use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::Mutex;
use tracing::{debug, trace};
#[derive(Clone, Debug)]
struct SessionEntry {
ticket: Arc<Vec<u8>>,
created_at: SystemTime,
ttl: Duration,
}
impl SessionEntry {
fn is_expired(&self) -> bool {
match self.created_at.elapsed() {
Ok(elapsed) => elapsed > self.ttl,
Err(_) => true, }
}
}
#[derive(Clone)]
pub struct SessionCache {
max_entries: usize,
default_ttl: Duration,
inner: Arc<Mutex<SessionCacheInner>>,
}
struct SessionCacheInner {
sessions: HashMap<String, SessionEntry>,
total_inserts: u64,
}
impl SessionCache {
pub fn new(max_entries: usize, default_ttl: Duration) -> Self {
Self {
max_entries,
default_ttl,
inner: Arc::new(Mutex::new(SessionCacheInner {
sessions: HashMap::with_capacity(max_entries),
total_inserts: 0,
})),
}
}
pub async fn store<S: Into<String>>(&self, session_id: S, ticket: Vec<u8>) {
let mut inner = self.inner.lock().await;
let session_id = session_id.into();
let entry = SessionEntry {
ticket: Arc::new(ticket),
created_at: SystemTime::now(),
ttl: self.default_ttl,
};
self.evict_expired(&mut inner);
inner.sessions.insert(session_id.clone(), entry);
inner.total_inserts += 1;
if inner.sessions.len() > self.max_entries {
self.evict_oldest(&mut inner);
}
trace!(
session_count = inner.sessions.len(),
"Session stored in cache"
);
}
pub async fn get(&self, session_id: &str) -> Option<Arc<Vec<u8>>> {
let mut inner = self.inner.lock().await;
if let Some(entry) = inner.sessions.get(session_id) {
if !entry.is_expired() {
trace!("Session cache hit");
return Some(entry.ticket.clone());
}
}
inner.sessions.remove(session_id);
trace!("Session cache miss or expired");
None
}
pub async fn clear(&self) {
let mut inner = self.inner.lock().await;
let count = inner.sessions.len();
inner.sessions.clear();
debug!(cleared_count = count, "Session cache cleared");
}
pub async fn stats(&self) -> SessionCacheStats {
let inner = self.inner.lock().await;
let expired_count = inner.sessions.values().filter(|e| e.is_expired()).count();
SessionCacheStats {
total_entries: inner.sessions.len(),
max_entries: self.max_entries,
expired_count,
total_inserts: inner.total_inserts,
}
}
#[allow(dead_code)]
async fn evict_expired_async(&self) {
let mut inner = self.inner.lock().await;
self.evict_expired(&mut inner);
}
fn evict_expired(&self, inner: &mut SessionCacheInner) {
let before = inner.sessions.len();
inner.sessions.retain(|_, entry| !entry.is_expired());
let after = inner.sessions.len();
if before != after {
debug!(
removed_count = before - after,
remaining_count = after,
"Expired sessions evicted"
);
}
}
fn evict_oldest(&self, inner: &mut SessionCacheInner) {
if let Some(oldest_key) = inner
.sessions
.iter()
.min_by_key(|(_, entry)| entry.created_at)
.map(|(k, _)| k.clone())
{
inner.sessions.remove(&oldest_key);
debug!("Oldest session evicted to make room");
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct SessionCacheStats {
pub total_entries: usize,
pub max_entries: usize,
pub expired_count: usize,
pub total_inserts: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[allow(clippy::unwrap_used)]
async fn test_store_and_retrieve() {
let cache = SessionCache::new(10, Duration::from_secs(60));
cache.store("session-1", vec![1, 2, 3, 4]).await;
let ticket = cache.get("session-1").await;
assert!(ticket.is_some());
assert_eq!(*ticket.unwrap(), vec![1, 2, 3, 4]);
}
#[tokio::test]
async fn test_missing_session() {
let cache = SessionCache::new(10, Duration::from_secs(60));
let ticket = cache.get("nonexistent").await;
assert!(ticket.is_none());
}
#[tokio::test]
async fn test_capacity_eviction() {
let cache = SessionCache::new(3, Duration::from_secs(60));
for i in 0..5 {
cache.store(format!("session-{i}"), vec![i as u8]).await;
}
let stats = cache.stats().await;
assert_eq!(stats.total_entries, 3);
assert_eq!(stats.total_inserts, 5);
}
#[tokio::test]
async fn test_clear() {
let cache = SessionCache::new(10, Duration::from_secs(60));
cache.store("session-1", vec![1, 2, 3]).await;
cache.clear().await;
let ticket = cache.get("session-1").await;
assert!(ticket.is_none());
}
}