use crate::session::storage::session_codec::{SessionCodec, SqlStoreError, expires_at};
use crate::session::storage::sql_helpers::log_cleanup_outcome;
use crate::session::{data::SessionData, id::SessionId, store::SessionStore};
use axess_clock::{Clock, SystemClock};
use sqlx::SqlitePool;
use std::sync::Arc;
use std::time::Duration;
pub type SqliteStoreError = SqlStoreError;
#[derive(Clone)]
pub struct SqliteSessionStore {
pool: SqlitePool,
codec: SessionCodec,
clock: Arc<dyn Clock>,
}
impl SqliteSessionStore {
pub fn new(pool: SqlitePool, crypto: crate::session::crypto::SessionCrypto) -> Self {
Self {
pool,
codec: SessionCodec::encrypted(crypto),
clock: Arc::new(SystemClock),
}
}
pub fn plaintext(pool: SqlitePool) -> Self {
tracing::warn!(
"SqliteSessionStore created without encryption; \
do not use in production"
);
Self {
pool,
codec: SessionCodec::plaintext(),
clock: Arc::new(SystemClock),
}
}
pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
self.clock = clock;
self
}
pub async fn init_schema(&self) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
data TEXT NOT NULL,
expires_at INTEGER NOT NULL
)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions (expires_at)")
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn cleanup_expired(&self) -> Result<u64, sqlx::Error> {
let now = self.clock.now().timestamp();
let result = sqlx::query("DELETE FROM sessions WHERE expires_at < ?1")
.bind(now)
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
}
pub fn spawn_cleanup_task(&self, interval: Duration) -> tokio::task::JoinHandle<()> {
let store = self.clone();
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.tick().await;
loop {
ticker.tick().await;
log_cleanup_outcome("sqlite", store.cleanup_expired().await);
}
})
}
}
impl SessionStore for SqliteSessionStore {
type Error = SqlStoreError;
async fn load(&self, id: &SessionId) -> Result<Option<SessionData>, Self::Error> {
let id_str = id.to_string();
let now = self.clock.now().timestamp();
let row: Option<(String,)> =
sqlx::query_as("SELECT data FROM sessions WHERE id = ?1 AND expires_at > ?2")
.bind(&id_str)
.bind(now)
.fetch_optional(&self.pool)
.await?;
match row {
Some((stored,)) => Ok(Some(self.codec.decode(&stored)?)),
None => Ok(None),
}
}
async fn save(
&self,
id: &SessionId,
data: &SessionData,
ttl: Duration,
) -> Result<(), Self::Error> {
let id_str = id.to_string();
let encoded = self.codec.encode(data)?;
let exp = expires_at(&*self.clock, ttl);
sqlx::query(
r#"
INSERT INTO sessions (id, data, expires_at)
VALUES (?1, ?2, ?3)
ON CONFLICT(id) DO UPDATE SET data = excluded.data, expires_at = excluded.expires_at
"#,
)
.bind(&id_str)
.bind(&encoded)
.bind(exp)
.execute(&self.pool)
.await?;
Ok(())
}
async fn delete(&self, id: &SessionId) -> Result<(), Self::Error> {
let id_str = id.to_string();
sqlx::query("DELETE FROM sessions WHERE id = ?1")
.bind(&id_str)
.execute(&self.pool)
.await?;
Ok(())
}
async fn cycle(
&self,
old_id: &SessionId,
new_id: &SessionId,
data: &SessionData,
ttl: Duration,
) -> Result<(), Self::Error> {
let encoded = self.codec.encode(data)?;
let exp = expires_at(&*self.clock, ttl);
let old_str = old_id.to_string();
let new_str = new_id.to_string();
let mut tx = self.pool.begin().await?;
sqlx::query("DELETE FROM sessions WHERE id = ?1")
.bind(&old_str)
.execute(&mut *tx)
.await?;
sqlx::query("INSERT INTO sessions (id, data, expires_at) VALUES (?1, ?2, ?3)")
.bind(&new_str)
.bind(&encoded)
.bind(exp)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
async fn prune_expired(&self) -> Result<u64, Self::Error> {
Ok(self.cleanup_expired().await?)
}
}
use crate::health::{HealthCheck, HealthStatus};
use crate::session::storage::sql_helpers::sql_health_probe;
impl HealthCheck for SqliteSessionStore {
fn check(
&self,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = HealthStatus> + Send + '_>> {
Box::pin(sql_health_probe(
"sqlite",
sqlx::query_scalar::<_, i32>("SELECT 1").fetch_one(&self.pool),
))
}
}
impl crate::store::Store<SessionId, SessionData> for SqliteSessionStore {
type Error = SqlStoreError;
fn get(
&self,
key: &SessionId,
) -> impl std::future::Future<Output = Result<Option<SessionData>, Self::Error>> + Send {
<Self as SessionStore>::load(self, key)
}
fn put(
&self,
key: &SessionId,
value: &SessionData,
ttl: Duration,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
<Self as SessionStore>::save(self, key, value, ttl)
}
fn delete(
&self,
key: &SessionId,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
<Self as SessionStore>::delete(self, key)
}
fn prune_expired(&self) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
<Self as SessionStore>::prune_expired(self)
}
}
#[cfg(test)]
mod sqlite_tests {
use super::*;
use crate::session::data::SessionData;
use crate::session::id::SessionId;
use axess_rng::SystemRng;
use sqlx::sqlite::SqlitePoolOptions;
async fn memory_pool() -> SqlitePool {
SqlitePoolOptions::new()
.max_connections(1)
.connect("sqlite::memory:")
.await
.expect("in-memory sqlite must connect")
}
async fn store() -> SqliteSessionStore {
let pool = memory_pool().await;
let store = SqliteSessionStore::plaintext(pool);
store.init_schema().await.expect("init_schema");
store
}
fn sample_id() -> SessionId {
SessionId::new(&SystemRng)
}
fn payload_with_custom() -> SessionData {
SessionData {
custom: serde_json::json!({"k": "v"}),
..SessionData::default()
}
}
#[tokio::test]
async fn init_schema_creates_sessions_table() {
let store = store().await;
let result = store
.save(
&sample_id(),
&SessionData::default(),
Duration::from_secs(60),
)
.await;
assert!(
result.is_ok(),
"save after init_schema must succeed: {result:?}"
);
}
#[tokio::test]
async fn save_then_load_returns_persisted_payload() {
let store = store().await;
let id = sample_id();
let data = payload_with_custom();
store
.save(&id, &data, Duration::from_secs(60))
.await
.expect("save");
let loaded = store.load(&id).await.expect("load").expect("Some");
assert_eq!(
serde_json::to_string(&data).unwrap(),
serde_json::to_string(&loaded).unwrap(),
"load must return what save persisted; kills Ok(None) AND Ok(Some(Default))"
);
}
#[tokio::test]
async fn load_of_absent_key_returns_none() {
let store = store().await;
let loaded = store.load(&sample_id()).await.expect("load");
assert!(
loaded.is_none(),
"absent key must yield None, not Some(Default)"
);
}
#[tokio::test]
async fn delete_actually_removes_row() {
let store = store().await;
let id = sample_id();
let data = payload_with_custom();
store
.save(&id, &data, Duration::from_secs(60))
.await
.expect("save");
assert!(store.load(&id).await.expect("load before delete").is_some());
store.delete(&id).await.expect("delete");
let after = store.load(&id).await.expect("load after delete");
assert!(
after.is_none(),
"delete must remove the row; mutated body would leave it"
);
}
#[tokio::test]
async fn cycle_atomically_swaps_session_ids() {
let store = store().await;
let old = sample_id();
let new = sample_id();
let data = payload_with_custom();
store
.save(&old, &data, Duration::from_secs(60))
.await
.expect("save");
store
.cycle(&old, &new, &data, Duration::from_secs(60))
.await
.expect("cycle");
assert!(
store.load(&old).await.expect("load old").is_none(),
"old id must be gone after cycle"
);
assert!(
store.load(&new).await.expect("load new").is_some(),
"new id must exist after cycle; kills cycle->Ok(()) mutant"
);
}
#[tokio::test]
async fn cleanup_expired_returns_real_row_count() {
let store = store().await;
let id1 = sample_id();
let id2 = sample_id();
for id in [&id1, &id2] {
sqlx::query("INSERT INTO sessions (id, data, expires_at) VALUES (?1, '{}', 0)")
.bind(id.to_string())
.execute(&store.pool)
.await
.expect("manual insert");
}
let removed = store.cleanup_expired().await.expect("cleanup_expired");
assert_eq!(
removed, 2,
"cleanup_expired must report the real row-count; kills Ok(0) and Ok(1)"
);
}
#[tokio::test]
async fn prune_expired_trait_surface_matches_inherent() {
let store = store().await;
for _ in 0..3 {
sqlx::query("INSERT INTO sessions (id, data, expires_at) VALUES (?1, '{}', 0)")
.bind(SessionId::new(&SystemRng).to_string())
.execute(&store.pool)
.await
.expect("manual insert");
}
let removed = store.prune_expired().await.expect("prune_expired");
assert_eq!(removed, 3);
}
}