use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use tokio::sync::RwLock;
pub const SV_CACHE_KEY_PREFIX: &str = "sv:";
pub const SV_CACHE_TTL: Duration = Duration::from_secs(60);
#[async_trait]
pub trait SessionVersionCache: Send + Sync {
async fn get(&self, key: &str) -> Option<i64>;
async fn set(&self, key: &str, sv: i64, ttl: Duration);
}
pub struct MemorySessionVersionCache {
inner: Arc<RwLock<HashMap<String, (i64, Instant)>>>,
}
impl MemorySessionVersionCache {
#[must_use]
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for MemorySessionVersionCache {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SessionVersionCache for MemorySessionVersionCache {
async fn get(&self, key: &str) -> Option<i64> {
let guard = self.inner.read().await;
let (sv, written_at) = guard.get(key)?;
if written_at.elapsed() >= SV_CACHE_TTL {
return None;
}
Some(*sv)
}
async fn set(&self, key: &str, sv: i64, _ttl: Duration) {
let mut guard = self.inner.write().await;
guard.insert(key.to_string(), (sv, Instant::now()));
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[tokio::test]
async fn memory_cache_respects_ttl() {
let cache = MemorySessionVersionCache::new();
cache.set("sv:abc", 42, SV_CACHE_TTL).await;
assert_eq!(cache.get("sv:abc").await, Some(42));
assert_eq!(cache.get("sv:missing").await, None);
}
#[tokio::test]
async fn memory_cache_overwrite() {
let cache = MemorySessionVersionCache::new();
cache.set("sv:xyz", 1, SV_CACHE_TTL).await;
cache.set("sv:xyz", 2, SV_CACHE_TTL).await;
assert_eq!(cache.get("sv:xyz").await, Some(2));
}
}