1use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore};
2use sqlx::{pool::PoolConnection, sqlite::SqlitePool, Sqlite};
3
4#[derive(Clone, Debug)]
26pub struct SqliteSessionStore {
27 client: SqlitePool,
28 table_name: String,
29}
30
31impl SqliteSessionStore {
32 pub fn from_client(client: SqlitePool) -> Self {
48 Self {
49 client,
50 table_name: "async_sessions".into(),
51 }
52 }
53
54 pub async fn new(database_url: &str) -> sqlx::Result<Self> {
73 Ok(Self::from_client(SqlitePool::connect(database_url).await?))
74 }
75
76 pub async fn new_with_table_name(database_url: &str, table_name: &str) -> sqlx::Result<Self> {
92 Ok(Self::new(database_url).await?.with_table_name(table_name))
93 }
94
95 pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Self {
116 let table_name = table_name.as_ref();
117 if table_name.is_empty()
118 || !table_name
119 .chars()
120 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
121 {
122 panic!(
123 "table name must be [a-zA-Z0-9_-]+, but {} was not",
124 table_name
125 );
126 }
127
128 self.table_name = table_name.to_owned();
129 self
130 }
131
132 pub async fn migrate(&self) -> sqlx::Result<()> {
150 log::info!("migrating sessions on `{}`", self.table_name);
151
152 let mut conn = self.client.acquire().await?;
153 sqlx::query(&self.substitute_table_name(
154 r#"
155 CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% (
156 id TEXT PRIMARY KEY NOT NULL,
157 expires INTEGER NULL,
158 session TEXT NOT NULL
159 )
160 "#,
161 ))
162 .execute(&mut conn)
163 .await?;
164 Ok(())
165 }
166
167 fn substitute_table_name(&self, query: &str) -> String {
170 query.replace("%%TABLE_NAME%%", &self.table_name)
171 }
172
173 async fn connection(&self) -> sqlx::Result<PoolConnection<Sqlite>> {
175 self.client.acquire().await
176 }
177
178 #[cfg(feature = "async_std")]
201 pub fn spawn_cleanup_task(
202 &self,
203 period: std::time::Duration,
204 ) -> async_std::task::JoinHandle<()> {
205 let store = self.clone();
206 async_std::task::spawn(async move {
207 loop {
208 async_std::task::sleep(period).await;
209 if let Err(error) = store.cleanup().await {
210 log::error!("cleanup error: {}", error);
211 }
212 }
213 })
214 }
215
216 pub async fn cleanup(&self) -> sqlx::Result<()> {
233 let mut connection = self.connection().await?;
234 sqlx::query(&self.substitute_table_name(
235 r#"
236 DELETE FROM %%TABLE_NAME%%
237 WHERE expires < ?
238 "#,
239 ))
240 .bind(Utc::now().timestamp())
241 .execute(&mut connection)
242 .await?;
243
244 Ok(())
245 }
246
247 pub async fn count(&self) -> sqlx::Result<i32> {
264 let (count,) =
265 sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%"))
266 .fetch_one(&mut self.connection().await?)
267 .await?;
268
269 Ok(count)
270 }
271}
272
273#[async_trait]
274impl SessionStore for SqliteSessionStore {
275 async fn load_session(&self, cookie_value: String) -> Result<Option<Session>> {
276 let id = Session::id_from_cookie_value(&cookie_value)?;
277 let mut connection = self.connection().await?;
278
279 let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name(
280 r#"
281 SELECT session FROM %%TABLE_NAME%%
282 WHERE id = ? AND (expires IS NULL OR expires > ?)
283 "#,
284 ))
285 .bind(&id)
286 .bind(Utc::now().timestamp())
287 .fetch_optional(&mut connection)
288 .await?;
289
290 Ok(result
291 .map(|(session,)| serde_json::from_str(&session))
292 .transpose()?)
293 }
294
295 async fn store_session(&self, session: Session) -> Result<Option<String>> {
296 let id = session.id();
297 let string = serde_json::to_string(&session)?;
298 let mut connection = self.connection().await?;
299
300 sqlx::query(&self.substitute_table_name(
301 r#"
302 INSERT INTO %%TABLE_NAME%%
303 (id, session, expires) VALUES (?, ?, ?)
304 ON CONFLICT(id) DO UPDATE SET
305 expires = excluded.expires,
306 session = excluded.session
307 "#,
308 ))
309 .bind(&id)
310 .bind(&string)
311 .bind(&session.expiry().map(|expiry| expiry.timestamp()))
312 .execute(&mut connection)
313 .await?;
314
315 Ok(session.into_cookie_value())
316 }
317
318 async fn destroy_session(&self, session: Session) -> Result {
319 let id = session.id();
320 let mut connection = self.connection().await?;
321 sqlx::query(&self.substitute_table_name(
322 r#"
323 DELETE FROM %%TABLE_NAME%% WHERE id = ?
324 "#,
325 ))
326 .bind(&id)
327 .execute(&mut connection)
328 .await?;
329
330 Ok(())
331 }
332
333 async fn clear_store(&self) -> Result {
334 let mut connection = self.connection().await?;
335 sqlx::query(&self.substitute_table_name(
336 r#"
337 DELETE FROM %%TABLE_NAME%%
338 "#,
339 ))
340 .execute(&mut connection)
341 .await?;
342
343 Ok(())
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use std::time::Duration;
351
352 async fn test_store() -> SqliteSessionStore {
353 let store = SqliteSessionStore::new("sqlite::memory:")
354 .await
355 .expect("building a sqlite :memory: SqliteSessionStore");
356 store
357 .migrate()
358 .await
359 .expect("migrating a brand new :memory: SqliteSessionStore");
360 store
361 }
362
363 #[async_std::test]
364 async fn creating_a_new_session_with_no_expiry() -> Result {
365 let store = test_store().await;
366 let mut session = Session::new();
367 session.insert("key", "value")?;
368 let cloned = session.clone();
369 let cookie_value = store.store_session(session).await?.unwrap();
370
371 let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
372 sqlx::query_as("select id, expires, session, count(*) from async_sessions")
373 .fetch_one(&mut store.connection().await?)
374 .await?;
375
376 assert_eq!(1, count);
377 assert_eq!(id, cloned.id());
378 assert_eq!(expires, None);
379
380 let deserialized_session: Session = serde_json::from_str(&serialized)?;
381 assert_eq!(cloned.id(), deserialized_session.id());
382 assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
383
384 let loaded_session = store.load_session(cookie_value).await?.unwrap();
385 assert_eq!(cloned.id(), loaded_session.id());
386 assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
387
388 assert!(!loaded_session.is_expired());
389 Ok(())
390 }
391
392 #[async_std::test]
393 async fn updating_a_session() -> Result {
394 let store = test_store().await;
395 let mut session = Session::new();
396 let original_id = session.id().to_owned();
397
398 session.insert("key", "value")?;
399 let cookie_value = store.store_session(session).await?.unwrap();
400
401 let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
402 session.insert("key", "other value")?;
403 assert_eq!(None, store.store_session(session).await?);
404
405 let session = store.load_session(cookie_value.clone()).await?.unwrap();
406 assert_eq!(session.get::<String>("key").unwrap(), "other value");
407
408 let (id, count): (String, i64) = sqlx::query_as("select id, count(*) from async_sessions")
409 .fetch_one(&mut store.connection().await?)
410 .await?;
411
412 assert_eq!(1, count);
413 assert_eq!(original_id, id);
414
415 Ok(())
416 }
417
418 #[async_std::test]
419 async fn updating_a_session_extending_expiry() -> Result {
420 let store = test_store().await;
421 let mut session = Session::new();
422 session.expire_in(Duration::from_secs(10));
423 let original_id = session.id().to_owned();
424 let original_expires = session.expiry().unwrap().clone();
425 let cookie_value = store.store_session(session).await?.unwrap();
426
427 let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
428 assert_eq!(session.expiry().unwrap(), &original_expires);
429 session.expire_in(Duration::from_secs(20));
430 let new_expires = session.expiry().unwrap().clone();
431 store.store_session(session).await?;
432
433 let session = store.load_session(cookie_value.clone()).await?.unwrap();
434 assert_eq!(session.expiry().unwrap(), &new_expires);
435
436 let (id, expires, count): (String, i64, i64) =
437 sqlx::query_as("select id, expires, count(*) from async_sessions")
438 .fetch_one(&mut store.connection().await?)
439 .await?;
440
441 assert_eq!(1, count);
442 assert_eq!(expires, new_expires.timestamp());
443 assert_eq!(original_id, id);
444
445 Ok(())
446 }
447
448 #[async_std::test]
449 async fn creating_a_new_session_with_expiry() -> Result {
450 let store = test_store().await;
451 let mut session = Session::new();
452 session.expire_in(Duration::from_secs(1));
453 session.insert("key", "value")?;
454 let cloned = session.clone();
455
456 let cookie_value = store.store_session(session).await?.unwrap();
457
458 let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
459 sqlx::query_as("select id, expires, session, count(*) from async_sessions")
460 .fetch_one(&mut store.connection().await?)
461 .await?;
462
463 assert_eq!(1, count);
464 assert_eq!(id, cloned.id());
465 assert!(expires.unwrap() > Utc::now().timestamp());
466
467 let deserialized_session: Session = serde_json::from_str(&serialized)?;
468 assert_eq!(cloned.id(), deserialized_session.id());
469 assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
470
471 let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap();
472 assert_eq!(cloned.id(), loaded_session.id());
473 assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
474
475 assert!(!loaded_session.is_expired());
476
477 async_std::task::sleep(Duration::from_secs(1)).await;
478 assert_eq!(None, store.load_session(cookie_value).await?);
479
480 Ok(())
481 }
482
483 #[async_std::test]
484 async fn destroying_a_single_session() -> Result {
485 let store = test_store().await;
486 for _ in 0..3i8 {
487 store.store_session(Session::new()).await?;
488 }
489
490 let cookie = store.store_session(Session::new()).await?.unwrap();
491 assert_eq!(4, store.count().await?);
492 let session = store.load_session(cookie.clone()).await?.unwrap();
493 store.destroy_session(session.clone()).await.unwrap();
494 assert_eq!(None, store.load_session(cookie).await?);
495 assert_eq!(3, store.count().await?);
496
497 Ok(())
500 }
501
502 #[async_std::test]
503 async fn clearing_the_whole_store() -> Result {
504 let store = test_store().await;
505 for _ in 0..3i8 {
506 store.store_session(Session::new()).await?;
507 }
508
509 assert_eq!(3, store.count().await?);
510 store.clear_store().await.unwrap();
511 assert_eq!(0, store.count().await?);
512
513 Ok(())
514 }
515}