1use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore};
2use sqlx::{pool::PoolConnection, Executor, MySql, MySqlPool};
3
4#[derive(Clone, Debug)]
27pub struct MySqlSessionStore {
28 client: MySqlPool,
29 table_name: String,
30}
31
32impl MySqlSessionStore {
33 pub fn from_client(client: MySqlPool) -> Self {
49 Self {
50 client,
51 table_name: "async_sessions".into(),
52 }
53 }
54
55 pub async fn new(database_url: &str) -> sqlx::Result<Self> {
71 let pool = MySqlPool::connect(database_url).await?;
72 Ok(Self::from_client(pool))
73 }
74
75 pub async fn new_with_table_name(database_url: &str, table_name: &str) -> sqlx::Result<Self> {
91 Ok(Self::new(database_url).await?.with_table_name(table_name))
92 }
93
94 pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Self {
115 let table_name = table_name.as_ref();
116 if table_name.is_empty()
117 || !table_name
118 .chars()
119 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
120 {
121 panic!(
122 "table name must be [a-zA-Z0-9_-], but {} was not",
123 table_name
124 );
125 }
126
127 self.table_name = table_name.to_owned();
128 self
129 }
130
131 pub async fn migrate(&self) -> sqlx::Result<()> {
149 log::info!("migrating sessions on `{}`", self.table_name);
150
151 let mut conn = self.client.acquire().await?;
152 conn.execute(&*self.substitute_table_name(
153 r#"
154 CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% (
155 `id` VARCHAR(128) NOT NULL,
156 `expires` TIMESTAMP(6) NULL,
157 `session` TEXT NOT NULL,
158 PRIMARY KEY (`id`),
159 KEY `expires` (`expires`)
160 )
161 ENGINE=InnoDB
162 DEFAULT CHARSET=utf8
163 "#,
164 ))
165 .await?;
166
167 Ok(())
168 }
169
170 fn substitute_table_name(&self, query: &str) -> String {
171 query.replace("%%TABLE_NAME%%", &self.table_name)
172 }
173
174 async fn connection(&self) -> sqlx::Result<PoolConnection<MySql>> {
176 self.client.acquire().await
177 }
178
179 #[cfg(feature = "async_std")]
202 pub fn spawn_cleanup_task(
203 &self,
204 period: std::time::Duration,
205 ) -> async_std::task::JoinHandle<()> {
206 let store = self.clone();
207 async_std::task::spawn(async move {
208 loop {
209 async_std::task::sleep(period).await;
210 if let Err(error) = store.cleanup().await {
211 log::error!("cleanup error: {}", error);
212 }
213 }
214 })
215 }
216
217 pub async fn cleanup(&self) -> sqlx::Result<()> {
235 let mut connection = self.connection().await?;
236 sqlx::query(&self.substitute_table_name("DELETE FROM %%TABLE_NAME%% WHERE expires < ?"))
237 .bind(Utc::now())
238 .execute(&mut connection)
239 .await?;
240
241 Ok(())
242 }
243
244 pub async fn count(&self) -> sqlx::Result<i64> {
262 let (count,) =
263 sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%"))
264 .fetch_one(&mut self.connection().await?)
265 .await?;
266
267 Ok(count)
268 }
269}
270
271#[async_trait]
272impl SessionStore for MySqlSessionStore {
273 async fn load_session(&self, cookie_value: String) -> Result<Option<Session>> {
274 let id = Session::id_from_cookie_value(&cookie_value)?;
275 let mut connection = self.connection().await?;
276
277 let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name(
278 "SELECT session FROM %%TABLE_NAME%% WHERE id = ? AND (expires IS NULL OR expires > ?)",
279 ))
280 .bind(&id)
281 .bind(Utc::now())
282 .fetch_optional(&mut connection)
283 .await?;
284
285 Ok(result
286 .map(|(session,)| serde_json::from_str(&session))
287 .transpose()?)
288 }
289
290 async fn store_session(&self, session: Session) -> Result<Option<String>> {
291 let id = session.id();
292 let string = serde_json::to_string(&session)?;
293 let mut connection = self.connection().await?;
294
295 sqlx::query(&self.substitute_table_name(
296 r#"
297 INSERT INTO %%TABLE_NAME%%
298 (id, session, expires) VALUES(?, ?, ?)
299 ON DUPLICATE KEY UPDATE
300 expires = VALUES(expires),
301 session = VALUES(session)
302 "#,
303 ))
304 .bind(&id)
305 .bind(&string)
306 .bind(&session.expiry())
307 .execute(&mut connection)
308 .await?;
309
310 Ok(session.into_cookie_value())
311 }
312
313 async fn destroy_session(&self, session: Session) -> Result {
314 let id = session.id();
315 let mut connection = self.connection().await?;
316 sqlx::query(&self.substitute_table_name("DELETE FROM %%TABLE_NAME%% WHERE id = ?"))
317 .bind(&id)
318 .execute(&mut connection)
319 .await?;
320
321 Ok(())
322 }
323
324 async fn clear_store(&self) -> Result {
325 let mut connection = self.connection().await?;
326 sqlx::query(&self.substitute_table_name("TRUNCATE %%TABLE_NAME%%"))
327 .execute(&mut connection)
328 .await?;
329
330 Ok(())
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use async_session::chrono::DateTime;
338 use std::time::Duration;
339
340 async fn test_store() -> MySqlSessionStore {
341 let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap())
342 .await
343 .expect("building a MySqlSessionStore");
344
345 store
346 .migrate()
347 .await
348 .expect("migrating a MySqlSessionStore");
349
350 store.clear_store().await.expect("clearing");
351
352 store
353 }
354
355 #[async_std::test]
356 async fn creating_a_new_session_with_no_expiry() -> Result {
357 let store = test_store().await;
358 let mut session = Session::new();
359 session.insert("key", "value")?;
360 let cloned = session.clone();
361 let cookie_value = store.store_session(session).await?.unwrap();
362
363 let (id, expires, serialized, count): (String, Option<DateTime<Utc>>, String, i64) =
364 sqlx::query_as("select id, expires, session, (select count(*) from async_sessions) from async_sessions")
365 .fetch_one(&mut store.connection().await?)
366 .await?;
367
368 assert_eq!(1, count);
369 assert_eq!(id, cloned.id());
370 assert_eq!(expires, None);
371
372 let deserialized_session: Session = serde_json::from_str(&serialized)?;
373 assert_eq!(cloned.id(), deserialized_session.id());
374 assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
375
376 let loaded_session = store.load_session(cookie_value).await?.unwrap();
377 assert_eq!(cloned.id(), loaded_session.id());
378 assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
379
380 assert!(!loaded_session.is_expired());
381 Ok(())
382 }
383
384 #[async_std::test]
385 async fn updating_a_session() -> Result {
386 let store = test_store().await;
387 let mut session = Session::new();
388 let original_id = session.id().to_owned();
389
390 session.insert("key", "value")?;
391 let cookie_value = store.store_session(session).await?.unwrap();
392
393 let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
394 session.insert("key", "other value")?;
395 assert_eq!(None, store.store_session(session).await?);
396
397 let session = store.load_session(cookie_value.clone()).await?.unwrap();
398 assert_eq!(session.get::<String>("key").unwrap(), "other value");
399
400 let (id, count): (String, i64) =
401 sqlx::query_as("select id, (select count(*) from async_sessions) from async_sessions")
402 .fetch_one(&mut store.connection().await?)
403 .await?;
404
405 assert_eq!(1, count);
406 assert_eq!(original_id, id);
407
408 Ok(())
409 }
410
411 #[async_std::test]
412 async fn updating_a_session_extending_expiry() -> Result {
413 let store = test_store().await;
414 let mut session = Session::new();
415 session.expire_in(Duration::from_secs(10));
416 let original_id = session.id().to_owned();
417 let original_expires = session.expiry().unwrap().clone();
418 let cookie_value = store.store_session(session).await?.unwrap();
419
420 let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
421 assert_eq!(session.expiry().unwrap(), &original_expires);
422 session.expire_in(Duration::from_secs(20));
423 let new_expires = session.expiry().unwrap().clone();
424 store.store_session(session).await?;
425
426 let session = store.load_session(cookie_value.clone()).await?.unwrap();
427 assert_eq!(session.expiry().unwrap(), &new_expires);
428
429 let (id, expires, count): (String, DateTime<Utc>, i64) = sqlx::query_as(
430 "select id, expires, (select count(*) from async_sessions) from async_sessions",
431 )
432 .fetch_one(&mut store.connection().await?)
433 .await?;
434
435 assert_eq!(1, count);
436 assert_eq!(expires.timestamp_millis(), new_expires.timestamp_millis());
437 assert_eq!(original_id, id);
438
439 Ok(())
440 }
441
442 #[async_std::test]
443 async fn creating_a_new_session_with_expiry() -> Result {
444 let store = test_store().await;
445 let mut session = Session::new();
446 session.expire_in(Duration::from_secs(1));
447 session.insert("key", "value")?;
448 let cloned = session.clone();
449
450 let cookie_value = store.store_session(session).await?.unwrap();
451
452 let (id, expires, serialized, count): (String, Option<DateTime<Utc>>, String, i64) =
453 sqlx::query_as("select id, expires, session, (select count(*) from async_sessions) from async_sessions")
454 .fetch_one(&mut store.connection().await?)
455 .await?;
456
457 assert_eq!(1, count);
458 assert_eq!(id, cloned.id());
459 assert!(expires.unwrap() > Utc::now());
460
461 let deserialized_session: Session = serde_json::from_str(&serialized)?;
462 assert_eq!(cloned.id(), deserialized_session.id());
463 assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
464
465 let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap();
466 assert_eq!(cloned.id(), loaded_session.id());
467 assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
468
469 assert!(!loaded_session.is_expired());
470
471 async_std::task::sleep(Duration::from_secs(1)).await;
472 assert_eq!(None, store.load_session(cookie_value).await?);
473
474 Ok(())
475 }
476
477 #[async_std::test]
478 async fn destroying_a_single_session() -> Result {
479 let store = test_store().await;
480 for _ in 0..3i8 {
481 store.store_session(Session::new()).await?;
482 }
483
484 let cookie = store.store_session(Session::new()).await?.unwrap();
485 assert_eq!(4, store.count().await?);
486 let session = store.load_session(cookie.clone()).await?.unwrap();
487 store.destroy_session(session.clone()).await.unwrap();
488 assert_eq!(None, store.load_session(cookie).await?);
489 assert_eq!(3, store.count().await?);
490
491 assert!(store.destroy_session(session).await.is_ok());
493 Ok(())
494 }
495
496 #[async_std::test]
497 async fn clearing_the_whole_store() -> Result {
498 let store = test_store().await;
499 for _ in 0..3i8 {
500 store.store_session(Session::new()).await?;
501 }
502
503 assert_eq!(3, store.count().await?);
504 store.clear_store().await.unwrap();
505 assert_eq!(0, store.count().await?);
506
507 Ok(())
508 }
509}