use async_session::{async_trait, chrono::Utc, log, serde_json, Result, Session, SessionStore};
use async_std::task;
use sqlx::{pool::PoolConnection, prelude::PgQueryAs, Executor, PgConnection, PgPool};
use std::time::Duration;
#[derive(Clone, Debug)]
pub struct PostgresSessionStore {
client: PgPool,
table_name: String,
}
impl PostgresSessionStore {
pub fn from_client(client: PgPool) -> Self {
Self {
client,
table_name: "async_sessions".into(),
}
}
pub async fn new(database_url: &str) -> sqlx::Result<Self> {
let pool = PgPool::new(database_url).await?;
Ok(Self::from_client(pool))
}
pub async fn new_with_table_name(database_url: &str, table_name: &str) -> sqlx::Result<Self> {
Ok(Self::new(database_url).await?.with_table_name(table_name))
}
pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Self {
let table_name = table_name.as_ref();
if table_name.is_empty()
|| !table_name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
panic!(
"table name must be [a-zA-Z0-9_-]+, but {} was not",
table_name
);
}
self.table_name = table_name.to_owned();
self
}
pub async fn migrate(&self) -> sqlx::Result<()> {
log::info!("migrating sessions on `{}`", self.table_name);
let mut conn = self.client.acquire().await?;
conn.execute(&*self.substitute_table_name(
r#"
CREATE TABLE IF NOT EXISTS %%TABLE_NAME%% (
"id" VARCHAR NOT NULL PRIMARY KEY,
"expires" TIMESTAMP WITH TIME ZONE NULL,
"session" TEXT NOT NULL
)
"#,
))
.await?;
Ok(())
}
fn substitute_table_name(&self, query: &str) -> String {
query.replace("%%TABLE_NAME%%", &self.table_name)
}
async fn connection(&self) -> sqlx::Result<PoolConnection<PgConnection>> {
self.client.acquire().await
}
pub fn spawn_cleanup_task(&self, period: Duration) -> task::JoinHandle<()> {
let store = self.clone();
task::spawn(async move {
loop {
task::sleep(period).await;
if let Err(error) = store.cleanup().await {
log::error!("cleanup error: {}", error);
}
}
})
}
pub async fn cleanup(&self) -> sqlx::Result<()> {
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name("DELETE FROM %%TABLE_NAME%% WHERE expires < $1"))
.bind(Utc::now())
.execute(&mut connection)
.await?;
Ok(())
}
pub async fn count(&self) -> sqlx::Result<i64> {
let (count,) =
sqlx::query_as(&self.substitute_table_name("SELECT COUNT(*) FROM %%TABLE_NAME%%"))
.fetch_one(&mut self.connection().await?)
.await?;
Ok(count)
}
}
#[async_trait]
impl SessionStore for PostgresSessionStore {
async fn load_session(&self, cookie_value: String) -> Result<Option<Session>> {
let id = Session::id_from_cookie_value(&cookie_value)?;
let mut connection = self.connection().await?;
let result: Option<(String,)> = sqlx::query_as(&self.substitute_table_name(
"SELECT session FROM %%TABLE_NAME%% WHERE id = $1 AND (expires IS NULL OR expires > $2)"
))
.bind(&id)
.bind(Utc::now())
.fetch_optional(&mut connection)
.await?;
Ok(result
.map(|(session,)| serde_json::from_str(&session))
.transpose()?)
}
async fn store_session(&self, session: Session) -> Result<Option<String>> {
let id = session.id();
let string = serde_json::to_string(&session)?;
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name(
r#"
INSERT INTO %%TABLE_NAME%%
(id, session, expires) SELECT $1, $2, $3
ON CONFLICT(id) DO UPDATE SET
expires = EXCLUDED.expires,
session = EXCLUDED.session
"#,
))
.bind(&id)
.bind(&string)
.bind(&session.expiry())
.execute(&mut connection)
.await?;
Ok(session.into_cookie_value())
}
async fn destroy_session(&self, session: Session) -> Result {
let id = session.id();
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name("DELETE FROM %%TABLE_NAME%% WHERE id = $1"))
.bind(&id)
.execute(&mut connection)
.await?;
Ok(())
}
async fn clear_store(&self) -> Result {
let mut connection = self.connection().await?;
sqlx::query(&self.substitute_table_name("TRUNCATE %%TABLE_NAME%%"))
.execute(&mut connection)
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_session::chrono::DateTime;
async fn test_store() -> PostgresSessionStore {
let store = PostgresSessionStore::new(&std::env::var("PG_TEST_DB_URL").unwrap())
.await
.expect("building a PostgresSessionStore");
store
.migrate()
.await
.expect("migrating a PostgresSessionStore");
store.clear_store().await.expect("clearing");
store
}
#[async_std::test]
async fn creating_a_new_session_with_no_expiry() -> Result {
let store = test_store().await;
let mut session = Session::new();
session.insert("key", "value")?;
let cloned = session.clone();
let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<DateTime<Utc>>, String, i64) =
sqlx::query_as("select id, expires, session, (select count(*) from async_sessions) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(id, cloned.id());
assert_eq!(expires, None);
let deserialized_session: Session = serde_json::from_str(&serialized)?;
assert_eq!(cloned.id(), deserialized_session.id());
assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
let loaded_session = store.load_session(cookie_value).await?.unwrap();
assert_eq!(cloned.id(), loaded_session.id());
assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
assert!(!loaded_session.is_expired());
Ok(())
}
#[async_std::test]
async fn updating_a_session() -> Result {
let store = test_store().await;
let mut session = Session::new();
let original_id = session.id().to_owned();
session.insert("key", "value")?;
let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
session.insert("key", "other value")?;
assert_eq!(None, store.store_session(session).await?);
let session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(session.get::<String>("key").unwrap(), "other value");
let (id, count): (String, i64) =
sqlx::query_as("select id, (select count(*) from async_sessions) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(original_id, id);
Ok(())
}
#[async_std::test]
async fn updating_a_session_extending_expiry() -> Result {
let store = test_store().await;
let mut session = Session::new();
session.expire_in(Duration::from_secs(10));
let original_id = session.id().to_owned();
let original_expires = session.expiry().unwrap().clone();
let cookie_value = store.store_session(session).await?.unwrap();
let mut session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(session.expiry().unwrap(), &original_expires);
session.expire_in(Duration::from_secs(20));
let new_expires = session.expiry().unwrap().clone();
store.store_session(session).await?;
let session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(session.expiry().unwrap(), &new_expires);
let (id, expires, count): (String, DateTime<Utc>, i64) = sqlx::query_as(
"select id, expires, (select count(*) from async_sessions) from async_sessions",
)
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(expires.timestamp_millis(), new_expires.timestamp_millis());
assert_eq!(original_id, id);
Ok(())
}
#[async_std::test]
async fn creating_a_new_session_with_expiry() -> Result {
let store = test_store().await;
let mut session = Session::new();
session.expire_in(Duration::from_secs(1));
session.insert("key", "value")?;
let cloned = session.clone();
let cookie_value = store.store_session(session).await?.unwrap();
let (id, expires, serialized, count): (String, Option<DateTime<Utc>>, String, i64) =
sqlx::query_as("select id, expires, session, (select count(*) from async_sessions) from async_sessions")
.fetch_one(&mut store.connection().await?)
.await?;
assert_eq!(1, count);
assert_eq!(id, cloned.id());
assert!(expires.unwrap() > Utc::now());
let deserialized_session: Session = serde_json::from_str(&serialized)?;
assert_eq!(cloned.id(), deserialized_session.id());
assert_eq!("value", &deserialized_session.get::<String>("key").unwrap());
let loaded_session = store.load_session(cookie_value.clone()).await?.unwrap();
assert_eq!(cloned.id(), loaded_session.id());
assert_eq!("value", &loaded_session.get::<String>("key").unwrap());
assert!(!loaded_session.is_expired());
task::sleep(Duration::from_secs(1)).await;
assert_eq!(None, store.load_session(cookie_value).await?);
Ok(())
}
#[async_std::test]
async fn destroying_a_single_session() -> Result {
let store = test_store().await;
for _ in 0..3i8 {
store.store_session(Session::new()).await?;
}
let cookie = store.store_session(Session::new()).await?.unwrap();
assert_eq!(4, store.count().await?);
let session = store.load_session(cookie.clone()).await?.unwrap();
store.destroy_session(session.clone()).await.unwrap();
assert_eq!(None, store.load_session(cookie).await?);
assert_eq!(3, store.count().await?);
assert!(store.destroy_session(session).await.is_ok());
Ok(())
}
#[async_std::test]
async fn clearing_the_whole_store() -> Result {
let store = test_store().await;
for _ in 0..3i8 {
store.store_session(Session::new()).await?;
}
assert_eq!(3, store.count().await?);
store.clear_store().await.unwrap();
assert_eq!(0, store.count().await?);
Ok(())
}
}