axess_core/session/storage/
postgres.rs1use crate::session::storage::session_codec::{SessionCodec, SqlStoreError, expires_at};
16use crate::session::storage::sql_helpers::log_cleanup_outcome;
17use crate::session::{data::SessionData, id::SessionId, store::SessionStore};
18use axess_clock::{Clock, SystemClock};
19use sqlx::PgPool;
20use std::sync::Arc;
21use std::time::Duration;
22
23pub type PostgresStoreError = SqlStoreError;
25
26#[derive(Clone)]
54pub struct PostgresSessionStore {
55 pool: PgPool,
56 codec: SessionCodec,
57 clock: Arc<dyn Clock>,
58}
59
60impl PostgresSessionStore {
61 pub fn new(pool: PgPool, crypto: crate::session::crypto::SessionCrypto) -> Self {
63 Self {
64 pool,
65 codec: SessionCodec::encrypted(crypto),
66 clock: Arc::new(SystemClock),
67 }
68 }
69
70 pub fn plaintext(pool: PgPool) -> Self {
72 tracing::warn!(
73 "PostgresSessionStore created without encryption; \
74 do not use in production"
75 );
76 Self {
77 pool,
78 codec: SessionCodec::plaintext(),
79 clock: Arc::new(SystemClock),
80 }
81 }
82
83 pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
85 self.clock = clock;
86 self
87 }
88
89 pub async fn init_schema(&self) -> Result<(), sqlx::Error> {
91 sqlx::query(
92 r#"
93 CREATE TABLE IF NOT EXISTS sessions (
94 id TEXT PRIMARY KEY,
95 data TEXT NOT NULL,
96 expires_at BIGINT NOT NULL
97 )
98 "#,
99 )
100 .execute(&self.pool)
101 .await?;
102
103 sqlx::query("CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions (expires_at)")
104 .execute(&self.pool)
105 .await?;
106
107 Ok(())
108 }
109
110 pub async fn cleanup_expired(&self) -> Result<u64, sqlx::Error> {
112 let now = self.clock.now().timestamp();
113 let result = sqlx::query("DELETE FROM sessions WHERE expires_at < $1")
114 .bind(now)
115 .execute(&self.pool)
116 .await?;
117 Ok(result.rows_affected())
118 }
119
120 pub fn spawn_cleanup_task(&self, interval: Duration) -> tokio::task::JoinHandle<()> {
139 let store = self.clone();
140 tokio::spawn(async move {
141 let mut ticker = tokio::time::interval(interval);
142 ticker.tick().await;
143 loop {
144 ticker.tick().await;
145 log_cleanup_outcome("postgres", store.cleanup_expired().await);
146 }
147 })
148 }
149}
150
151impl SessionStore for PostgresSessionStore {
152 type Error = SqlStoreError;
153
154 async fn load(&self, id: &SessionId) -> Result<Option<SessionData>, Self::Error> {
155 let id_str = id.to_string();
156 let now = self.clock.now().timestamp();
157
158 let row: Option<(String,)> =
159 sqlx::query_as("SELECT data FROM sessions WHERE id = $1 AND expires_at > $2")
160 .bind(&id_str)
161 .bind(now)
162 .fetch_optional(&self.pool)
163 .await?;
164
165 match row {
166 Some((stored,)) => Ok(Some(self.codec.decode(&stored)?)),
167 None => Ok(None),
168 }
169 }
170
171 async fn save(
172 &self,
173 id: &SessionId,
174 data: &SessionData,
175 ttl: Duration,
176 ) -> Result<(), Self::Error> {
177 let id_str = id.to_string();
178 let encoded = self.codec.encode(data)?;
179 let exp = expires_at(&*self.clock, ttl);
180
181 sqlx::query(
182 r#"
183 INSERT INTO sessions (id, data, expires_at)
184 VALUES ($1, $2, $3)
185 ON CONFLICT (id) DO UPDATE SET data = EXCLUDED.data, expires_at = EXCLUDED.expires_at
186 "#,
187 )
188 .bind(&id_str)
189 .bind(&encoded)
190 .bind(exp)
191 .execute(&self.pool)
192 .await?;
193
194 Ok(())
195 }
196
197 async fn delete(&self, id: &SessionId) -> Result<(), Self::Error> {
198 let id_str = id.to_string();
199 sqlx::query("DELETE FROM sessions WHERE id = $1")
200 .bind(&id_str)
201 .execute(&self.pool)
202 .await?;
203 Ok(())
204 }
205
206 async fn cycle(
207 &self,
208 old_id: &SessionId,
209 new_id: &SessionId,
210 data: &SessionData,
211 ttl: Duration,
212 ) -> Result<(), Self::Error> {
213 let encoded = self.codec.encode(data)?;
214 let exp = expires_at(&*self.clock, ttl);
215 let old_str = old_id.to_string();
216 let new_str = new_id.to_string();
217
218 let mut tx = self.pool.begin().await?;
219
220 sqlx::query("DELETE FROM sessions WHERE id = $1")
221 .bind(&old_str)
222 .execute(&mut *tx)
223 .await?;
224
225 sqlx::query("INSERT INTO sessions (id, data, expires_at) VALUES ($1, $2, $3)")
226 .bind(&new_str)
227 .bind(&encoded)
228 .bind(exp)
229 .execute(&mut *tx)
230 .await?;
231
232 tx.commit().await?;
233 Ok(())
234 }
235
236 async fn prune_expired(&self) -> Result<u64, Self::Error> {
237 Ok(self.cleanup_expired().await?)
241 }
242}
243
244use crate::health::{HealthCheck, HealthStatus};
247use crate::session::storage::sql_helpers::sql_health_probe;
248
249impl HealthCheck for PostgresSessionStore {
250 fn check(
251 &self,
252 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = HealthStatus> + Send + '_>> {
253 Box::pin(sql_health_probe(
254 "postgres",
255 sqlx::query_scalar::<_, i32>("SELECT 1").fetch_one(&self.pool),
256 ))
257 }
258}
259
260impl crate::store::Store<SessionId, SessionData> for PostgresSessionStore {
266 type Error = SqlStoreError;
267
268 fn get(
269 &self,
270 key: &SessionId,
271 ) -> impl std::future::Future<Output = Result<Option<SessionData>, Self::Error>> + Send {
272 <Self as SessionStore>::load(self, key)
273 }
274
275 fn put(
276 &self,
277 key: &SessionId,
278 value: &SessionData,
279 ttl: Duration,
280 ) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
281 <Self as SessionStore>::save(self, key, value, ttl)
282 }
283
284 fn delete(
285 &self,
286 key: &SessionId,
287 ) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
288 <Self as SessionStore>::delete(self, key)
289 }
290
291 fn prune_expired(&self) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
292 <Self as SessionStore>::prune_expired(self)
293 }
294}
295
296#[cfg(test)]
297mod postgres_tests {
298 use super::*;
306 use crate::session::data::SessionData;
307 use crate::session::id::SessionId;
308 use crate::testing::mock_tracing::TracingCapture;
309 use sqlx::postgres::PgPoolOptions;
310
311 fn unreachable_pool() -> PgPool {
312 PgPoolOptions::new()
318 .max_connections(1)
319 .acquire_timeout(Duration::from_millis(200))
320 .connect_lazy("postgres://user:pass@127.0.0.1:1/nodb")
321 .expect("connect_lazy must parse a valid URL")
322 }
323
324 fn store() -> PostgresSessionStore {
325 PostgresSessionStore::plaintext(unreachable_pool())
326 }
327
328 fn sample_id() -> SessionId {
329 SessionId::new(&axess_rng::SystemRng)
330 }
331
332 #[tokio::test]
333 async fn plaintext_constructor_emits_warning() {
334 let capture = TracingCapture::install();
338 drop(PostgresSessionStore::plaintext(unreachable_pool()));
339 assert!(
340 capture.contains_at_level(tracing::Level::WARN, "without encryption"),
341 "plaintext() must warn operators; captured events: {:#?}",
342 capture.events()
343 );
344 }
345
346 #[tokio::test]
347 async fn load_propagates_connection_error_not_ok_none() {
348 let result = store().load(&sample_id()).await;
351 assert!(
352 result.is_err(),
353 "load must propagate sqlx error from an unreachable pool, \
354 not silently return an Ok variant"
355 );
356 }
357
358 #[tokio::test]
359 async fn save_propagates_connection_error_not_ok_unit() {
360 let result = store()
362 .save(
363 &sample_id(),
364 &SessionData::default(),
365 Duration::from_secs(60),
366 )
367 .await;
368 assert!(
369 result.is_err(),
370 "save must propagate sqlx error, not Ok(())"
371 );
372 }
373
374 #[tokio::test]
375 async fn delete_propagates_connection_error_not_ok_unit() {
376 let result = store().delete(&sample_id()).await;
378 assert!(
379 result.is_err(),
380 "delete must propagate sqlx error, not Ok(())"
381 );
382 }
383
384 #[tokio::test]
385 async fn cycle_propagates_connection_error_not_ok_unit() {
386 let result = store()
391 .cycle(
392 &sample_id(),
393 &sample_id(),
394 &SessionData::default(),
395 Duration::from_secs(60),
396 )
397 .await;
398 assert!(
399 result.is_err(),
400 "cycle must propagate sqlx error, not Ok(())"
401 );
402 }
403
404 #[tokio::test]
405 async fn prune_expired_propagates_connection_error_not_ok_count() {
406 let result = store().prune_expired().await;
410 assert!(
411 result.is_err(),
412 "prune_expired must propagate sqlx error, not an Ok(u64) count"
413 );
414 }
415
416 #[tokio::test]
417 async fn cleanup_expired_propagates_connection_error_not_ok_count() {
418 let result = store().cleanup_expired().await;
421 assert!(
422 result.is_err(),
423 "cleanup_expired must propagate sqlx error, not an Ok(u64) count"
424 );
425 }
426
427 #[tokio::test]
428 async fn init_schema_propagates_connection_error_not_ok_unit() {
429 let result = store().init_schema().await;
431 assert!(
432 result.is_err(),
433 "init_schema must propagate sqlx error, not Ok(())"
434 );
435 }
436
437 #[tokio::test]
438 async fn health_check_returns_unhealthy_on_unreachable_pool() {
439 let status = store().check().await;
444 assert!(
445 matches!(status, HealthStatus::Unhealthy(_)),
446 "check() must report Unhealthy against unreachable pool, got {status:?}"
447 );
448 }
449}