1use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore};
2use sqlx::{pool::PoolConnection, Executor, PgPool, Postgres};
3
4#[derive(Clone, Debug)]
28pub struct PostgresSessionStore {
29 client: PgPool,
30 table_name: String,
31}
32
33impl PostgresSessionStore {
34 pub fn from_client(client: PgPool) -> Self {
50 Self {
51 client,
52 table_name: "async_sessions".into(),
53 }
54 }
55
56 pub async fn new(database_url: &str) -> sqlx::Result<Self> {
72 let pool = PgPool::connect(database_url).await?;
73 Ok(Self::from_client(pool))
74 }
75
76 pub async fn new_with_table_name(database_url: &str, table_name: &str) -> sqlx::Result<Self> {
92 Ok(Self::new(database_url).await?.with_table_name(table_name))
93 }
94
95 pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Self {
116 let table_name = table_name.as_ref();
117 if table_name.is_empty()
118 || !table_name
119 .chars()
120 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
121 {
122 panic!(
123 "table name must be [a-zA-Z0-9_-]+, but {} was not",
124 table_name
125 );
126 }
127
128 self.table_name = table_name.to_owned();
129 self
130 }
131
132 pub async fn migrate(&self) -> sqlx::Result<()> {
150 log::info!("migrating sessions on `{}`", self.table_name);
151
152 let mut conn = self.client.acquire().await?;
153 conn.execute(&*self.substitute_table_name(
154 r#"
155 CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% (
156 "id" VARCHAR NOT NULL PRIMARY KEY,
157 "expires" TIMESTAMP WITH TIME ZONE NULL,
158 "session" TEXT NOT NULL
159 )
160 "#,
161 ))
162 .await?;
163
164 Ok(())
165 }
166
167 fn substitute_table_name(&self, query: &str) -> String {
168 query.replace("%%TABLE_NAME%%", &self.table_name)
169 }
170
171 async fn connection(&self) -> sqlx::Result<PoolConnection<Postgres>> {
173 self.client.acquire().await
174 }
175
176 #[cfg(feature = "async_std")]
199 pub fn spawn_cleanup_task(
200 &self,
201 period: std::time::Duration,
202 ) -> async_std::task::JoinHandle<()> {
203 use async_std::task;
204 let store = self.clone();
205 task::spawn(async move {
206 loop {
207 task::sleep(period).await;
208 if let Err(error) = store.cleanup().await {
209 log::error!("cleanup error: {}", error);
210 }
211 }
212 })
213 }
214
215 pub async fn cleanup(&self) -> sqlx::Result<()> {
233 let mut connection = self.connection().await?;
234 sqlx::query(&self.substitute_table_name("DELETE FROM %%TABLE_NAME%% WHERE expires < $1"))
235 .bind(Utc::now())
236 .execute(&mut connection)
237 .await?;
238
239 Ok(())
240 }
241
242 pub async fn count(&self) -> sqlx::Result<i64> {
260 let (count,) =
261 sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%"))
262 .fetch_one(&mut self.connection().await?)
263 .await?;
264
265 Ok(count)
266 }
267}
268
269#[async_trait]
270impl SessionStore for PostgresSessionStore {
271 async fn load_session(&self, cookie_value: String) -> Result<Option<Session>> {
272 let id = Session::id_from_cookie_value(&cookie_value)?;
273 let mut connection = self.connection().await?;
274
275 let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name(
276 "SELECT session FROM %%TABLE_NAME%% WHERE id = $1 AND (expires IS NULL OR expires > $2)"
277 ))
278 .bind(&id)
279 .bind(Utc::now())
280 .fetch_optional(&mut connection)
281 .await?;
282
283 Ok(result
284 .map(|(session,)| serde_json::from_str(&session))
285 .transpose()?)
286 }
287
288 async fn store_session(&self, session: Session) -> Result<Option<String>> {
289 let id = session.id();
290 let string = serde_json::to_string(&session)?;
291 let mut connection = self.connection().await?;
292
293 sqlx::query(&self.substitute_table_name(
294 r#"
295 INSERT INTO %%TABLE_NAME%%
296 (id, session, expires) SELECT $1, $2, $3
297 ON CONFLICT(id) DO UPDATE SET
298 expires = EXCLUDED.expires,
299 session = EXCLUDED.session
300 "#,
301 ))
302 .bind(&id)
303 .bind(&string)
304 .bind(&session.expiry())
305 .execute(&mut connection)
306 .await?;
307
308 Ok(session.into_cookie_value())
309 }
310
311 async fn destroy_session(&self, session: Session) -> Result {
312 let id = session.id();
313 let mut connection = self.connection().await?;
314 sqlx::query(&self.substitute_table_name("DELETE FROM %%TABLE_NAME%% WHERE id = $1"))
315 .bind(&id)
316 .execute(&mut connection)
317 .await?;
318
319 Ok(())
320 }
321
322 async fn clear_store(&self) -> Result {
323 let mut connection = self.connection().await?;
324 sqlx::query(&self.substitute_table_name("TRUNCATE %%TABLE_NAME%%"))
325 .execute(&mut connection)
326 .await?;
327
328 Ok(())
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use async_session::chrono::DateTime;
336 use std::time::Duration;
337
338 async fn test_store() -> PostgresSessionStore {
339 let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap())
340 .await
341 .expect("building a PostgresSessionStore");
342
343 store
344 .migrate()
345 .await
346 .expect("migrating a PostgresSessionStore");
347
348 store.clear_store().await.expect("clearing");
349
350 store
351 }
352
353 #[async_std::test]
354 async fn creating_a_new_session_with_no_expiry() -> Result {
355 let store = test_store().await;
356 let mut session = Session::new();
357 session.insert("key", "value")?;
358 let cloned = session.clone();
359 let cookie_value = store.store_session(session).await?.unwrap();
360
361 let (id, expires, serialized, count): (String, Option<DateTime<Utc>>, String, i64) =
362 sqlx::query_as("select id, expires, session, (select count(*) from async_sessions) from async_sessions")
363 .fetch_one(&mut store.connection().await?)
364 .await?;
365
366 assert_eq!(1, count);
367 assert_eq!(id, cloned.id());
368 assert_eq!(expires, None);
369
370 let deserialized_session: Session = serde_json::from_str(&serialized)?;
371 assert_eq!(cloned.id(), deserialized_session.id());
372 assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
373
374 let loaded_session = store.load_session(cookie_value).await?.unwrap();
375 assert_eq!(cloned.id(), loaded_session.id());
376 assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
377
378 assert!(!loaded_session.is_expired());
379 Ok(())
380 }
381
382 #[async_std::test]
383 async fn updating_a_session() -> Result {
384 let store = test_store().await;
385 let mut session = Session::new();
386 let original_id = session.id().to_owned();
387
388 session.insert("key", "value")?;
389 let cookie_value = store.store_session(session).await?.unwrap();
390
391 let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
392 session.insert("key", "other value")?;
393 assert_eq!(None, store.store_session(session).await?);
394
395 let session = store.load_session(cookie_value.clone()).await?.unwrap();
396 assert_eq!(session.get::<String>("key").unwrap(), "other value");
397
398 let (id, count): (String, i64) =
399 sqlx::query_as("select id, (select count(*) from async_sessions) from async_sessions")
400 .fetch_one(&mut store.connection().await?)
401 .await?;
402
403 assert_eq!(1, count);
404 assert_eq!(original_id, id);
405
406 Ok(())
407 }
408
409 #[async_std::test]
410 async fn updating_a_session_extending_expiry() -> Result {
411 let store = test_store().await;
412 let mut session = Session::new();
413 session.expire_in(Duration::from_secs(10));
414 let original_id = session.id().to_owned();
415 let original_expires = session.expiry().unwrap().clone();
416 let cookie_value = store.store_session(session).await?.unwrap();
417
418 let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
419 assert_eq!(session.expiry().unwrap(), &original_expires);
420 session.expire_in(Duration::from_secs(20));
421 let new_expires = session.expiry().unwrap().clone();
422 store.store_session(session).await?;
423
424 let session = store.load_session(cookie_value.clone()).await?.unwrap();
425 assert_eq!(session.expiry().unwrap(), &new_expires);
426
427 let (id, expires, count): (String, DateTime<Utc>, i64) = sqlx::query_as(
428 "select id, expires, (select count(*) from async_sessions) from async_sessions",
429 )
430 .fetch_one(&mut store.connection().await?)
431 .await?;
432
433 assert_eq!(1, count);
434 assert_eq!(expires.timestamp_millis(), new_expires.timestamp_millis());
435 assert_eq!(original_id, id);
436
437 Ok(())
438 }
439
440 #[async_std::test]
441 async fn creating_a_new_session_with_expiry() -> Result {
442 let store = test_store().await;
443 let mut session = Session::new();
444 session.expire_in(Duration::from_secs(1));
445 session.insert("key", "value")?;
446 let cloned = session.clone();
447
448 let cookie_value = store.store_session(session).await?.unwrap();
449
450 let (id, expires, serialized, count): (String, Option<DateTime<Utc>>, String, i64) =
451 sqlx::query_as("select id, expires, session, (select count(*) from async_sessions) from async_sessions")
452 .fetch_one(&mut store.connection().await?)
453 .await?;
454
455 assert_eq!(1, count);
456 assert_eq!(id, cloned.id());
457 assert!(expires.unwrap() > Utc::now());
458
459 let deserialized_session: Session = serde_json::from_str(&serialized)?;
460 assert_eq!(cloned.id(), deserialized_session.id());
461 assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
462
463 let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap();
464 assert_eq!(cloned.id(), loaded_session.id());
465 assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
466
467 assert!(!loaded_session.is_expired());
468
469 async_std::task::sleep(Duration::from_secs(1)).await;
470 assert_eq!(None, store.load_session(cookie_value).await?);
471
472 Ok(())
473 }
474
475 #[async_std::test]
476 async fn destroying_a_single_session() -> Result {
477 let store = test_store().await;
478 for _ in 0..3i8 {
479 store.store_session(Session::new()).await?;
480 }
481
482 let cookie = store.store_session(Session::new()).await?.unwrap();
483 assert_eq!(4, store.count().await?);
484 let session = store.load_session(cookie.clone()).await?.unwrap();
485 store.destroy_session(session.clone()).await.unwrap();
486 assert_eq!(None, store.load_session(cookie).await?);
487 assert_eq!(3, store.count().await?);
488
489 assert!(store.destroy_session(session).await.is_ok());
491 Ok(())
492 }
493
494 #[async_std::test]
495 async fn clearing_the_whole_store() -> Result {
496 let store = test_store().await;
497 for _ in 0..3i8 {
498 store.store_session(Session::new()).await?;
499 }
500
501 assert_eq!(3, store.count().await?);
502 store.clear_store().await.unwrap();
503 assert_eq!(0, store.count().await?);
504
505 Ok(())
506 }
507}