use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use crate::protocol::{ClientCapabilities, Implementation};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct SessionRecord {
pub id: String,
pub protocol_version: String,
pub client_info: Option<Implementation>,
pub client_capabilities: Option<ClientCapabilities>,
pub created_at: SystemTime,
pub last_accessed: SystemTime,
pub expires_at: SystemTime,
}
impl SessionRecord {
pub fn new(id: impl Into<String>, protocol_version: impl Into<String>, ttl: Duration) -> Self {
let now = SystemTime::now();
Self {
id: id.into(),
protocol_version: protocol_version.into(),
client_info: None,
client_capabilities: None,
created_at: now,
last_accessed: now,
expires_at: now + ttl,
}
}
pub fn touch(&mut self, ttl: Duration) {
let now = SystemTime::now();
self.last_accessed = now;
self.expires_at = now + ttl;
}
pub fn is_expired(&self) -> bool {
SystemTime::now() >= self.expires_at
}
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum SessionStoreError {
#[error("encode error: {0}")]
Encode(String),
#[error("decode error: {0}")]
Decode(String),
#[error("backend error: {0}")]
Backend(String),
}
pub type Result<T> = std::result::Result<T, SessionStoreError>;
#[async_trait]
pub trait SessionStore: Send + Sync + 'static {
async fn create(&self, record: &mut SessionRecord) -> Result<()>;
async fn save(&self, record: &SessionRecord) -> Result<()>;
async fn load(&self, id: &str) -> Result<Option<SessionRecord>>;
async fn delete(&self, id: &str) -> Result<()>;
}
#[derive(Debug, Default, Clone)]
pub struct MemorySessionStore {
inner: Arc<RwLock<HashMap<String, SessionRecord>>>,
}
impl MemorySessionStore {
pub fn new() -> Self {
Self::default()
}
pub async fn len(&self) -> usize {
self.inner.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.inner.read().await.is_empty()
}
pub async fn cleanup_expired(&self) -> usize {
let mut map = self.inner.write().await;
let before = map.len();
map.retain(|_, record| !record.is_expired());
before - map.len()
}
}
#[async_trait]
impl SessionStore for MemorySessionStore {
async fn create(&self, record: &mut SessionRecord) -> Result<()> {
let mut map = self.inner.write().await;
while map.contains_key(&record.id) {
record.id = uuid::Uuid::new_v4().to_string();
}
map.insert(record.id.clone(), record.clone());
Ok(())
}
async fn save(&self, record: &SessionRecord) -> Result<()> {
self.inner
.write()
.await
.insert(record.id.clone(), record.clone());
Ok(())
}
async fn load(&self, id: &str) -> Result<Option<SessionRecord>> {
let map = self.inner.read().await;
Ok(map.get(id).filter(|r| !r.is_expired()).cloned())
}
async fn delete(&self, id: &str) -> Result<()> {
self.inner.write().await.remove(id);
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct CachingSessionStore<Cache, Store> {
cache: Cache,
store: Store,
}
impl<Cache, Store> CachingSessionStore<Cache, Store> {
pub fn new(cache: Cache, store: Store) -> Self {
Self { cache, store }
}
}
#[async_trait]
impl<Cache, Store> SessionStore for CachingSessionStore<Cache, Store>
where
Cache: SessionStore,
Store: SessionStore,
{
async fn create(&self, record: &mut SessionRecord) -> Result<()> {
self.store.create(record).await?;
self.cache.save(record).await?;
Ok(())
}
async fn save(&self, record: &SessionRecord) -> Result<()> {
self.store.save(record).await?;
self.cache.save(record).await?;
Ok(())
}
async fn load(&self, id: &str) -> Result<Option<SessionRecord>> {
if let Some(record) = self.cache.load(id).await? {
return Ok(Some(record));
}
match self.store.load(id).await? {
Some(record) => {
let _ = self.cache.save(&record).await;
Ok(Some(record))
}
None => Ok(None),
}
}
async fn delete(&self, id: &str) -> Result<()> {
let cache_result = self.cache.delete(id).await;
let store_result = self.store.delete(id).await;
cache_result.and(store_result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn sample_record(id: &str) -> SessionRecord {
SessionRecord::new(id, "2025-11-25", Duration::from_secs(60))
}
#[tokio::test]
async fn memory_store_create_load_delete() {
let store = MemorySessionStore::new();
let mut record = sample_record("abc");
store.create(&mut record).await.unwrap();
let loaded = store.load("abc").await.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().id, "abc");
store.delete("abc").await.unwrap();
assert!(store.load("abc").await.unwrap().is_none());
}
#[tokio::test]
async fn memory_store_create_regenerates_id_on_collision() {
let store = MemorySessionStore::new();
let mut first = sample_record("dup");
store.create(&mut first).await.unwrap();
let mut second = sample_record("dup");
store.create(&mut second).await.unwrap();
assert_ne!(first.id, second.id);
assert!(store.load(&first.id).await.unwrap().is_some());
assert!(store.load(&second.id).await.unwrap().is_some());
}
#[tokio::test]
async fn memory_store_save_upserts() {
let store = MemorySessionStore::new();
let mut record = sample_record("upsert");
store.create(&mut record).await.unwrap();
record.protocol_version = "2025-06-18".into();
store.save(&record).await.unwrap();
let loaded = store.load("upsert").await.unwrap().unwrap();
assert_eq!(loaded.protocol_version, "2025-06-18");
}
#[tokio::test]
async fn memory_store_hides_expired_records() {
let store = MemorySessionStore::new();
let mut record = SessionRecord::new("expired", "2025-11-25", Duration::from_millis(1));
store.create(&mut record).await.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(store.load("expired").await.unwrap().is_none());
}
#[tokio::test]
async fn memory_store_cleanup_removes_expired() {
let store = MemorySessionStore::new();
let mut live = SessionRecord::new("live", "2025-11-25", Duration::from_secs(60));
store.create(&mut live).await.unwrap();
let mut dead = SessionRecord::new("dead", "2025-11-25", Duration::from_millis(1));
store.create(&mut dead).await.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
let removed = store.cleanup_expired().await;
assert_eq!(removed, 1);
assert_eq!(store.len().await, 1);
}
#[tokio::test]
async fn memory_store_delete_is_idempotent() {
let store = MemorySessionStore::new();
store.delete("nonexistent").await.unwrap();
}
#[tokio::test]
async fn record_touch_updates_timestamps() {
let mut record = SessionRecord::new("t", "2025-11-25", Duration::from_secs(60));
let original_expiry = record.expires_at;
tokio::time::sleep(Duration::from_millis(10)).await;
record.touch(Duration::from_secs(60));
assert!(record.expires_at > original_expiry);
}
#[tokio::test]
async fn dyn_session_store_object_safe() {
let store: Arc<dyn SessionStore> = Arc::new(MemorySessionStore::new());
let mut record = sample_record("dyn");
store.create(&mut record).await.unwrap();
assert!(store.load(&record.id).await.unwrap().is_some());
}
#[tokio::test]
async fn caching_store_writes_to_both_tiers() {
let cache = MemorySessionStore::new();
let backend = MemorySessionStore::new();
let store = CachingSessionStore::new(cache.clone(), backend.clone());
let mut record = sample_record("cached");
store.create(&mut record).await.unwrap();
assert!(cache.load(&record.id).await.unwrap().is_some());
assert!(backend.load(&record.id).await.unwrap().is_some());
}
#[tokio::test]
async fn caching_store_populates_cache_on_miss() {
let cache = MemorySessionStore::new();
let backend = MemorySessionStore::new();
let mut record = sample_record("warm");
backend.create(&mut record).await.unwrap();
let id = record.id.clone();
assert!(cache.load(&id).await.unwrap().is_none());
let store = CachingSessionStore::new(cache.clone(), backend);
let loaded = store.load(&id).await.unwrap();
assert!(loaded.is_some());
assert!(cache.load(&id).await.unwrap().is_some());
}
#[tokio::test]
async fn caching_store_delete_clears_both() {
let cache = MemorySessionStore::new();
let backend = MemorySessionStore::new();
let store = CachingSessionStore::new(cache.clone(), backend.clone());
let mut record = sample_record("gone");
store.create(&mut record).await.unwrap();
store.delete(&record.id).await.unwrap();
assert!(cache.load(&record.id).await.unwrap().is_none());
assert!(backend.load(&record.id).await.unwrap().is_none());
}
}