Skip to main content

axess_core/session/storage/
postgres.rs

1//! PostgreSQL-backed session store using sqlx.
2//!
3//! # Schema
4//!
5//! ```sql
6//! CREATE TABLE IF NOT EXISTS sessions (
7//!   id TEXT PRIMARY KEY,
8//!   data TEXT NOT NULL,
9//!   expires_at BIGINT NOT NULL
10//! );
11//!
12//! CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions (expires_at);
13//! ```
14
15use crate::session::storage::session_codec::{SessionCodec, SqlStoreError, expires_at};
16use crate::session::storage::sql_helpers::log_cleanup_outcome;
17use crate::session::{data::SessionData, id::SessionId, store::SessionStore};
18use axess_clock::{Clock, SystemClock};
19use sqlx::PgPool;
20use std::sync::Arc;
21use std::time::Duration;
22
23/// Backward-compatible type alias. Use [`SqlStoreError`] directly in new code.
24pub type PostgresStoreError = SqlStoreError;
25
26/// PostgreSQL-backed session store with AES-256-GCM encryption at rest.
27///
28/// Wrap an existing [`PgPool`] and call [`init_schema`](Self::init_schema) once at startup.
29/// **Production deployments must also schedule cleanup** of expired session
30/// rows: either by calling [`spawn_cleanup_task`](Self::spawn_cleanup_task)
31/// at startup, configuring `pg_cron` on the database side, or running an
32/// external job that invokes [`cleanup_expired`](Self::cleanup_expired).
33/// Without one of these, the `sessions` table grows unbounded.
34///
35/// # Encryption
36///
37/// The primary constructor [`new`](PostgresSessionStore::new) requires a
38/// [`SessionCrypto`](crate::session::crypto::SessionCrypto) key; session data
39/// is encrypted before storage and decrypted on load.
40///
41/// For local development or testing where encryption is not needed, use
42/// [`plaintext`](PostgresSessionStore::plaintext) instead.
43///
44/// ```rust,ignore
45/// use axess::session::SessionCrypto;
46///
47/// // Production: encrypted (required).
48/// let store = PostgresSessionStore::new(pool, SessionCrypto::new(key));
49///
50/// // Development only: plaintext (explicit opt-out).
51/// let store = PostgresSessionStore::plaintext(pool);
52/// ```
53#[derive(Clone)]
54pub struct PostgresSessionStore {
55    pool: PgPool,
56    codec: SessionCodec,
57    clock: Arc<dyn Clock>,
58}
59
60impl PostgresSessionStore {
61    /// Create an **encrypted** store (recommended for production).
62    pub fn new(pool: PgPool, crypto: crate::session::crypto::SessionCrypto) -> Self {
63        Self {
64            pool,
65            codec: SessionCodec::encrypted(crypto),
66            clock: Arc::new(SystemClock),
67        }
68    }
69
70    /// Create a **plaintext** store (development/testing only).
71    pub fn plaintext(pool: PgPool) -> Self {
72        tracing::warn!(
73            "PostgresSessionStore created without encryption; \
74             do not use in production"
75        );
76        Self {
77            pool,
78            codec: SessionCodec::plaintext(),
79            clock: Arc::new(SystemClock),
80        }
81    }
82
83    /// Inject a [`Clock`] for deterministic-simulation testing.
84    pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
85        self.clock = clock;
86        self
87    }
88
89    /// Create the `sessions` table and index if they don't already exist.
90    pub async fn init_schema(&self) -> Result<(), sqlx::Error> {
91        sqlx::query(
92            r#"
93            CREATE TABLE IF NOT EXISTS sessions (
94                id TEXT PRIMARY KEY,
95                data TEXT NOT NULL,
96                expires_at BIGINT NOT NULL
97            )
98            "#,
99        )
100        .execute(&self.pool)
101        .await?;
102
103        sqlx::query("CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions (expires_at)")
104            .execute(&self.pool)
105            .await?;
106
107        Ok(())
108    }
109
110    /// Delete all sessions whose `expires_at` is in the past.
111    pub async fn cleanup_expired(&self) -> Result<u64, sqlx::Error> {
112        let now = self.clock.now().timestamp();
113        let result = sqlx::query("DELETE FROM sessions WHERE expires_at < $1")
114            .bind(now)
115            .execute(&self.pool)
116            .await?;
117        Ok(result.rows_affected())
118    }
119
120    /// Spawn a background task that calls [`cleanup_expired`](Self::cleanup_expired) on a fixed interval.
121    ///
122    /// SQL stores accumulate expired session rows forever unless something
123    /// removes them. Production deployments **must** either call this
124    /// helper once at startup, run an external scheduled job (e.g. `pg_cron`
125    /// on the database side), or accept unbounded table growth. The
126    /// returned [`tokio::task::JoinHandle`] aborts the loop when dropped,
127    /// so store it for the lifetime of the application.
128    ///
129    /// Errors from `cleanup_expired` are logged at `warn` and swallowed;
130    /// the loop keeps running so a single transient DB blip does not
131    /// silently halt cleanup forever.
132    ///
133    /// ```rust,ignore
134    /// let store = PostgresSessionStore::new(pool, crypto);
135    /// store.init_schema().await?;
136    /// let _cleanup = store.spawn_cleanup_task(std::time::Duration::from_secs(3600));
137    /// ```
138    pub fn spawn_cleanup_task(&self, interval: Duration) -> tokio::task::JoinHandle<()> {
139        let store = self.clone();
140        tokio::spawn(async move {
141            let mut ticker = tokio::time::interval(interval);
142            ticker.tick().await;
143            loop {
144                ticker.tick().await;
145                log_cleanup_outcome("postgres", store.cleanup_expired().await);
146            }
147        })
148    }
149}
150
151impl SessionStore for PostgresSessionStore {
152    type Error = SqlStoreError;
153
154    async fn load(&self, id: &SessionId) -> Result<Option<SessionData>, Self::Error> {
155        let id_str = id.to_string();
156        let now = self.clock.now().timestamp();
157
158        let row: Option<(String,)> =
159            sqlx::query_as("SELECT data FROM sessions WHERE id = $1 AND expires_at > $2")
160                .bind(&id_str)
161                .bind(now)
162                .fetch_optional(&self.pool)
163                .await?;
164
165        match row {
166            Some((stored,)) => Ok(Some(self.codec.decode(&stored)?)),
167            None => Ok(None),
168        }
169    }
170
171    async fn save(
172        &self,
173        id: &SessionId,
174        data: &SessionData,
175        ttl: Duration,
176    ) -> Result<(), Self::Error> {
177        let id_str = id.to_string();
178        let encoded = self.codec.encode(data)?;
179        let exp = expires_at(&*self.clock, ttl);
180
181        sqlx::query(
182            r#"
183            INSERT INTO sessions (id, data, expires_at)
184            VALUES ($1, $2, $3)
185            ON CONFLICT (id) DO UPDATE SET data = EXCLUDED.data, expires_at = EXCLUDED.expires_at
186            "#,
187        )
188        .bind(&id_str)
189        .bind(&encoded)
190        .bind(exp)
191        .execute(&self.pool)
192        .await?;
193
194        Ok(())
195    }
196
197    async fn delete(&self, id: &SessionId) -> Result<(), Self::Error> {
198        let id_str = id.to_string();
199        sqlx::query("DELETE FROM sessions WHERE id = $1")
200            .bind(&id_str)
201            .execute(&self.pool)
202            .await?;
203        Ok(())
204    }
205
206    async fn cycle(
207        &self,
208        old_id: &SessionId,
209        new_id: &SessionId,
210        data: &SessionData,
211        ttl: Duration,
212    ) -> Result<(), Self::Error> {
213        let encoded = self.codec.encode(data)?;
214        let exp = expires_at(&*self.clock, ttl);
215        let old_str = old_id.to_string();
216        let new_str = new_id.to_string();
217
218        let mut tx = self.pool.begin().await?;
219
220        sqlx::query("DELETE FROM sessions WHERE id = $1")
221            .bind(&old_str)
222            .execute(&mut *tx)
223            .await?;
224
225        sqlx::query("INSERT INTO sessions (id, data, expires_at) VALUES ($1, $2, $3)")
226            .bind(&new_str)
227            .bind(&encoded)
228            .bind(exp)
229            .execute(&mut *tx)
230            .await?;
231
232        tx.commit().await?;
233        Ok(())
234    }
235
236    async fn prune_expired(&self) -> Result<u64, Self::Error> {
237        // Trait surface over the existing inherent
238        // `cleanup_expired`; applications that hold a `dyn SessionStore`
239        // can now drive the sweep without downcasting.
240        Ok(self.cleanup_expired().await?)
241    }
242}
243
244// ── HealthCheck ──────────────────────────────────────────────────────────────
245
246use crate::health::{HealthCheck, HealthStatus};
247use crate::session::storage::sql_helpers::sql_health_probe;
248
249impl HealthCheck for PostgresSessionStore {
250    fn check(
251        &self,
252    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = HealthStatus> + Send + '_>> {
253        Box::pin(sql_health_probe(
254            "postgres",
255            sqlx::query_scalar::<_, i32>("SELECT 1").fetch_one(&self.pool),
256        ))
257    }
258}
259
260// ── Store<SessionId, SessionData> ────────────────────────────────────────────
261//
262// surface; see the same impl on `SqliteSessionStore` for the
263// rationale. Identical body shape (forwards to `SessionStore`).
264
265impl crate::store::Store<SessionId, SessionData> for PostgresSessionStore {
266    type Error = SqlStoreError;
267
268    fn get(
269        &self,
270        key: &SessionId,
271    ) -> impl std::future::Future<Output = Result<Option<SessionData>, Self::Error>> + Send {
272        <Self as SessionStore>::load(self, key)
273    }
274
275    fn put(
276        &self,
277        key: &SessionId,
278        value: &SessionData,
279        ttl: Duration,
280    ) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
281        <Self as SessionStore>::save(self, key, value, ttl)
282    }
283
284    fn delete(
285        &self,
286        key: &SessionId,
287    ) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
288        <Self as SessionStore>::delete(self, key)
289    }
290
291    fn prune_expired(&self) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
292        <Self as SessionStore>::prune_expired(self)
293    }
294}
295
296#[cfg(test)]
297mod postgres_tests {
298    //! Pin every `<impl SessionStore for PostgresSessionStore>` body
299    //! against `Ok(...)` body-replacement mutations without requiring a live
300    //! Postgres. Strategy: a `PgPool` created via `connect_lazy` against a
301    //! port that refuses connections produces a deterministic `sqlx::Error`
302    //! on the first DB round-trip. Real impl: returns `Err`. Mutated impl:
303    //! returns `Ok(...)`. The two are distinguishable, so the mutation is
304    //! caught.
305    use super::*;
306    use crate::session::data::SessionData;
307    use crate::session::id::SessionId;
308    use crate::testing::mock_tracing::TracingCapture;
309    use sqlx::postgres::PgPoolOptions;
310
311    fn unreachable_pool() -> PgPool {
312        // 127.0.0.1:1; port 1 is reserved IANA and never opened by
313        // local services. `connect_lazy_with` does not initiate a
314        // connection, so this constructor itself never blocks; the
315        // first query will fail with ECONNREFUSED in well under the
316        // 200 ms acquire timeout.
317        PgPoolOptions::new()
318            .max_connections(1)
319            .acquire_timeout(Duration::from_millis(200))
320            .connect_lazy("postgres://user:pass@127.0.0.1:1/nodb")
321            .expect("connect_lazy must parse a valid URL")
322    }
323
324    fn store() -> PostgresSessionStore {
325        PostgresSessionStore::plaintext(unreachable_pool())
326    }
327
328    fn sample_id() -> SessionId {
329        SessionId::new(&axess_rng::SystemRng)
330    }
331
332    #[tokio::test]
333    async fn plaintext_constructor_emits_warning() {
334        // Pins line 71 against `-> Default::default()` (unviable; no
335        // `Default` impl exists) AND the inline `tracing::warn!`
336        // diagnostic against accidental removal during refactors.
337        let capture = TracingCapture::install();
338        drop(PostgresSessionStore::plaintext(unreachable_pool()));
339        assert!(
340            capture.contains_at_level(tracing::Level::WARN, "without encryption"),
341            "plaintext() must warn operators; captured events: {:#?}",
342            capture.events()
343        );
344    }
345
346    #[tokio::test]
347    async fn load_propagates_connection_error_not_ok_none() {
348        // Pins line 162 body against `Ok(None)` and
349        // `Ok(Some(Default::default()))` mutations.
350        let result = store().load(&sample_id()).await;
351        assert!(
352            result.is_err(),
353            "load must propagate sqlx error from an unreachable pool, \
354             not silently return an Ok variant"
355        );
356    }
357
358    #[tokio::test]
359    async fn save_propagates_connection_error_not_ok_unit() {
360        // Pins line 184 body against `Ok(())`.
361        let result = store()
362            .save(
363                &sample_id(),
364                &SessionData::default(),
365                Duration::from_secs(60),
366            )
367            .await;
368        assert!(
369            result.is_err(),
370            "save must propagate sqlx error, not Ok(())"
371        );
372    }
373
374    #[tokio::test]
375    async fn delete_propagates_connection_error_not_ok_unit() {
376        // Pins line 205 body against `Ok(())`.
377        let result = store().delete(&sample_id()).await;
378        assert!(
379            result.is_err(),
380            "delete must propagate sqlx error, not Ok(())"
381        );
382    }
383
384    #[tokio::test]
385    async fn cycle_propagates_connection_error_not_ok_unit() {
386        // Pins line 220 body against `Ok(())`. Even though `cycle` opens
387        // a transaction with two statements, the *begin* fails first
388        // against an unreachable pool; so we observe Err from the very
389        // first await.
390        let result = store()
391            .cycle(
392                &sample_id(),
393                &sample_id(),
394                &SessionData::default(),
395                Duration::from_secs(60),
396            )
397            .await;
398        assert!(
399            result.is_err(),
400            "cycle must propagate sqlx error, not Ok(())"
401        );
402    }
403
404    #[tokio::test]
405    async fn prune_expired_propagates_connection_error_not_ok_count() {
406        // Pins line 247 body against `Ok(0)` and `Ok(1)`. `prune_expired`
407        // is a thin wrapper around `cleanup_expired`, which is itself
408        // covered below; this test specifically guards the trait surface.
409        let result = store().prune_expired().await;
410        assert!(
411            result.is_err(),
412            "prune_expired must propagate sqlx error, not an Ok(u64) count"
413        );
414    }
415
416    #[tokio::test]
417    async fn cleanup_expired_propagates_connection_error_not_ok_count() {
418        // Pins line 111 body against `Ok(0)` and `Ok(1)` mutations on
419        // the inherent `cleanup_expired` method.
420        let result = store().cleanup_expired().await;
421        assert!(
422            result.is_err(),
423            "cleanup_expired must propagate sqlx error, not an Ok(u64) count"
424        );
425    }
426
427    #[tokio::test]
428    async fn init_schema_propagates_connection_error_not_ok_unit() {
429        // Pins line 90 body against `Ok(())`.
430        let result = store().init_schema().await;
431        assert!(
432            result.is_err(),
433            "init_schema must propagate sqlx error, not Ok(())"
434        );
435    }
436
437    #[tokio::test]
438    async fn health_check_returns_unhealthy_on_unreachable_pool() {
439        // Pins line 259 body against `Pin::new(...)`-style mutations
440        // that would degrade the response to a default future. The
441        // original yields a real `HealthStatus::Unhealthy(_)` against
442        // an unreachable pool.
443        let status = store().check().await;
444        assert!(
445            matches!(status, HealthStatus::Unhealthy(_)),
446            "check() must report Unhealthy against unreachable pool, got {status:?}"
447        );
448    }
449}