use std::error::Error;
use async_trait::async_trait;
use deadpool_redis::{Config, Pool as RedisPool, Runtime};
use redis::{AsyncCommands, ExistenceCheck, SetExpiry, SetOptions};
use thiserror::Error;
use time::OffsetDateTime;
use tower_sessions::session::{Id, Record};
use tower_sessions::{SessionStore, session_store};
use crate::config::CacheUrl;
use crate::session::store::{ERROR_PREFIX, MAX_COLLISION_RETRIES};
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum RedisStoreError {
#[error("{ERROR_PREFIX} pool connection error: {0}")]
PoolConnection(Box<dyn Error + Send + Sync>),
#[error("{ERROR_PREFIX} pool creation error: {0}")]
PoolCreation(Box<dyn Error + Send + Sync>),
#[error("{ERROR_PREFIX} command error: {0}")]
Command(Box<dyn Error + Send + Sync>),
#[error("{ERROR_PREFIX} session-id collision retried too many times ({0})")]
TooManyIdCollisions(u32),
#[error("{ERROR_PREFIX} serialization error: {0}")]
Serialize(Box<dyn Error + Send + Sync>),
#[error("{ERROR_PREFIX} deserialization error: {0}")]
Deserialize(Box<dyn Error + Send + Sync>),
}
impl From<RedisStoreError> for session_store::Error {
fn from(err: RedisStoreError) -> session_store::Error {
match err {
RedisStoreError::PoolConnection(inner) | RedisStoreError::PoolCreation(inner) => {
session_store::Error::Backend(inner.to_string())
}
RedisStoreError::Command(inner) => session_store::Error::Backend(inner.to_string()),
RedisStoreError::Serialize(inner) => session_store::Error::Encode(inner.to_string()),
RedisStoreError::Deserialize(inner) => session_store::Error::Decode(inner.to_string()),
other => session_store::Error::Backend(other.to_string()),
}
}
}
#[derive(Debug, Clone)]
pub struct RedisStore {
pool: RedisPool,
}
impl RedisStore {
pub fn new(url: &CacheUrl) -> Result<RedisStore, RedisStoreError> {
let cfg = Config::from_url(url.as_str());
let pool = cfg
.create_pool(Some(Runtime::Tokio1))
.map_err(|err| RedisStoreError::PoolCreation(Box::new(err)))?;
Ok(Self { pool })
}
pub async fn get_connection(&self) -> Result<deadpool_redis::Connection, RedisStoreError> {
self.pool
.get()
.await
.map_err(|err| RedisStoreError::PoolConnection(Box::new(err)))
}
}
fn get_expiry_as_u64(expiry: OffsetDateTime) -> u64 {
let now = OffsetDateTime::now_utc();
expiry
.unix_timestamp()
.saturating_sub(now.unix_timestamp())
.max(0)
.unsigned_abs()
}
#[async_trait]
impl SessionStore for RedisStore {
async fn create(&self, session_record: &mut Record) -> session_store::Result<()> {
let mut conn = self.get_connection().await?;
let data: String = serde_json::to_string(&session_record)
.map_err(|err| RedisStoreError::Serialize(Box::new(err)))?;
let options = SetOptions::default()
.conditional_set(ExistenceCheck::NX) .with_expiration(SetExpiry::EX(get_expiry_as_u64(session_record.expiry_date)));
for _ in 0..=MAX_COLLISION_RETRIES {
let key = session_record.id.to_string();
let set_ok: bool = conn
.set_options(key, &data, options)
.await
.map_err(|err| RedisStoreError::Command(Box::new(err)))?;
if set_ok {
return Ok(());
}
session_record.id = Id::default();
}
Err(RedisStoreError::TooManyIdCollisions(MAX_COLLISION_RETRIES))?
}
async fn save(&self, session_record: &Record) -> session_store::Result<()> {
let mut conn = self.get_connection().await?;
let key: String = session_record.id.to_string();
let data: String = serde_json::to_string(&session_record)
.map_err(|err| RedisStoreError::Serialize(Box::new(err)))?;
let options = SetOptions::default()
.conditional_set(ExistenceCheck::XX) .with_expiration(SetExpiry::EX(get_expiry_as_u64(session_record.expiry_date)));
let set_ok: bool = conn
.set_options(key, data, options)
.await
.map_err(|err| RedisStoreError::Command(Box::new(err)))?;
if !set_ok {
let mut record = session_record.clone();
self.create(&mut record).await?;
}
Ok(())
}
async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>> {
let mut conn = self.get_connection().await?;
let key = session_id.to_string();
let data: Option<String> = conn
.get(key)
.await
.map_err(|err| RedisStoreError::Command(Box::new(err)))?;
if let Some(data) = data {
let rec = serde_json::from_str::<Record>(&data)
.map_err(|err| RedisStoreError::Deserialize(Box::new(err)))?;
return Ok(Some(rec));
}
Ok(None)
}
async fn delete(&self, session_id: &Id) -> session_store::Result<()> {
let mut conn = self.get_connection().await?;
let key = session_id.to_string();
conn.del::<_, ()>(key)
.await
.map_err(|err| RedisStoreError::Command(Box::new(err)))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::{env, io};
use time::{Duration, OffsetDateTime};
use tower_sessions::session::{Id, Record};
use super::*;
use crate::config::CacheUrl;
async fn make_store() -> RedisStore {
let redis_url =
env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string());
let url = CacheUrl::from(redis_url);
let store = RedisStore::new(&url).expect("failed to create RedisStore");
store.get_connection().await.expect("get_connection failed");
store
}
fn make_record() -> Record {
Record {
id: Id::default(),
data: HashMap::default(),
expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
}
}
#[cot::test]
#[ignore = "requires external Redis service"]
async fn test_create_and_load() {
let store = make_store().await;
let mut rec = make_record();
store.create(&mut rec).await.expect("create failed");
let loaded = store.load(&rec.id).await.expect("load err");
assert_eq!(Some(rec.clone()), loaded);
}
#[cot::test]
#[ignore = "requires external Redis service"]
async fn test_save_overwrites() {
let store = make_store().await;
let mut rec = make_record();
store.create(&mut rec).await.unwrap();
let mut rec2 = rec.clone();
rec2.data.insert("x".into(), "y".into());
store.save(&rec2).await.expect("save failed");
let loaded = store.load(&rec.id).await.unwrap().unwrap();
assert_eq!(rec2.data, loaded.data);
}
#[cot::test]
#[ignore = "requires external Redis service"]
async fn test_save_creates_if_missing() {
let store = make_store().await;
let rec = make_record();
store.save(&rec).await.expect("save failed");
let loaded = store.load(&rec.id).await.unwrap();
assert_eq!(Some(rec), loaded);
}
#[cot::test]
#[ignore = "requires external Redis service"]
async fn test_delete() {
let store = make_store().await;
let mut rec = make_record();
store.create(&mut rec).await.unwrap();
store.delete(&rec.id).await.expect("delete failed");
let loaded = store.load(&rec.id).await.unwrap();
assert!(loaded.is_none());
store.delete(&rec.id).await.expect("second delete");
}
#[cot::test]
#[ignore = "requires external Redis service"]
async fn test_create_id_collision() {
let store = make_store().await;
let expiry = OffsetDateTime::now_utc() + Duration::minutes(30);
let mut r1 = Record {
id: Id::default(),
data: HashMap::default(),
expiry_date: expiry,
};
store.create(&mut r1).await.unwrap();
let mut r2 = Record {
id: r1.id,
data: HashMap::default(),
expiry_date: expiry,
};
store.create(&mut r2).await.unwrap();
assert_ne!(r1.id, r2.id, "ID collision not resolved");
let loaded1 = store.load(&r1.id).await.unwrap();
let loaded2 = store.load(&r2.id).await.unwrap();
assert!(loaded1.is_some() && loaded2.is_some());
}
#[cot::test]
async fn test_from_redis_store_error_to_session_store_error() {
let pool_err = io::Error::other("pool conn failure");
let sess_err: session_store::Error =
RedisStoreError::PoolConnection(Box::new(pool_err)).into();
assert!(matches!(sess_err, session_store::Error::Backend(_)));
let create_err = io::Error::other("pool creation failure");
let sess_err: session_store::Error =
RedisStoreError::PoolCreation(Box::new(create_err)).into();
assert!(matches!(sess_err, session_store::Error::Backend(_)));
let cmd_err = io::Error::other("redis command failure");
let sess_err: session_store::Error = RedisStoreError::Command(Box::new(cmd_err)).into();
assert!(matches!(sess_err, session_store::Error::Backend(_)));
let ser_err = io::Error::other("serialization oops");
let sess_err: session_store::Error = RedisStoreError::Serialize(Box::new(ser_err)).into();
assert!(matches!(sess_err, session_store::Error::Encode(_)));
let parse_err = serde_json::from_str::<Record>("not a json").unwrap_err();
let sess_err: session_store::Error =
RedisStoreError::Deserialize(Box::new(parse_err)).into();
assert!(matches!(sess_err, session_store::Error::Decode(_)));
let sess_err: session_store::Error = RedisStoreError::TooManyIdCollisions(99).into();
assert!(matches!(sess_err, session_store::Error::Backend(_)));
}
}