use async_trait::async_trait;
use chrono::{DateTime, Utc};
use reinhardt_utils::cache::{Cache, InMemoryCache};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use thiserror::Error;
use crate::sessions::cleanup::{CleanupableBackend, SessionMetadata};
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum SessionError {
#[error("Cache error: {0}")]
CacheError(String),
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("Session has expired due to inactivity")]
SessionExpired,
}
#[async_trait]
pub trait SessionBackend: Send + Sync + Clone {
async fn load<T>(&self, session_key: &str) -> Result<Option<T>, SessionError>
where
T: for<'de> Deserialize<'de> + Serialize + Send + Sync;
async fn save<T>(
&self,
session_key: &str,
data: &T,
ttl: Option<u64>,
) -> Result<(), SessionError>
where
T: Serialize + Send + Sync;
async fn delete(&self, session_key: &str) -> Result<(), SessionError>;
async fn exists(&self, session_key: &str) -> Result<bool, SessionError>;
}
#[derive(Clone)]
pub struct InMemorySessionBackend {
cache: Arc<InMemoryCache>,
}
impl InMemorySessionBackend {
pub fn new() -> Self {
Self {
cache: Arc::new(InMemoryCache::new()),
}
}
}
impl Default for InMemorySessionBackend {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SessionBackend for InMemorySessionBackend {
async fn load<T>(&self, session_key: &str) -> Result<Option<T>, SessionError>
where
T: for<'de> Deserialize<'de> + Serialize + Send + Sync,
{
self.cache
.get(session_key)
.await
.map_err(|e| SessionError::CacheError(e.to_string()))
}
async fn save<T>(
&self,
session_key: &str,
data: &T,
ttl: Option<u64>,
) -> Result<(), SessionError>
where
T: Serialize + Send + Sync,
{
let duration = ttl.map(std::time::Duration::from_secs);
self.cache
.set(session_key, data, duration)
.await
.map_err(|e| SessionError::CacheError(e.to_string()))
}
async fn delete(&self, session_key: &str) -> Result<(), SessionError> {
self.cache
.delete(session_key)
.await
.map_err(|e| SessionError::CacheError(e.to_string()))
}
async fn exists(&self, session_key: &str) -> Result<bool, SessionError> {
self.cache
.has_key(session_key)
.await
.map_err(|e| SessionError::CacheError(e.to_string()))
}
}
#[async_trait]
impl CleanupableBackend for InMemorySessionBackend {
async fn get_all_keys(&self) -> Result<Vec<String>, SessionError> {
Ok(self.cache.list_keys().await)
}
async fn get_metadata(
&self,
session_key: &str,
) -> Result<Option<SessionMetadata>, SessionError> {
match self.cache.inspect_entry_with_timestamps(session_key).await {
Ok(Some((created, accessed))) => Ok(Some(SessionMetadata {
created_at: DateTime::<Utc>::from(created),
last_accessed: accessed.map(DateTime::<Utc>::from),
})),
Ok(None) => Ok(None),
Err(e) => Err(SessionError::CacheError(e.to_string())),
}
}
}
#[derive(Clone)]
pub struct CacheSessionBackend<C: Cache + Clone> {
cache: Arc<C>,
}
impl<C: Cache + Clone> CacheSessionBackend<C> {
pub fn new(cache: Arc<C>) -> Self {
Self { cache }
}
}
#[async_trait]
impl<C: Cache + Clone + 'static> SessionBackend for CacheSessionBackend<C> {
async fn load<T>(&self, session_key: &str) -> Result<Option<T>, SessionError>
where
T: for<'de> Deserialize<'de> + Serialize + Send + Sync,
{
self.cache
.get(session_key)
.await
.map_err(|e| SessionError::CacheError(e.to_string()))
}
async fn save<T>(
&self,
session_key: &str,
data: &T,
ttl: Option<u64>,
) -> Result<(), SessionError>
where
T: Serialize + Send + Sync,
{
let duration = ttl.map(std::time::Duration::from_secs);
self.cache
.set(session_key, data, duration)
.await
.map_err(|e| SessionError::CacheError(e.to_string()))
}
async fn delete(&self, session_key: &str) -> Result<(), SessionError> {
self.cache
.delete(session_key)
.await
.map_err(|e| SessionError::CacheError(e.to_string()))
}
async fn exists(&self, session_key: &str) -> Result<bool, SessionError> {
self.cache
.has_key(session_key)
.await
.map_err(|e| SessionError::CacheError(e.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use serde_json::json;
use std::collections::HashMap;
#[rstest]
#[tokio::test]
async fn test_in_memory_save_and_load_roundtrip() {
let backend = InMemorySessionBackend::new();
let mut data = HashMap::new();
data.insert("user_id".to_string(), json!(42));
data.insert("username".to_string(), json!("alice"));
backend.save("sess_1", &data, Some(3600)).await.unwrap();
let loaded: Option<HashMap<String, serde_json::Value>> =
backend.load("sess_1").await.unwrap();
let loaded = loaded.unwrap();
assert_eq!(loaded["user_id"], json!(42));
assert_eq!(loaded["username"], json!("alice"));
}
#[rstest]
#[tokio::test]
async fn test_in_memory_load_nonexistent_key() {
let backend = InMemorySessionBackend::new();
let loaded: Option<serde_json::Value> = backend.load("nonexistent").await.unwrap();
assert!(loaded.is_none());
}
#[rstest]
#[tokio::test]
async fn test_in_memory_delete_removes_session() {
let backend = InMemorySessionBackend::new();
let data = json!({"key": "value"});
backend.save("sess_del", &data, Some(3600)).await.unwrap();
backend.delete("sess_del").await.unwrap();
let loaded: Option<serde_json::Value> = backend.load("sess_del").await.unwrap();
assert!(loaded.is_none());
}
#[rstest]
#[tokio::test]
async fn test_in_memory_exists_reflects_state() {
let backend = InMemorySessionBackend::new();
let data = json!({"active": true});
assert!(!backend.exists("sess_ex").await.unwrap());
backend.save("sess_ex", &data, Some(3600)).await.unwrap();
assert!(backend.exists("sess_ex").await.unwrap());
backend.delete("sess_ex").await.unwrap();
assert!(!backend.exists("sess_ex").await.unwrap());
}
#[rstest]
#[tokio::test]
async fn test_in_memory_save_overwrites_existing() {
let backend = InMemorySessionBackend::new();
let data_v1 = json!({"version": 1});
let data_v2 = json!({"version": 2});
backend.save("sess_ow", &data_v1, Some(3600)).await.unwrap();
backend.save("sess_ow", &data_v2, Some(3600)).await.unwrap();
let loaded: Option<serde_json::Value> = backend.load("sess_ow").await.unwrap();
assert_eq!(loaded.unwrap()["version"], 2);
}
#[rstest]
#[tokio::test]
async fn test_in_memory_save_with_ttl() {
let backend = InMemorySessionBackend::new();
let data = json!({"ttl_test": true});
backend.save("sess_ttl", &data, Some(60)).await.unwrap();
let loaded: Option<serde_json::Value> = backend.load("sess_ttl").await.unwrap();
assert_eq!(loaded.unwrap()["ttl_test"], true);
}
#[rstest]
#[tokio::test]
async fn test_cache_backend_wrapper_save_and_load() {
let cache = Arc::new(InMemoryCache::new());
let backend = CacheSessionBackend::new(cache);
let data = json!({"wrapped": "value", "count": 99});
backend
.save("wrapped_sess", &data, Some(3600))
.await
.unwrap();
let loaded: Option<serde_json::Value> = backend.load("wrapped_sess").await.unwrap();
let loaded = loaded.unwrap();
assert_eq!(loaded["wrapped"], "value");
assert_eq!(loaded["count"], 99);
}
#[rstest]
#[tokio::test]
async fn test_cache_backend_wrapper_delete_and_exists() {
let cache = Arc::new(InMemoryCache::new());
let backend = CacheSessionBackend::new(cache);
let data = json!({"item": "to_delete"});
backend.save("wrap_del", &data, Some(3600)).await.unwrap();
assert!(backend.exists("wrap_del").await.unwrap());
backend.delete("wrap_del").await.unwrap();
assert!(!backend.exists("wrap_del").await.unwrap());
let loaded: Option<serde_json::Value> = backend.load("wrap_del").await.unwrap();
assert!(loaded.is_none());
}
}