use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value;
use crate::cache::{BoxedCache, CacheError};
const KEY_PREFIX: &str = "session";
const DEFAULT_TTL_SECS: u64 = 60 * 60 * 24 * 14; const ID_BYTES: usize = 24;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Session {
data: HashMap<String, Value>,
#[serde(skip)]
dirty: bool,
}
impl Session {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
self.data
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn set<T: Serialize>(&mut self, key: impl Into<String>, value: T) {
if let Ok(v) = serde_json::to_value(value) {
self.data.insert(key.into(), v);
self.dirty = true;
}
}
pub fn remove(&mut self, key: &str) -> Option<Value> {
let prev = self.data.remove(key);
if prev.is_some() {
self.dirty = true;
}
prev
}
pub fn clear(&mut self) {
if !self.data.is_empty() {
self.dirty = true;
}
self.data.clear();
}
#[must_use]
pub fn is_dirty(&self) -> bool {
self.dirty
}
#[must_use]
pub fn len(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[must_use]
pub fn keys(&self) -> impl Iterator<Item = &String> {
self.data.keys()
}
}
#[derive(Debug, thiserror::Error)]
pub enum SessionError {
#[error("cache: {0}")]
Cache(#[from] CacheError),
#[error("session deserialize: {0}")]
Serialization(String),
}
#[derive(Clone)]
pub struct SessionStore {
cache: BoxedCache,
ttl: Arc<Duration>,
}
impl SessionStore {
#[must_use]
pub fn new(cache: BoxedCache) -> Self {
Self {
cache,
ttl: Arc::new(Duration::from_secs(DEFAULT_TTL_SECS)),
}
}
#[must_use]
pub fn ttl(mut self, ttl: Duration) -> Self {
self.ttl = Arc::new(ttl);
self
}
pub async fn save(&self, session: &Session) -> Result<String, SessionError> {
let id = generate_id();
self.save_with_id(&id, session).await?;
Ok(id)
}
pub async fn save_with_id(&self, id: &str, session: &Session) -> Result<(), SessionError> {
let json = serde_json::to_string(session)
.map_err(|e| SessionError::Serialization(e.to_string()))?;
self.cache
.set(&self.cache_key(id), &json, Some(*self.ttl))
.await?;
Ok(())
}
pub async fn load(&self, id: &str) -> Result<Option<Session>, SessionError> {
let Some(raw) = self.cache.get(&self.cache_key(id)).await? else {
return Ok(None);
};
let mut session: Session =
match serde_json::from_str(&raw) {
Ok(s) => s,
Err(_) => return Ok(None),
};
session.dirty = false;
Ok(Some(session))
}
pub async fn destroy(&self, id: &str) -> Result<(), SessionError> {
self.cache.delete(&self.cache_key(id)).await?;
Ok(())
}
pub async fn touch(&self, id: &str) -> Result<bool, SessionError> {
let key = self.cache_key(id);
let Some(raw) = self.cache.get(&key).await? else {
return Ok(false);
};
self.cache.set(&key, &raw, Some(*self.ttl)).await?;
Ok(true)
}
fn cache_key(&self, id: &str) -> String {
format!("{KEY_PREFIX}:{id}")
}
}
fn generate_id() -> String {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
let mut buf = [0u8; ID_BYTES];
use rand::RngCore;
rand::thread_rng().fill_bytes(&mut buf);
URL_SAFE_NO_PAD.encode(buf)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cache::InMemoryCache;
use std::sync::Arc as StdArc;
fn store() -> SessionStore {
let cache: BoxedCache = StdArc::new(InMemoryCache::new());
SessionStore::new(cache)
}
#[test]
fn fresh_session_is_clean_and_empty() {
let s = Session::new();
assert!(!s.is_dirty());
assert!(s.is_empty());
assert_eq!(s.len(), 0);
}
#[test]
fn set_marks_dirty_and_stores() {
let mut s = Session::new();
s.set("user_id", 42_i64);
assert!(s.is_dirty());
assert_eq!(s.get::<i64>("user_id"), Some(42));
assert_eq!(s.len(), 1);
}
#[test]
fn get_returns_none_for_missing() {
let s = Session::new();
assert_eq!(s.get::<i64>("nope"), None);
}
#[test]
fn get_returns_none_for_wrong_type() {
let mut s = Session::new();
s.set("flag", "string-not-a-number");
assert_eq!(s.get::<i64>("flag"), None);
}
#[test]
fn remove_returns_previous_and_marks_dirty() {
let mut s = Session::new();
s.set("k", "v");
let prev = s.remove("k");
assert_eq!(prev.unwrap(), "v");
assert!(s.is_dirty());
assert!(s.is_empty());
}
#[test]
fn remove_missing_does_not_mark_dirty() {
let mut s = Session::new();
assert!(s.remove("nope").is_none());
assert!(!s.is_dirty());
}
#[test]
fn clear_wipes_all_keys() {
let mut s = Session::new();
s.set("a", 1);
s.set("b", 2);
s.clear();
assert!(s.is_empty());
assert!(s.is_dirty());
}
#[test]
fn keys_iterates_inserted_keys() {
let mut s = Session::new();
s.set("a", 1);
s.set("b", 2);
let mut keys: Vec<&String> = s.keys().collect();
keys.sort();
assert_eq!(keys.iter().map(|s| s.as_str()).collect::<Vec<_>>(), vec!["a", "b"]);
}
#[tokio::test]
async fn save_then_load_roundtrips() {
let store = store();
let mut s = Session::new();
s.set("user_id", 42_i64);
s.set("name", "Alice");
let id = store.save(&s).await.unwrap();
let loaded = store.load(&id).await.unwrap().unwrap();
assert_eq!(loaded.get::<i64>("user_id"), Some(42));
assert_eq!(loaded.get::<String>("name").as_deref(), Some("Alice"));
assert!(!loaded.is_dirty());
}
#[tokio::test]
async fn load_unknown_id_returns_none() {
let store = store();
assert!(store.load("does-not-exist").await.unwrap().is_none());
}
#[tokio::test]
async fn destroy_removes_session() {
let store = store();
let id = store.save(&Session::new()).await.unwrap();
assert!(store.load(&id).await.unwrap().is_some());
store.destroy(&id).await.unwrap();
assert!(store.load(&id).await.unwrap().is_none());
}
#[tokio::test]
async fn touch_extends_ttl_on_existing_session() {
let store = store();
let id = store.save(&Session::new()).await.unwrap();
assert!(store.touch(&id).await.unwrap());
assert!(store.load(&id).await.unwrap().is_some());
}
#[tokio::test]
async fn touch_returns_false_on_missing_session() {
let store = store();
assert!(!store.touch("does-not-exist").await.unwrap());
}
#[tokio::test]
async fn save_with_id_rewrites_existing_session() {
let store = store();
let mut s = Session::new();
s.set("v", 1);
let id = store.save(&s).await.unwrap();
let mut loaded = store.load(&id).await.unwrap().unwrap();
loaded.set("v", 2);
store.save_with_id(&id, &loaded).await.unwrap();
let again = store.load(&id).await.unwrap().unwrap();
assert_eq!(again.get::<i64>("v"), Some(2));
}
#[tokio::test]
async fn each_save_generates_distinct_id() {
let store = store();
let id1 = store.save(&Session::new()).await.unwrap();
let id2 = store.save(&Session::new()).await.unwrap();
assert_ne!(id1, id2);
}
#[tokio::test]
async fn corrupted_cache_value_loads_as_none() {
let store = store();
store
.cache
.set("session:corrupt", "not-json-{}", Some(Duration::from_secs(60)))
.await
.unwrap();
assert!(store.load("corrupt").await.unwrap().is_none());
}
#[tokio::test]
async fn complex_value_roundtrips() {
let store = store();
let mut s = Session::new();
let payload = serde_json::json!({"role": "admin", "perms": ["read", "write"]});
s.set("ctx", payload.clone());
let id = store.save(&s).await.unwrap();
let loaded = store.load(&id).await.unwrap().unwrap();
assert_eq!(loaded.get::<serde_json::Value>("ctx"), Some(payload));
}
#[test]
fn generated_id_is_url_safe_and_192_bits() {
let id = generate_id();
assert_eq!(id.len(), 32);
assert!(id.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
}
#[test]
fn generated_ids_are_distinct() {
let a = generate_id();
let b = generate_id();
assert_ne!(a, b);
}
}