async_sqlx_session/
sqlite.rs

1use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore};
2use sqlx::{pool::PoolConnection, sqlite::SqlitePool, Sqlite};
3
4/// sqlx sqlite session store for async-sessions
5///
6/// ```rust
7/// use async_sqlx_session::SqliteSessionStore;
8/// use async_session::{Session, SessionStore};
9/// use std::time::Duration;
10///
11/// # fn main() -> async_session::Result { async_std::task::block_on(async {
12/// let store = SqliteSessionStore::new("sqlite::memory:").await?;
13/// store.migrate().await?;
14/// # #[cfg(feature = "async_std")]
15/// store.spawn_cleanup_task(Duration::from_secs(60 * 60));
16///
17/// let mut session = Session::new();
18/// session.insert("key", vec![1,2,3]);
19///
20/// let cookie_value = store.store_session(session).await?.unwrap();
21/// let session = store.load_session(cookie_value).await?.unwrap();
22/// assert_eq!(session.get::<Vec<i8>>("key").unwrap(), vec![1,2,3]);
23/// # Ok(()) }) }
24///
25#[derive(Clone, Debug)]
26pub struct SqliteSessionStore {
27    client: SqlitePool,
28    table_name: String,
29}
30
31impl SqliteSessionStore {
32    /// constructs a new SqliteSessionStore from an existing
33    /// sqlx::SqlitePool.  the default table name for this session
34    /// store will be "async_sessions". To override this, chain this
35    /// with [`with_table_name`](crate::SqliteSessionStore::with_table_name).
36    ///
37    /// ```rust
38    /// # use async_sqlx_session::SqliteSessionStore;
39    /// # use async_session::Result;
40    /// # fn main() -> Result { async_std::task::block_on(async {
41    /// let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
42    /// let store = SqliteSessionStore::from_client(pool)
43    ///     .with_table_name("custom_table_name");
44    /// store.migrate().await;
45    /// # Ok(()) }) }
46    /// ```
47    pub fn from_client(client: SqlitePool) -> Self {
48        Self {
49            client,
50            table_name: "async_sessions".into(),
51        }
52    }
53
54    /// Constructs a new SqliteSessionStore from a sqlite: database url. note
55    /// that this documentation uses the special `:memory:` sqlite
56    /// database for convenient testing, but a real application would
57    /// use a path like `sqlite:///path/to/database.db`. The default
58    /// table name for this session store will be "async_sessions". To
59    /// override this, either chain with
60    /// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
61    /// use
62    /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
63    ///
64    /// ```rust
65    /// # use async_sqlx_session::SqliteSessionStore;
66    /// # use async_session::Result;
67    /// # fn main() -> Result { async_std::task::block_on(async {
68    /// let store = SqliteSessionStore::new("sqlite::memory:").await?;
69    /// store.migrate().await;
70    /// # Ok(()) }) }
71    /// ```
72    pub async fn new(database_url: &str) -> sqlx::Result<Self> {
73        Ok(Self::from_client(SqlitePool::connect(database_url).await?))
74    }
75
76    /// constructs a new SqliteSessionStore from a sqlite: database url. the
77    /// default table name for this session store will be
78    /// "async_sessions". To override this, either chain with
79    /// [`with_table_name`](crate::SqliteSessionStore::with_table_name) or
80    /// use
81    /// [`new_with_table_name`](crate::SqliteSessionStore::new_with_table_name)
82    ///
83    /// ```rust
84    /// # use async_sqlx_session::SqliteSessionStore;
85    /// # use async_session::Result;
86    /// # fn main() -> Result { async_std::task::block_on(async {
87    /// let store = SqliteSessionStore::new_with_table_name("sqlite::memory:", "custom_table_name").await?;
88    /// store.migrate().await;
89    /// # Ok(()) }) }
90    /// ```
91    pub async fn new_with_table_name(database_url: &str, table_name: &str) -> sqlx::Result<Self> {
92        Ok(Self::new(database_url).await?.with_table_name(table_name))
93    }
94
95    /// Chainable method to add a custom table name. This will panic
96    /// if the table name is not `[a-zA-Z0-9_-]+`.
97    /// ```rust
98    /// # use async_sqlx_session::SqliteSessionStore;
99    /// # use async_session::Result;
100    /// # fn main() -> Result { async_std::task::block_on(async {
101    /// let store = SqliteSessionStore::new("sqlite::memory:").await?
102    ///     .with_table_name("custom_name");
103    /// store.migrate().await;
104    /// # Ok(()) }) }
105    /// ```
106    ///
107    /// ```should_panic
108    /// # use async_sqlx_session::SqliteSessionStore;
109    /// # use async_session::Result;
110    /// # fn main() -> Result { async_std::task::block_on(async {
111    /// let store = SqliteSessionStore::new("sqlite::memory:").await?
112    ///     .with_table_name("johnny (); drop users;");
113    /// # Ok(()) }) }
114    /// ```
115    pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Self {
116        let table_name = table_name.as_ref();
117        if table_name.is_empty()
118            || !table_name
119                .chars()
120                .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
121        {
122            panic!(
123                "table name must be [a-zA-Z0-9_-]+, but {} was not",
124                table_name
125            );
126        }
127
128        self.table_name = table_name.to_owned();
129        self
130    }
131
132    /// Creates a session table if it does not already exist. If it
133    /// does, this will noop, making it safe to call repeatedly on
134    /// store initialization. In the future, this may make
135    /// exactly-once modifications to the schema of the session table
136    /// on breaking releases.
137    /// ```rust
138    /// # use async_sqlx_session::SqliteSessionStore;
139    /// # use async_session::{Result, SessionStore, Session};
140    /// # fn main() -> Result { async_std::task::block_on(async {
141    /// let store = SqliteSessionStore::new("sqlite::memory:").await?;
142    /// assert!(store.count().await.is_err());
143    /// store.migrate().await?;
144    /// store.store_session(Session::new()).await?;
145    /// store.migrate().await?; // calling it a second time is safe
146    /// assert_eq!(store.count().await?, 1);
147    /// # Ok(()) }) }
148    /// ```
149    pub async fn migrate(&self) -> sqlx::Result<()> {
150        log::info!("migrating sessions on `{}`", self.table_name);
151
152        let mut conn = self.client.acquire().await?;
153        sqlx::query(&self.substitute_table_name(
154            r#"
155            CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% (
156                id TEXT PRIMARY KEY NOT NULL,
157                expires INTEGER NULL,
158                session TEXT NOT NULL
159            )
160            "#,
161        ))
162        .execute(&mut conn)
163        .await?;
164        Ok(())
165    }
166
167    // private utility function because sqlite does not support
168    // parametrized table names
169    fn substitute_table_name(&self, query: &str) -> String {
170        query.replace("%%TABLE_NAME%%", &self.table_name)
171    }
172
173    /// retrieve a connection from the pool
174    async fn connection(&self) -> sqlx::Result<PoolConnection<Sqlite>> {
175        self.client.acquire().await
176    }
177
178    /// Spawns an async_std::task that clears out stale (expired)
179    /// sessions on a periodic basis. Only available with the
180    /// async_std feature enabled.
181    ///
182    /// ```rust,no_run
183    /// # use async_sqlx_session::SqliteSessionStore;
184    /// # use async_session::{Result, SessionStore, Session};
185    /// # use std::time::Duration;
186    /// # fn main() -> Result { async_std::task::block_on(async {
187    /// let store = SqliteSessionStore::new("sqlite::memory:").await?;
188    /// store.migrate().await?;
189    /// # let join_handle =
190    /// store.spawn_cleanup_task(Duration::from_secs(1));
191    /// let mut session = Session::new();
192    /// session.expire_in(Duration::from_secs(0));
193    /// store.store_session(session).await?;
194    /// assert_eq!(store.count().await?, 1);
195    /// async_std::task::sleep(Duration::from_secs(2)).await;
196    /// assert_eq!(store.count().await?, 0);
197    /// # join_handle.cancel().await;
198    /// # Ok(()) }) }
199    /// ```
200    #[cfg(feature = "async_std")]
201    pub fn spawn_cleanup_task(
202        &self,
203        period: std::time::Duration,
204    ) -> async_std::task::JoinHandle<()> {
205        let store = self.clone();
206        async_std::task::spawn(async move {
207            loop {
208                async_std::task::sleep(period).await;
209                if let Err(error) = store.cleanup().await {
210                    log::error!("cleanup error: {}", error);
211                }
212            }
213        })
214    }
215
216    /// Performs a one-time cleanup task that clears out stale
217    /// (expired) sessions. You may want to call this from cron.
218    /// ```rust
219    /// # use async_sqlx_session::SqliteSessionStore;
220    /// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session};
221    /// # fn main() -> Result { async_std::task::block_on(async {
222    /// let store = SqliteSessionStore::new("sqlite::memory:").await?;
223    /// store.migrate().await?;
224    /// let mut session = Session::new();
225    /// session.set_expiry(Utc::now() - Duration::seconds(5));
226    /// store.store_session(session).await?;
227    /// assert_eq!(store.count().await?, 1);
228    /// store.cleanup().await?;
229    /// assert_eq!(store.count().await?, 0);
230    /// # Ok(()) }) }
231    /// ```
232    pub async fn cleanup(&self) -> sqlx::Result<()> {
233        let mut connection = self.connection().await?;
234        sqlx::query(&self.substitute_table_name(
235            r#"
236            DELETE FROM %%TABLE_NAME%%
237            WHERE expires < ?
238            "#,
239        ))
240        .bind(Utc::now().timestamp())
241        .execute(&mut connection)
242        .await?;
243
244        Ok(())
245    }
246
247    /// retrieves the number of sessions currently stored, including
248    /// expired sessions
249    ///
250    /// ```rust
251    /// # use async_sqlx_session::SqliteSessionStore;
252    /// # use async_session::{Result, SessionStore, Session};
253    /// # use std::time::Duration;
254    /// # fn main() -> Result { async_std::task::block_on(async {
255    /// let store = SqliteSessionStore::new("sqlite::memory:").await?;
256    /// store.migrate().await?;
257    /// assert_eq!(store.count().await?, 0);
258    /// store.store_session(Session::new()).await?;
259    /// assert_eq!(store.count().await?, 1);
260    /// # Ok(()) }) }
261    /// ```
262
263    pub async fn count(&self) -> sqlx::Result<i32> {
264        let (count,) =
265            sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%"))
266                .fetch_one(&mut self.connection().await?)
267                .await?;
268
269        Ok(count)
270    }
271}
272
273#[async_trait]
274impl SessionStore for SqliteSessionStore {
275    async fn load_session(&self, cookie_value: String) -> Result<Option<Session>> {
276        let id = Session::id_from_cookie_value(&cookie_value)?;
277        let mut connection = self.connection().await?;
278
279        let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name(
280            r#"
281            SELECT session FROM %%TABLE_NAME%%
282              WHERE id = ? AND (expires IS NULL OR expires > ?)
283            "#,
284        ))
285        .bind(&id)
286        .bind(Utc::now().timestamp())
287        .fetch_optional(&mut connection)
288        .await?;
289
290        Ok(result
291            .map(|(session,)| serde_json::from_str(&session))
292            .transpose()?)
293    }
294
295    async fn store_session(&self, session: Session) -> Result<Option<String>> {
296        let id = session.id();
297        let string = serde_json::to_string(&session)?;
298        let mut connection = self.connection().await?;
299
300        sqlx::query(&self.substitute_table_name(
301            r#"
302            INSERT INTO %%TABLE_NAME%%
303              (id, session, expires) VALUES (?, ?, ?)
304            ON CONFLICT(id) DO UPDATE SET
305              expires = excluded.expires,
306              session = excluded.session
307            "#,
308        ))
309        .bind(&id)
310        .bind(&string)
311        .bind(&session.expiry().map(|expiry| expiry.timestamp()))
312        .execute(&mut connection)
313        .await?;
314
315        Ok(session.into_cookie_value())
316    }
317
318    async fn destroy_session(&self, session: Session) -> Result {
319        let id = session.id();
320        let mut connection = self.connection().await?;
321        sqlx::query(&self.substitute_table_name(
322            r#"
323            DELETE FROM %%TABLE_NAME%% WHERE id = ?
324            "#,
325        ))
326        .bind(&id)
327        .execute(&mut connection)
328        .await?;
329
330        Ok(())
331    }
332
333    async fn clear_store(&self) -> Result {
334        let mut connection = self.connection().await?;
335        sqlx::query(&self.substitute_table_name(
336            r#"
337            DELETE FROM %%TABLE_NAME%%
338            "#,
339        ))
340        .execute(&mut connection)
341        .await?;
342
343        Ok(())
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use std::time::Duration;
351
352    async fn test_store() -> SqliteSessionStore {
353        let store = SqliteSessionStore::new("sqlite::memory:")
354            .await
355            .expect("building a sqlite :memory: SqliteSessionStore");
356        store
357            .migrate()
358            .await
359            .expect("migrating a brand new :memory: SqliteSessionStore");
360        store
361    }
362
363    #[async_std::test]
364    async fn creating_a_new_session_with_no_expiry() -> Result {
365        let store = test_store().await;
366        let mut session = Session::new();
367        session.insert("key", "value")?;
368        let cloned = session.clone();
369        let cookie_value = store.store_session(session).await?.unwrap();
370
371        let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
372            sqlx::query_as("select id, expires, session, count(*) from async_sessions")
373                .fetch_one(&mut store.connection().await?)
374                .await?;
375
376        assert_eq!(1, count);
377        assert_eq!(id, cloned.id());
378        assert_eq!(expires, None);
379
380        let deserialized_session: Session = serde_json::from_str(&serialized)?;
381        assert_eq!(cloned.id(), deserialized_session.id());
382        assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
383
384        let loaded_session = store.load_session(cookie_value).await?.unwrap();
385        assert_eq!(cloned.id(), loaded_session.id());
386        assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
387
388        assert!(!loaded_session.is_expired());
389        Ok(())
390    }
391
392    #[async_std::test]
393    async fn updating_a_session() -> Result {
394        let store = test_store().await;
395        let mut session = Session::new();
396        let original_id = session.id().to_owned();
397
398        session.insert("key", "value")?;
399        let cookie_value = store.store_session(session).await?.unwrap();
400
401        let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
402        session.insert("key", "other value")?;
403        assert_eq!(None, store.store_session(session).await?);
404
405        let session = store.load_session(cookie_value.clone()).await?.unwrap();
406        assert_eq!(session.get::<String>("key").unwrap(), "other value");
407
408        let (id, count): (String, i64) = sqlx::query_as("select id, count(*) from async_sessions")
409            .fetch_one(&mut store.connection().await?)
410            .await?;
411
412        assert_eq!(1, count);
413        assert_eq!(original_id, id);
414
415        Ok(())
416    }
417
418    #[async_std::test]
419    async fn updating_a_session_extending_expiry() -> Result {
420        let store = test_store().await;
421        let mut session = Session::new();
422        session.expire_in(Duration::from_secs(10));
423        let original_id = session.id().to_owned();
424        let original_expires = session.expiry().unwrap().clone();
425        let cookie_value = store.store_session(session).await?.unwrap();
426
427        let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
428        assert_eq!(session.expiry().unwrap(), &original_expires);
429        session.expire_in(Duration::from_secs(20));
430        let new_expires = session.expiry().unwrap().clone();
431        store.store_session(session).await?;
432
433        let session = store.load_session(cookie_value.clone()).await?.unwrap();
434        assert_eq!(session.expiry().unwrap(), &new_expires);
435
436        let (id, expires, count): (String, i64, i64) =
437            sqlx::query_as("select id, expires, count(*) from async_sessions")
438                .fetch_one(&mut store.connection().await?)
439                .await?;
440
441        assert_eq!(1, count);
442        assert_eq!(expires, new_expires.timestamp());
443        assert_eq!(original_id, id);
444
445        Ok(())
446    }
447
448    #[async_std::test]
449    async fn creating_a_new_session_with_expiry() -> Result {
450        let store = test_store().await;
451        let mut session = Session::new();
452        session.expire_in(Duration::from_secs(1));
453        session.insert("key", "value")?;
454        let cloned = session.clone();
455
456        let cookie_value = store.store_session(session).await?.unwrap();
457
458        let (id, expires, serialized, count): (String, Option<i64>, String, i64) =
459            sqlx::query_as("select id, expires, session, count(*) from async_sessions")
460                .fetch_one(&mut store.connection().await?)
461                .await?;
462
463        assert_eq!(1, count);
464        assert_eq!(id, cloned.id());
465        assert!(expires.unwrap() > Utc::now().timestamp());
466
467        let deserialized_session: Session = serde_json::from_str(&serialized)?;
468        assert_eq!(cloned.id(), deserialized_session.id());
469        assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
470
471        let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap();
472        assert_eq!(cloned.id(), loaded_session.id());
473        assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
474
475        assert!(!loaded_session.is_expired());
476
477        async_std::task::sleep(Duration::from_secs(1)).await;
478        assert_eq!(None, store.load_session(cookie_value).await?);
479
480        Ok(())
481    }
482
483    #[async_std::test]
484    async fn destroying_a_single_session() -> Result {
485        let store = test_store().await;
486        for _ in 0..3i8 {
487            store.store_session(Session::new()).await?;
488        }
489
490        let cookie = store.store_session(Session::new()).await?.unwrap();
491        assert_eq!(4, store.count().await?);
492        let session = store.load_session(cookie.clone()).await?.unwrap();
493        store.destroy_session(session.clone()).await.unwrap();
494        assert_eq!(None, store.load_session(cookie).await?);
495        assert_eq!(3, store.count().await?);
496
497        // // attempting to destroy the session again is not an error
498        // assert!(store.destroy_session(session).await.is_ok());
499        Ok(())
500    }
501
502    #[async_std::test]
503    async fn clearing_the_whole_store() -> Result {
504        let store = test_store().await;
505        for _ in 0..3i8 {
506            store.store_session(Session::new()).await?;
507        }
508
509        assert_eq!(3, store.count().await?);
510        store.clear_store().await.unwrap();
511        assert_eq!(0, store.count().await?);
512
513        Ok(())
514    }
515}