use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use tokio::sync::RwLock;
pub const SV_CACHE_KEY_PREFIX: &str = "sv:";
pub const SV_CACHE_TTL: Duration = Duration::from_secs(60);
const MAX_ENTRIES: usize = 10_000;
#[async_trait]
pub trait SessionVersionCache: Send + Sync {
async fn get(&self, key: &str) -> Option<i64>;
async fn set(&self, key: &str, sv: i64, ttl: Duration);
}
pub struct MemorySessionVersionCache {
inner: Arc<RwLock<HashMap<String, (i64, Instant)>>>,
}
impl MemorySessionVersionCache {
#[must_use]
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for MemorySessionVersionCache {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SessionVersionCache for MemorySessionVersionCache {
async fn get(&self, key: &str) -> Option<i64> {
let guard = self.inner.read().await;
let (sv, written_at) = guard.get(key)?;
if written_at.elapsed() >= SV_CACHE_TTL {
return None;
}
Some(*sv)
}
async fn set(&self, key: &str, sv: i64, _ttl: Duration) {
let mut guard = self.inner.write().await;
if guard.len() >= MAX_ENTRIES && !guard.contains_key(key) {
guard.retain(|_, (_, written_at)| written_at.elapsed() < SV_CACHE_TTL);
if guard.len() >= MAX_ENTRIES {
let oldest_key = guard
.iter()
.min_by_key(|(_, (_, written))| *written)
.map(|(k, _)| k.clone());
if let Some(k) = oldest_key {
guard.remove(&k);
}
}
}
guard.insert(key.to_string(), (sv, Instant::now()));
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[tokio::test]
async fn memory_cache_respects_ttl() {
let cache = MemorySessionVersionCache::new();
cache.set("sv:abc", 42, SV_CACHE_TTL).await;
assert_eq!(cache.get("sv:abc").await, Some(42));
assert_eq!(cache.get("sv:missing").await, None);
}
#[tokio::test]
async fn memory_cache_overwrite() {
let cache = MemorySessionVersionCache::new();
cache.set("sv:xyz", 1, SV_CACHE_TTL).await;
cache.set("sv:xyz", 2, SV_CACHE_TTL).await;
assert_eq!(cache.get("sv:xyz").await, Some(2));
}
#[tokio::test]
async fn memory_cache_bounded_by_max_entries() {
let cache = MemorySessionVersionCache::new();
for i in 0..(MAX_ENTRIES + 100) {
cache.set(&format!("sv:{i}"), i as i64, SV_CACHE_TTL).await;
}
let len = cache.inner.read().await.len();
assert!(
len <= MAX_ENTRIES,
"cache exceeded cap: {len} > {MAX_ENTRIES}"
);
let last_key = format!("sv:{}", MAX_ENTRIES + 99);
assert_eq!(
cache.get(&last_key).await,
Some((MAX_ENTRIES + 99) as i64),
"newest entry must survive eviction"
);
}
}