async_sqlx_session/
pg.rs

1use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore};
2use sqlx::{pool::PoolConnection, Executor, PgPool, Postgres};
3
4/// sqlx postgres session store for async-sessions
5///
6/// ```rust
7/// use async_sqlx_session::PostgresSessionStore;
8/// use async_session::{Session, SessionStore};
9/// use std::time::Duration;
10///
11/// # fn main() -> async_session::Result { async_std::task::block_on(async {
12/// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await?;
13/// store.migrate().await?;
14/// # store.clear_store().await?;
15/// # #[cfg(feature = "async_std")] {
16/// store.spawn_cleanup_task(Duration::from_secs(60 * 60));
17/// # }
18///
19/// let mut session = Session::new();
20/// session.insert("key", vec![1,2,3]);
21///
22/// let cookie_value = store.store_session(session).await?.unwrap();
23/// let session = store.load_session(cookie_value).await?.unwrap();
24/// assert_eq!(session.get::<Vec<i8>>("key").unwrap(), vec![1,2,3]);
25/// # Ok(()) }) }
26///
27#[derive(Clone, Debug)]
28pub struct PostgresSessionStore {
29    client: PgPool,
30    table_name: String,
31}
32
33impl PostgresSessionStore {
34    /// constructs a new PostgresSessionStore from an existing
35    /// sqlx::PgPool.  the default table name for this session
36    /// store will be "async_sessions". To override this, chain this
37    /// with [`with_table_name`](crate::PostgresSessionStore::with_table_name).
38    ///
39    /// ```rust
40    /// # use async_sqlx_session::PostgresSessionStore;
41    /// # use async_session::Result;
42    /// # fn main() -> Result { async_std::task::block_on(async {
43    /// let pool = sqlx::PgPool::connect(&std::env::var("PG_TEST_DB_URL").unwrap()).await.unwrap();
44    /// let store = PostgresSessionStore::from_client(pool)
45    ///     .with_table_name("custom_table_name");
46    /// store.migrate().await;
47    /// # Ok(()) }) }
48    /// ```
49    pub fn from_client(client: PgPool) -> Self {
50        Self {
51            client,
52            table_name: "async_sessions".into(),
53        }
54    }
55
56    /// Constructs a new PostgresSessionStore from a postgres://
57    /// database url. The default table name for this session store
58    /// will be "async_sessions". To override this, either chain with
59    /// [`with_table_name`](crate::PostgresSessionStore::with_table_name)
60    /// or use
61    /// [`new_with_table_name`](crate::PostgresSessionStore::new_with_table_name)
62    ///
63    /// ```rust
64    /// # use async_sqlx_session::PostgresSessionStore;
65    /// # use async_session::Result;
66    /// # fn main() -> Result { async_std::task::block_on(async {
67    /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await?;
68    /// store.migrate().await;
69    /// # Ok(()) }) }
70    /// ```
71    pub async fn new(database_url: &str) -> sqlx::Result<Self> {
72        let pool = PgPool::connect(database_url).await?;
73        Ok(Self::from_client(pool))
74    }
75
76    /// constructs a new PostgresSessionStore from a postgres:// url. the
77    /// default table name for this session store will be
78    /// "async_sessions". To override this, either chain with
79    /// [`with_table_name`](crate::PostgresSessionStore::with_table_name) or
80    /// use
81    /// [`new_with_table_name`](crate::PostgresSessionStore::new_with_table_name)
82    ///
83    /// ```rust
84    /// # use async_sqlx_session::PostgresSessionStore;
85    /// # use async_session::Result;
86    /// # fn main() -> Result { async_std::task::block_on(async {
87    /// let store = PostgresSessionStore::new_with_table_name(&std::env::var("PG_TEST_DB_URL").unwrap(), "custom_table_name").await?;
88    /// store.migrate().await;
89    /// # Ok(()) }) }
90    /// ```
91    pub async fn new_with_table_name(database_url: &str, table_name: &str) -> sqlx::Result<Self> {
92        Ok(Self::new(database_url).await?.with_table_name(table_name))
93    }
94
95    /// Chainable method to add a custom table name. This will panic
96    /// if the table name is not `[a-zA-Z0-9_-]+`.
97    /// ```rust
98    /// # use async_sqlx_session::PostgresSessionStore;
99    /// # use async_session::Result;
100    /// # fn main() -> Result { async_std::task::block_on(async {
101    /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await?
102    ///     .with_table_name("custom_name");
103    /// store.migrate().await;
104    /// # Ok(()) }) }
105    /// ```
106    ///
107    /// ```should_panic
108    /// # use async_sqlx_session::PostgresSessionStore;
109    /// # use async_session::Result;
110    /// # fn main() -> Result { async_std::task::block_on(async {
111    /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await?
112    ///     .with_table_name("johnny (); drop users;");
113    /// # Ok(()) }) }
114    /// ```
115    pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Self {
116        let table_name = table_name.as_ref();
117        if table_name.is_empty()
118            || !table_name
119                .chars()
120                .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
121        {
122            panic!(
123                "table name must be [a-zA-Z0-9_-]+, but {} was not",
124                table_name
125            );
126        }
127
128        self.table_name = table_name.to_owned();
129        self
130    }
131
132    /// Creates a session table if it does not already exist. If it
133    /// does, this will noop, making it safe to call repeatedly on
134    /// store initialization. In the future, this may make
135    /// exactly-once modifications to the schema of the session table
136    /// on breaking releases.
137    /// ```rust
138    /// # use async_sqlx_session::PostgresSessionStore;
139    /// # use async_session::{Result, SessionStore, Session};
140    /// # fn main() -> Result { async_std::task::block_on(async {
141    /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await?;
142    /// # store.clear_store().await?;
143    /// store.migrate().await?;
144    /// store.store_session(Session::new()).await?;
145    /// store.migrate().await?; // calling it a second time is safe
146    /// assert_eq!(store.count().await?, 1);
147    /// # Ok(()) }) }
148    /// ```
149    pub async fn migrate(&self) -> sqlx::Result<()> {
150        log::info!("migrating sessions on `{}`", self.table_name);
151
152        let mut conn = self.client.acquire().await?;
153        conn.execute(&*self.substitute_table_name(
154            r#"
155            CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% (
156                "id" VARCHAR NOT NULL PRIMARY KEY,
157                "expires" TIMESTAMP WITH TIME ZONE NULL,
158                "session" TEXT NOT NULL
159            )
160            "#,
161        ))
162        .await?;
163
164        Ok(())
165    }
166
167    fn substitute_table_name(&self, query: &str) -> String {
168        query.replace("%%TABLE_NAME%%", &self.table_name)
169    }
170
171    /// retrieve a connection from the pool
172    async fn connection(&self) -> sqlx::Result<PoolConnection<Postgres>> {
173        self.client.acquire().await
174    }
175
176    /// Spawns an async_std::task that clears out stale (expired)
177    /// sessions on a periodic basis. Only available with the
178    /// async_std feature enabled.
179    ///
180    /// ```rust,no_run
181    /// # use async_sqlx_session::PostgresSessionStore;
182    /// # use async_session::{Result, SessionStore, Session};
183    /// # use std::time::Duration;
184    /// # fn main() -> Result { async_std::task::block_on(async {
185    /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await?;
186    /// store.migrate().await?;
187    /// # let join_handle =
188    /// store.spawn_cleanup_task(Duration::from_secs(1));
189    /// let mut session = Session::new();
190    /// session.expire_in(Duration::from_secs(0));
191    /// store.store_session(session).await?;
192    /// assert_eq!(store.count().await?, 1);
193    /// async_std::task::sleep(Duration::from_secs(2)).await;
194    /// assert_eq!(store.count().await?, 0);
195    /// # join_handle.cancel().await;
196    /// # Ok(()) }) }
197    /// ```
198    #[cfg(feature = "async_std")]
199    pub fn spawn_cleanup_task(
200        &self,
201        period: std::time::Duration,
202    ) -> async_std::task::JoinHandle<()> {
203        use async_std::task;
204        let store = self.clone();
205        task::spawn(async move {
206            loop {
207                task::sleep(period).await;
208                if let Err(error) = store.cleanup().await {
209                    log::error!("cleanup error: {}", error);
210                }
211            }
212        })
213    }
214
215    /// Performs a one-time cleanup task that clears out stale
216    /// (expired) sessions. You may want to call this from cron.
217    /// ```rust
218    /// # use async_sqlx_session::PostgresSessionStore;
219    /// # use async_session::{chrono::{Utc,Duration}, Result, SessionStore, Session};
220    /// # fn main() -> Result { async_std::task::block_on(async {
221    /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await?;
222    /// store.migrate().await?;
223    /// # store.clear_store().await?;
224    /// let mut session = Session::new();
225    /// session.set_expiry(Utc::now() - Duration::seconds(5));
226    /// store.store_session(session).await?;
227    /// assert_eq!(store.count().await?, 1);
228    /// store.cleanup().await?;
229    /// assert_eq!(store.count().await?, 0);
230    /// # Ok(()) }) }
231    /// ```
232    pub async fn cleanup(&self) -> sqlx::Result<()> {
233        let mut connection = self.connection().await?;
234        sqlx::query(&self.substitute_table_name("DELETE FROM %%TABLE_NAME%% WHERE expires < $1"))
235            .bind(Utc::now())
236            .execute(&mut connection)
237            .await?;
238
239        Ok(())
240    }
241
242    /// retrieves the number of sessions currently stored, including
243    /// expired sessions
244    ///
245    /// ```rust
246    /// # use async_sqlx_session::PostgresSessionStore;
247    /// # use async_session::{Result, SessionStore, Session};
248    /// # use std::time::Duration;
249    /// # fn main() -> Result { async_std::task::block_on(async {
250    /// let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap()).await?;
251    /// store.migrate().await?;
252    /// # store.clear_store().await?;
253    /// assert_eq!(store.count().await?, 0);
254    /// store.store_session(Session::new()).await?;
255    /// assert_eq!(store.count().await?, 1);
256    /// # Ok(()) }) }
257    /// ```
258
259    pub async fn count(&self) -> sqlx::Result<i64> {
260        let (count,) =
261            sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%"))
262                .fetch_one(&mut self.connection().await?)
263                .await?;
264
265        Ok(count)
266    }
267}
268
269#[async_trait]
270impl SessionStore for PostgresSessionStore {
271    async fn load_session(&self, cookie_value: String) -> Result<Option<Session>> {
272        let id = Session::id_from_cookie_value(&cookie_value)?;
273        let mut connection = self.connection().await?;
274
275        let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name(
276            "SELECT session FROM %%TABLE_NAME%% WHERE id = $1 AND (expires IS NULL OR expires > $2)"
277        ))
278        .bind(&id)
279        .bind(Utc::now())
280        .fetch_optional(&mut connection)
281        .await?;
282
283        Ok(result
284            .map(|(session,)| serde_json::from_str(&session))
285            .transpose()?)
286    }
287
288    async fn store_session(&self, session: Session) -> Result<Option<String>> {
289        let id = session.id();
290        let string = serde_json::to_string(&session)?;
291        let mut connection = self.connection().await?;
292
293        sqlx::query(&self.substitute_table_name(
294            r#"
295            INSERT INTO %%TABLE_NAME%%
296              (id, session, expires) SELECT $1, $2, $3
297            ON CONFLICT(id) DO UPDATE SET
298              expires = EXCLUDED.expires,
299              session = EXCLUDED.session
300            "#,
301        ))
302        .bind(&id)
303        .bind(&string)
304        .bind(&session.expiry())
305        .execute(&mut connection)
306        .await?;
307
308        Ok(session.into_cookie_value())
309    }
310
311    async fn destroy_session(&self, session: Session) -> Result {
312        let id = session.id();
313        let mut connection = self.connection().await?;
314        sqlx::query(&self.substitute_table_name("DELETE FROM %%TABLE_NAME%% WHERE id = $1"))
315            .bind(&id)
316            .execute(&mut connection)
317            .await?;
318
319        Ok(())
320    }
321
322    async fn clear_store(&self) -> Result {
323        let mut connection = self.connection().await?;
324        sqlx::query(&self.substitute_table_name("TRUNCATE %%TABLE_NAME%%"))
325            .execute(&mut connection)
326            .await?;
327
328        Ok(())
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use async_session::chrono::DateTime;
336    use std::time::Duration;
337
338    async fn test_store() -> PostgresSessionStore {
339        let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap())
340            .await
341            .expect("building a PostgresSessionStore");
342
343        store
344            .migrate()
345            .await
346            .expect("migrating a PostgresSessionStore");
347
348        store.clear_store().await.expect("clearing");
349
350        store
351    }
352
353    #[async_std::test]
354    async fn creating_a_new_session_with_no_expiry() -> Result {
355        let store = test_store().await;
356        let mut session = Session::new();
357        session.insert("key", "value")?;
358        let cloned = session.clone();
359        let cookie_value = store.store_session(session).await?.unwrap();
360
361        let (id, expires, serialized, count): (String, Option<DateTime<Utc>>, String, i64) =
362            sqlx::query_as("select id, expires, session, (select count(*) from async_sessions) from async_sessions")
363                .fetch_one(&mut store.connection().await?)
364                .await?;
365
366        assert_eq!(1, count);
367        assert_eq!(id, cloned.id());
368        assert_eq!(expires, None);
369
370        let deserialized_session: Session = serde_json::from_str(&serialized)?;
371        assert_eq!(cloned.id(), deserialized_session.id());
372        assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
373
374        let loaded_session = store.load_session(cookie_value).await?.unwrap();
375        assert_eq!(cloned.id(), loaded_session.id());
376        assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
377
378        assert!(!loaded_session.is_expired());
379        Ok(())
380    }
381
382    #[async_std::test]
383    async fn updating_a_session() -> Result {
384        let store = test_store().await;
385        let mut session = Session::new();
386        let original_id = session.id().to_owned();
387
388        session.insert("key", "value")?;
389        let cookie_value = store.store_session(session).await?.unwrap();
390
391        let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
392        session.insert("key", "other value")?;
393        assert_eq!(None, store.store_session(session).await?);
394
395        let session = store.load_session(cookie_value.clone()).await?.unwrap();
396        assert_eq!(session.get::<String>("key").unwrap(), "other value");
397
398        let (id, count): (String, i64) =
399            sqlx::query_as("select id, (select count(*) from async_sessions) from async_sessions")
400                .fetch_one(&mut store.connection().await?)
401                .await?;
402
403        assert_eq!(1, count);
404        assert_eq!(original_id, id);
405
406        Ok(())
407    }
408
409    #[async_std::test]
410    async fn updating_a_session_extending_expiry() -> Result {
411        let store = test_store().await;
412        let mut session = Session::new();
413        session.expire_in(Duration::from_secs(10));
414        let original_id = session.id().to_owned();
415        let original_expires = session.expiry().unwrap().clone();
416        let cookie_value = store.store_session(session).await?.unwrap();
417
418        let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
419        assert_eq!(session.expiry().unwrap(), &original_expires);
420        session.expire_in(Duration::from_secs(20));
421        let new_expires = session.expiry().unwrap().clone();
422        store.store_session(session).await?;
423
424        let session = store.load_session(cookie_value.clone()).await?.unwrap();
425        assert_eq!(session.expiry().unwrap(), &new_expires);
426
427        let (id, expires, count): (String, DateTime<Utc>, i64) = sqlx::query_as(
428            "select id, expires, (select count(*) from async_sessions) from async_sessions",
429        )
430        .fetch_one(&mut store.connection().await?)
431        .await?;
432
433        assert_eq!(1, count);
434        assert_eq!(expires.timestamp_millis(), new_expires.timestamp_millis());
435        assert_eq!(original_id, id);
436
437        Ok(())
438    }
439
440    #[async_std::test]
441    async fn creating_a_new_session_with_expiry() -> Result {
442        let store = test_store().await;
443        let mut session = Session::new();
444        session.expire_in(Duration::from_secs(1));
445        session.insert("key", "value")?;
446        let cloned = session.clone();
447
448        let cookie_value = store.store_session(session).await?.unwrap();
449
450        let (id, expires, serialized, count): (String, Option<DateTime<Utc>>, String, i64) =
451            sqlx::query_as("select id, expires, session, (select count(*) from async_sessions) from async_sessions")
452                .fetch_one(&mut store.connection().await?)
453                .await?;
454
455        assert_eq!(1, count);
456        assert_eq!(id, cloned.id());
457        assert!(expires.unwrap() > Utc::now());
458
459        let deserialized_session: Session = serde_json::from_str(&serialized)?;
460        assert_eq!(cloned.id(), deserialized_session.id());
461        assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
462
463        let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap();
464        assert_eq!(cloned.id(), loaded_session.id());
465        assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
466
467        assert!(!loaded_session.is_expired());
468
469        async_std::task::sleep(Duration::from_secs(1)).await;
470        assert_eq!(None, store.load_session(cookie_value).await?);
471
472        Ok(())
473    }
474
475    #[async_std::test]
476    async fn destroying_a_single_session() -> Result {
477        let store = test_store().await;
478        for _ in 0..3i8 {
479            store.store_session(Session::new()).await?;
480        }
481
482        let cookie = store.store_session(Session::new()).await?.unwrap();
483        assert_eq!(4, store.count().await?);
484        let session = store.load_session(cookie.clone()).await?.unwrap();
485        store.destroy_session(session.clone()).await.unwrap();
486        assert_eq!(None, store.load_session(cookie).await?);
487        assert_eq!(3, store.count().await?);
488
489        // // attempting to destroy the session again is not an error
490        assert!(store.destroy_session(session).await.is_ok());
491        Ok(())
492    }
493
494    #[async_std::test]
495    async fn clearing_the_whole_store() -> Result {
496        let store = test_store().await;
497        for _ in 0..3i8 {
498            store.store_session(Session::new()).await?;
499        }
500
501        assert_eq!(3, store.count().await?);
502        store.clear_store().await.unwrap();
503        assert_eq!(0, store.count().await?);
504
505        Ok(())
506    }
507}