axess_core/session/storage/
mysql.rs1use crate::session::storage::session_codec::{SessionCodec, SqlStoreError, expires_at};
20use crate::session::storage::sql_helpers::log_cleanup_outcome;
21use crate::session::{data::SessionData, id::SessionId, store::SessionStore};
22use axess_clock::{Clock, SystemClock};
23use sqlx::MySqlPool;
24use std::sync::Arc;
25use std::time::Duration;
26
27pub type MysqlStoreError = SqlStoreError;
30
31#[derive(Clone)]
56pub struct MysqlSessionStore {
57 pool: MySqlPool,
58 codec: SessionCodec,
59 clock: Arc<dyn Clock>,
60}
61
62impl MysqlSessionStore {
63 pub fn new(pool: MySqlPool, crypto: crate::session::crypto::SessionCrypto) -> Self {
65 Self {
66 pool,
67 codec: SessionCodec::encrypted(crypto),
68 clock: Arc::new(SystemClock),
69 }
70 }
71
72 pub fn plaintext(pool: MySqlPool) -> Self {
74 tracing::warn!(
75 "MysqlSessionStore created without encryption; \
76 do not use in production"
77 );
78 Self {
79 pool,
80 codec: SessionCodec::plaintext(),
81 clock: Arc::new(SystemClock),
82 }
83 }
84
85 pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
87 self.clock = clock;
88 self
89 }
90
91 pub async fn init_schema(&self) -> Result<(), sqlx::Error> {
97 sqlx::query(
98 r#"
99 CREATE TABLE IF NOT EXISTS sessions (
100 id VARCHAR(64) PRIMARY KEY,
101 data TEXT NOT NULL,
102 expires_at BIGINT NOT NULL,
103 INDEX idx_sessions_expires_at (expires_at)
104 )
105 "#,
106 )
107 .execute(&self.pool)
108 .await?;
109
110 Ok(())
111 }
112
113 pub async fn cleanup_expired(&self) -> Result<u64, sqlx::Error> {
115 let now = self.clock.now().timestamp();
116 let result = sqlx::query("DELETE FROM sessions WHERE expires_at < ?")
117 .bind(now)
118 .execute(&self.pool)
119 .await?;
120 Ok(result.rows_affected())
121 }
122
123 pub fn spawn_cleanup_task(&self, interval: Duration) -> tokio::task::JoinHandle<()> {
127 let store = self.clone();
128 tokio::spawn(async move {
129 let mut ticker = tokio::time::interval(interval);
130 ticker.tick().await;
131 loop {
132 ticker.tick().await;
133 log_cleanup_outcome("mysql", store.cleanup_expired().await);
134 }
135 })
136 }
137}
138
139impl SessionStore for MysqlSessionStore {
140 type Error = SqlStoreError;
141
142 async fn load(&self, id: &SessionId) -> Result<Option<SessionData>, Self::Error> {
143 let id_str = id.to_string();
144 let now = self.clock.now().timestamp();
145
146 let row: Option<(String,)> =
147 sqlx::query_as("SELECT data FROM sessions WHERE id = ? AND expires_at > ?")
148 .bind(&id_str)
149 .bind(now)
150 .fetch_optional(&self.pool)
151 .await?;
152
153 match row {
154 Some((stored,)) => Ok(Some(self.codec.decode(&stored)?)),
155 None => Ok(None),
156 }
157 }
158
159 async fn save(
160 &self,
161 id: &SessionId,
162 data: &SessionData,
163 ttl: Duration,
164 ) -> Result<(), Self::Error> {
165 let id_str = id.to_string();
166 let encoded = self.codec.encode(data)?;
167 let exp = expires_at(&*self.clock, ttl);
168
169 sqlx::query(
176 r#"
177 INSERT INTO sessions (id, data, expires_at)
178 VALUES (?, ?, ?)
179 ON DUPLICATE KEY UPDATE data = VALUES(data), expires_at = VALUES(expires_at)
180 "#,
181 )
182 .bind(&id_str)
183 .bind(&encoded)
184 .bind(exp)
185 .execute(&self.pool)
186 .await?;
187
188 Ok(())
189 }
190
191 async fn delete(&self, id: &SessionId) -> Result<(), Self::Error> {
192 let id_str = id.to_string();
193 sqlx::query("DELETE FROM sessions WHERE id = ?")
194 .bind(&id_str)
195 .execute(&self.pool)
196 .await?;
197 Ok(())
198 }
199
200 async fn cycle(
201 &self,
202 old_id: &SessionId,
203 new_id: &SessionId,
204 data: &SessionData,
205 ttl: Duration,
206 ) -> Result<(), Self::Error> {
207 let encoded = self.codec.encode(data)?;
208 let exp = expires_at(&*self.clock, ttl);
209 let old_str = old_id.to_string();
210 let new_str = new_id.to_string();
211
212 let mut tx = self.pool.begin().await?;
213
214 sqlx::query("DELETE FROM sessions WHERE id = ?")
215 .bind(&old_str)
216 .execute(&mut *tx)
217 .await?;
218
219 sqlx::query("INSERT INTO sessions (id, data, expires_at) VALUES (?, ?, ?)")
220 .bind(&new_str)
221 .bind(&encoded)
222 .bind(exp)
223 .execute(&mut *tx)
224 .await?;
225
226 tx.commit().await?;
227 Ok(())
228 }
229
230 async fn prune_expired(&self) -> Result<u64, Self::Error> {
231 Ok(self.cleanup_expired().await?)
232 }
233}
234
235use crate::health::{HealthCheck, HealthStatus};
238use crate::session::storage::sql_helpers::sql_health_probe;
239
240impl HealthCheck for MysqlSessionStore {
241 fn check(
242 &self,
243 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = HealthStatus> + Send + '_>> {
244 Box::pin(sql_health_probe(
245 "mysql",
246 sqlx::query_scalar::<_, i32>("SELECT 1").fetch_one(&self.pool),
247 ))
248 }
249}
250
251impl crate::store::Store<SessionId, SessionData> for MysqlSessionStore {
257 type Error = SqlStoreError;
258
259 fn get(
260 &self,
261 key: &SessionId,
262 ) -> impl std::future::Future<Output = Result<Option<SessionData>, Self::Error>> + Send {
263 <Self as SessionStore>::load(self, key)
264 }
265
266 fn put(
267 &self,
268 key: &SessionId,
269 value: &SessionData,
270 ttl: Duration,
271 ) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
272 <Self as SessionStore>::save(self, key, value, ttl)
273 }
274
275 fn delete(
276 &self,
277 key: &SessionId,
278 ) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
279 <Self as SessionStore>::delete(self, key)
280 }
281
282 fn prune_expired(&self) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
283 <Self as SessionStore>::prune_expired(self)
284 }
285}
286
287#[cfg(test)]
288mod mysql_tests {
289 use super::*;
297 use crate::session::data::SessionData;
298 use crate::session::id::SessionId;
299 use crate::testing::mock_tracing::TracingCapture;
300 use sqlx::mysql::MySqlPoolOptions;
301
302 fn unreachable_pool() -> MySqlPool {
303 MySqlPoolOptions::new()
309 .max_connections(1)
310 .acquire_timeout(Duration::from_millis(200))
311 .connect_lazy("mysql://user:pass@127.0.0.1:1/nodb")
312 .expect("connect_lazy must parse a valid URL")
313 }
314
315 fn store() -> MysqlSessionStore {
316 MysqlSessionStore::plaintext(unreachable_pool())
317 }
318
319 fn sample_id() -> SessionId {
320 SessionId::new(&axess_rng::SystemRng)
321 }
322
323 #[tokio::test]
324 async fn plaintext_constructor_emits_warning() {
325 let capture = TracingCapture::install();
326 drop(MysqlSessionStore::plaintext(unreachable_pool()));
327 assert!(
328 capture.contains_at_level(tracing::Level::WARN, "without encryption"),
329 "plaintext() must warn operators; captured events: {:#?}",
330 capture.events()
331 );
332 }
333
334 #[tokio::test]
335 async fn load_propagates_connection_error_not_ok_none() {
336 let result = store().load(&sample_id()).await;
337 assert!(
338 result.is_err(),
339 "load must propagate sqlx error from an unreachable pool, \
340 not silently return an Ok variant"
341 );
342 }
343
344 #[tokio::test]
345 async fn save_propagates_connection_error_not_ok_unit() {
346 let result = store()
347 .save(
348 &sample_id(),
349 &SessionData::default(),
350 Duration::from_secs(60),
351 )
352 .await;
353 assert!(
354 result.is_err(),
355 "save must propagate sqlx error, not Ok(())"
356 );
357 }
358
359 #[tokio::test]
360 async fn delete_propagates_connection_error_not_ok_unit() {
361 let result = store().delete(&sample_id()).await;
362 assert!(
363 result.is_err(),
364 "delete must propagate sqlx error, not Ok(())"
365 );
366 }
367
368 #[tokio::test]
369 async fn cycle_propagates_connection_error_not_ok_unit() {
370 let result = store()
371 .cycle(
372 &sample_id(),
373 &sample_id(),
374 &SessionData::default(),
375 Duration::from_secs(60),
376 )
377 .await;
378 assert!(
379 result.is_err(),
380 "cycle must propagate sqlx error, not Ok(())"
381 );
382 }
383
384 #[tokio::test]
385 async fn prune_expired_propagates_connection_error_not_ok_count() {
386 let result = store().prune_expired().await;
387 assert!(
388 result.is_err(),
389 "prune_expired must propagate sqlx error, not an Ok(u64) count"
390 );
391 }
392
393 #[tokio::test]
394 async fn cleanup_expired_propagates_connection_error_not_ok_count() {
395 let result = store().cleanup_expired().await;
396 assert!(
397 result.is_err(),
398 "cleanup_expired must propagate sqlx error, not an Ok(u64) count"
399 );
400 }
401
402 #[tokio::test]
403 async fn init_schema_propagates_connection_error_not_ok_unit() {
404 let result = store().init_schema().await;
405 assert!(
406 result.is_err(),
407 "init_schema must propagate sqlx error, not Ok(())"
408 );
409 }
410
411 #[tokio::test]
412 async fn health_check_returns_unhealthy_on_unreachable_pool() {
413 let status = store().check().await;
414 assert!(
415 matches!(status, HealthStatus::Unhealthy(_)),
416 "check() must report Unhealthy against unreachable pool, got {status:?}"
417 );
418 }
419}