async_sqlx_session/
mysql.rs

1use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore};
2use sqlx::{pool::PoolConnection, Executor, MySql, MySqlPool};
3
4/// sqlx mysql session store for async-sessions
5///
6/// ```rust
7/// use async_sqlx_session::MySqlSessionStore;
8/// use async_session::{Session, SessionStore};
9/// use std::time::Duration;
10///
11/// # fn main() -> async_session::Result { async_std::task::block_on(async {
12/// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await?;
13/// store.migrate().await?;
14/// # store.clear_store().await?;
15/// # #[cfg(feature = "async_std")]
16/// store.spawn_cleanup_task(Duration::from_secs(60 * 60));
17///
18/// let mut session = Session::new();
19/// session.insert("key", vec![1,2,3]);
20///
21/// let cookie_value = store.store_session(session).await?.unwrap();
22/// let session = store.load_session(cookie_value).await?.unwrap();
23/// assert_eq!(session.get::<Vec<i8>>("key").unwrap(), vec![1,2,3]);
24/// # Ok(()) }) }
25///
26#[derive(Clone, Debug)]
27pub struct MySqlSessionStore {
28    client: MySqlPool,
29    table_name: String,
30}
31
32impl MySqlSessionStore {
33    /// constructs a new MySqlSessionStore from an existing
34    /// sqlx::MySqlPool.  the default table name for this session
35    /// store will be "async_sessions". To override this, chain this
36    /// with [`with_table_name`](crate::MySqlSessionStore::with_table_name).
37    ///
38    /// ```rust
39    /// # use async_sqlx_session::MySqlSessionStore;
40    /// # use async_session::Result;
41    /// # fn main() -> Result { async_std::task::block_on(async {
42    /// let pool = sqlx::MySqlPool::connect(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await.unwrap();
43    /// let store = MySqlSessionStore::from_client(pool)
44    ///     .with_table_name("custom_table_name");
45    /// store.migrate().await;
46    /// # Ok(()) }) }
47    /// ```
48    pub fn from_client(client: MySqlPool) -> Self {
49        Self {
50            client,
51            table_name: "async_sessions".into(),
52        }
53    }
54
55    /// Constructs a new MySqlSessionStore from a mysql://
56    /// database url. The default table name for this session store
57    /// will be "async_sessions". To override this, either chain with
58    /// [`with_table_name`](crate::MySqlSessionStore::with_table_name)
59    /// or use
60    /// [`new_with_table_name`](crate::MySqlSessionStore::new_with_table_name)
61    ///
62    /// ```rust
63    /// # use async_sqlx_session::MySqlSessionStore;
64    /// # use async_session::Result;
65    /// # fn main() -> Result { async_std::task::block_on(async {
66    /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await?;
67    /// store.migrate().await;
68    /// # Ok(()) }) }
69    /// ```
70    pub async fn new(database_url: &str) -> sqlx::Result<Self> {
71        let pool = MySqlPool::connect(database_url).await?;
72        Ok(Self::from_client(pool))
73    }
74
75    /// constructs a new MySqlSessionStore from a mysql:// url. the
76    /// default table name for this session store will be
77    /// "async_sessions". To override this, either chain with
78    /// [`with_table_name`](crate::MySqlSessionStore::with_table_name) or
79    /// use
80    /// [`new_with_table_name`](crate::MySqlSessionStore::new_with_table_name)
81    ///
82    /// ```rust
83    /// # use async_sqlx_session::MySqlSessionStore;
84    /// # use async_session::Result;
85    /// # fn main() -> Result { async_std::task::block_on(async {
86    /// let store = MySqlSessionStore::new_with_table_name(&std::env::var("MYSQL_TEST_DB_URL").unwrap(), "custom_table_name").await?;
87    /// store.migrate().await;
88    /// # Ok(()) }) }
89    /// ```
90    pub async fn new_with_table_name(database_url: &str, table_name: &str) -> sqlx::Result<Self> {
91        Ok(Self::new(database_url).await?.with_table_name(table_name))
92    }
93
94    /// Chainable method to add a custom table name. This will panic
95    /// if the table name is not `[a-zA-Z0-9_-]`.
96    /// ```rust
97    /// # use async_sqlx_session::MySqlSessionStore;
98    /// # use async_session::Result;
99    /// # fn main() -> Result { async_std::task::block_on(async {
100    /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await?
101    ///     .with_table_name("custom_name");
102    /// store.migrate().await;
103    /// # Ok(()) }) }
104    /// ```
105    ///
106    /// ```should_panic
107    /// # use async_sqlx_session::MySqlSessionStore;
108    /// # use async_session::Result;
109    /// # fn main() -> Result { async_std::task::block_on(async {
110    /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await?
111    ///     .with_table_name("johnny (); drop users;");
112    /// # Ok(()) }) }
113    /// ```
114    pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Self {
115        let table_name = table_name.as_ref();
116        if table_name.is_empty()
117            || !table_name
118                .chars()
119                .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
120        {
121            panic!(
122                "table name must be [a-zA-Z0-9_-], but {} was not",
123                table_name
124            );
125        }
126
127        self.table_name = table_name.to_owned();
128        self
129    }
130
131    /// Creates a session table if it does not already exist. If it
132    /// does, this will noop, making it safe to call repeatedly on
133    /// store initialization. In the future, this may make
134    /// exactly-once modifications to the schema of the session table
135    /// on breaking releases.
136    /// ```rust
137    /// # use async_sqlx_session::MySqlSessionStore;
138    /// # use async_session::{Result, SessionStore, Session};
139    /// # fn main() -> Result { async_std::task::block_on(async {
140    /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await?;
141    /// # store.clear_store().await?;
142    /// store.migrate().await?;
143    /// store.store_session(Session::new()).await?;
144    /// store.migrate().await?; // calling it a second time is safe
145    /// assert_eq!(store.count().await?, 1);
146    /// # Ok(()) }) }
147    /// ```
148    pub async fn migrate(&self) -> sqlx::Result<()> {
149        log::info!("migrating sessions on `{}`", self.table_name);
150
151        let mut conn = self.client.acquire().await?;
152        conn.execute(&*self.substitute_table_name(
153            r#"
154            CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% (
155                `id` VARCHAR(128) NOT NULL,
156                `expires` TIMESTAMP(6) NULL,
157                `session` TEXT NOT NULL,
158                PRIMARY KEY (`id`),
159                KEY `expires` (`expires`)
160            )
161            ENGINE=InnoDB
162            DEFAULT CHARSET=utf8
163            "#,
164        ))
165        .await?;
166
167        Ok(())
168    }
169
170    fn substitute_table_name(&self, query: &str) -> String {
171        query.replace("%%TABLE_NAME%%", &self.table_name)
172    }
173
174    /// retrieve a connection from the pool
175    async fn connection(&self) -> sqlx::Result<PoolConnection<MySql>> {
176        self.client.acquire().await
177    }
178
179    /// Spawns an async_std::task that clears out stale (expired)
180    /// sessions on a periodic basis. Only available with the
181    /// async_std feature enabled.
182    ///
183    /// ```rust,no_run
184    /// # use async_sqlx_session::MySqlSessionStore;
185    /// # use async_session::{Result, SessionStore, Session};
186    /// # use std::time::Duration;
187    /// # fn main() -> Result { async_std::task::block_on(async {
188    /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await?;
189    /// store.migrate().await?;
190    /// # let join_handle =
191    /// store.spawn_cleanup_task(Duration::from_secs(1));
192    /// let mut session = Session::new();
193    /// session.expire_in(Duration::from_secs(0));
194    /// store.store_session(session).await?;
195    /// assert_eq!(store.count().await?, 1);
196    /// async_std::task::sleep(Duration::from_secs(2)).await;
197    /// assert_eq!(store.count().await?, 0);
198    /// # join_handle.cancel().await;
199    /// # Ok(()) }) }
200    /// ```
201    #[cfg(feature = "async_std")]
202    pub fn spawn_cleanup_task(
203        &self,
204        period: std::time::Duration,
205    ) -> async_std::task::JoinHandle<()> {
206        let store = self.clone();
207        async_std::task::spawn(async move {
208            loop {
209                async_std::task::sleep(period).await;
210                if let Err(error) = store.cleanup().await {
211                    log::error!("cleanup error: {}", error);
212                }
213            }
214        })
215    }
216
217    /// Performs a one-time cleanup task that clears out stale
218    /// (expired) sessions. You may want to call this from cron.
219    /// ```rust
220    /// # use async_sqlx_session::MySqlSessionStore;
221    /// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session};
222    /// # fn main() -> Result { async_std::task::block_on(async {
223    /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await?;
224    /// store.migrate().await?;
225    /// # store.clear_store().await?;
226    /// let mut session = Session::new();
227    /// session.set_expiry(Utc::now() - Duration::seconds(5));
228    /// store.store_session(session).await?;
229    /// assert_eq!(store.count().await?, 1);
230    /// store.cleanup().await?;
231    /// assert_eq!(store.count().await?, 0);
232    /// # Ok(()) }) }
233    /// ```
234    pub async fn cleanup(&self) -> sqlx::Result<()> {
235        let mut connection = self.connection().await?;
236        sqlx::query(&self.substitute_table_name("DELETE FROM %%TABLE_NAME%% WHERE expires < ?"))
237            .bind(Utc::now())
238            .execute(&mut connection)
239            .await?;
240
241        Ok(())
242    }
243
244    /// retrieves the number of sessions currently stored, including
245    /// expired sessions
246    ///
247    /// ```rust
248    /// # use async_sqlx_session::MySqlSessionStore;
249    /// # use async_session::{Result, SessionStore, Session};
250    /// # use std::time::Duration;
251    /// # fn main() -> Result { async_std::task::block_on(async {
252    /// let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap()).await?;
253    /// store.migrate().await?;
254    /// # store.clear_store().await?;
255    /// assert_eq!(store.count().await?, 0);
256    /// store.store_session(Session::new()).await?;
257    /// assert_eq!(store.count().await?, 1);
258    /// # Ok(()) }) }
259    /// ```
260
261    pub async fn count(&self) -> sqlx::Result<i64> {
262        let (count,) =
263            sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%"))
264                .fetch_one(&mut self.connection().await?)
265                .await?;
266
267        Ok(count)
268    }
269}
270
271#[async_trait]
272impl SessionStore for MySqlSessionStore {
273    async fn load_session(&self, cookie_value: String) -> Result<Option<Session>> {
274        let id = Session::id_from_cookie_value(&cookie_value)?;
275        let mut connection = self.connection().await?;
276
277        let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name(
278            "SELECT session FROM %%TABLE_NAME%% WHERE id = ? AND (expires IS NULL OR expires > ?)",
279        ))
280        .bind(&id)
281        .bind(Utc::now())
282        .fetch_optional(&mut connection)
283        .await?;
284
285        Ok(result
286            .map(|(session,)| serde_json::from_str(&session))
287            .transpose()?)
288    }
289
290    async fn store_session(&self, session: Session) -> Result<Option<String>> {
291        let id = session.id();
292        let string = serde_json::to_string(&session)?;
293        let mut connection = self.connection().await?;
294
295        sqlx::query(&self.substitute_table_name(
296            r#"
297            INSERT INTO %%TABLE_NAME%%
298              (id, session, expires) VALUES(?, ?, ?)
299            ON DUPLICATE KEY UPDATE
300              expires = VALUES(expires),
301              session = VALUES(session)
302            "#,
303        ))
304        .bind(&id)
305        .bind(&string)
306        .bind(&session.expiry())
307        .execute(&mut connection)
308        .await?;
309
310        Ok(session.into_cookie_value())
311    }
312
313    async fn destroy_session(&self, session: Session) -> Result {
314        let id = session.id();
315        let mut connection = self.connection().await?;
316        sqlx::query(&self.substitute_table_name("DELETE FROM %%TABLE_NAME%% WHERE id = ?"))
317            .bind(&id)
318            .execute(&mut connection)
319            .await?;
320
321        Ok(())
322    }
323
324    async fn clear_store(&self) -> Result {
325        let mut connection = self.connection().await?;
326        sqlx::query(&self.substitute_table_name("TRUNCATE %%TABLE_NAME%%"))
327            .execute(&mut connection)
328            .await?;
329
330        Ok(())
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use async_session::chrono::DateTime;
338    use std::time::Duration;
339
340    async fn test_store() -> MySqlSessionStore {
341        let store = MySqlSessionStore::new(&std::env::var("MYSQL_TEST_DB_URL").unwrap())
342            .await
343            .expect("building a MySqlSessionStore");
344
345        store
346            .migrate()
347            .await
348            .expect("migrating a MySqlSessionStore");
349
350        store.clear_store().await.expect("clearing");
351
352        store
353    }
354
355    #[async_std::test]
356    async fn creating_a_new_session_with_no_expiry() -> Result {
357        let store = test_store().await;
358        let mut session = Session::new();
359        session.insert("key", "value")?;
360        let cloned = session.clone();
361        let cookie_value = store.store_session(session).await?.unwrap();
362
363        let (id, expires, serialized, count): (String, Option<DateTime<Utc>>, String, i64) =
364            sqlx::query_as("select id, expires, session, (select count(*) from async_sessions) from async_sessions")
365                .fetch_one(&mut store.connection().await?)
366                .await?;
367
368        assert_eq!(1, count);
369        assert_eq!(id, cloned.id());
370        assert_eq!(expires, None);
371
372        let deserialized_session: Session = serde_json::from_str(&serialized)?;
373        assert_eq!(cloned.id(), deserialized_session.id());
374        assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
375
376        let loaded_session = store.load_session(cookie_value).await?.unwrap();
377        assert_eq!(cloned.id(), loaded_session.id());
378        assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
379
380        assert!(!loaded_session.is_expired());
381        Ok(())
382    }
383
384    #[async_std::test]
385    async fn updating_a_session() -> Result {
386        let store = test_store().await;
387        let mut session = Session::new();
388        let original_id = session.id().to_owned();
389
390        session.insert("key", "value")?;
391        let cookie_value = store.store_session(session).await?.unwrap();
392
393        let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
394        session.insert("key", "other value")?;
395        assert_eq!(None, store.store_session(session).await?);
396
397        let session = store.load_session(cookie_value.clone()).await?.unwrap();
398        assert_eq!(session.get::<String>("key").unwrap(), "other value");
399
400        let (id, count): (String, i64) =
401            sqlx::query_as("select id, (select count(*) from async_sessions) from async_sessions")
402                .fetch_one(&mut store.connection().await?)
403                .await?;
404
405        assert_eq!(1, count);
406        assert_eq!(original_id, id);
407
408        Ok(())
409    }
410
411    #[async_std::test]
412    async fn updating_a_session_extending_expiry() -> Result {
413        let store = test_store().await;
414        let mut session = Session::new();
415        session.expire_in(Duration::from_secs(10));
416        let original_id = session.id().to_owned();
417        let original_expires = session.expiry().unwrap().clone();
418        let cookie_value = store.store_session(session).await?.unwrap();
419
420        let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
421        assert_eq!(session.expiry().unwrap(), &original_expires);
422        session.expire_in(Duration::from_secs(20));
423        let new_expires = session.expiry().unwrap().clone();
424        store.store_session(session).await?;
425
426        let session = store.load_session(cookie_value.clone()).await?.unwrap();
427        assert_eq!(session.expiry().unwrap(), &new_expires);
428
429        let (id, expires, count): (String, DateTime<Utc>, i64) = sqlx::query_as(
430            "select id, expires, (select count(*) from async_sessions) from async_sessions",
431        )
432        .fetch_one(&mut store.connection().await?)
433        .await?;
434
435        assert_eq!(1, count);
436        assert_eq!(expires.timestamp_millis(), new_expires.timestamp_millis());
437        assert_eq!(original_id, id);
438
439        Ok(())
440    }
441
442    #[async_std::test]
443    async fn creating_a_new_session_with_expiry() -> Result {
444        let store = test_store().await;
445        let mut session = Session::new();
446        session.expire_in(Duration::from_secs(1));
447        session.insert("key", "value")?;
448        let cloned = session.clone();
449
450        let cookie_value = store.store_session(session).await?.unwrap();
451
452        let (id, expires, serialized, count): (String, Option<DateTime<Utc>>, String, i64) =
453            sqlx::query_as("select id, expires, session, (select count(*) from async_sessions) from async_sessions")
454                .fetch_one(&mut store.connection().await?)
455                .await?;
456
457        assert_eq!(1, count);
458        assert_eq!(id, cloned.id());
459        assert!(expires.unwrap() > Utc::now());
460
461        let deserialized_session: Session = serde_json::from_str(&serialized)?;
462        assert_eq!(cloned.id(), deserialized_session.id());
463        assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
464
465        let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap();
466        assert_eq!(cloned.id(), loaded_session.id());
467        assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
468
469        assert!(!loaded_session.is_expired());
470
471        async_std::task::sleep(Duration::from_secs(1)).await;
472        assert_eq!(None, store.load_session(cookie_value).await?);
473
474        Ok(())
475    }
476
477    #[async_std::test]
478    async fn destroying_a_single_session() -> Result {
479        let store = test_store().await;
480        for _ in 0..3i8 {
481            store.store_session(Session::new()).await?;
482        }
483
484        let cookie = store.store_session(Session::new()).await?.unwrap();
485        assert_eq!(4, store.count().await?);
486        let session = store.load_session(cookie.clone()).await?.unwrap();
487        store.destroy_session(session.clone()).await.unwrap();
488        assert_eq!(None, store.load_session(cookie).await?);
489        assert_eq!(3, store.count().await?);
490
491        // // attempting to destroy the session again is not an error
492        assert!(store.destroy_session(session).await.is_ok());
493        Ok(())
494    }
495
496    #[async_std::test]
497    async fn clearing_the_whole_store() -> Result {
498        let store = test_store().await;
499        for _ in 0..3i8 {
500            store.store_session(Session::new()).await?;
501        }
502
503        assert_eq!(3, store.count().await?);
504        store.clear_store().await.unwrap();
505        assert_eq!(0, store.count().await?);
506
507        Ok(())
508    }
509}