use crate::error::CheckpointError;
use lru::LruCache;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
struct CacheEntry {
data: Vec<u8>,
expires_at: Option<std::time::Instant>,
}
impl CacheEntry {
#[must_use]
fn is_expired(&self) -> bool {
self.expires_at
.is_some_and(|expires_at| std::time::Instant::now() >= expires_at)
}
}
#[async_trait::async_trait]
pub trait BaseCache: Send + Sync + 'static {
async fn get(&self, namespace: &str, key: &str) -> Result<Option<Vec<u8>>, CheckpointError>;
async fn set(
&self,
namespace: &str,
key: &str,
value: Vec<u8>,
ttl: Option<Duration>,
) -> Result<(), CheckpointError>;
async fn delete(&self, namespace: &str, key: &str) -> Result<(), CheckpointError>;
async fn clear(&self, namespace: Option<&str>) -> Result<(), CheckpointError>;
}
#[derive(Clone, Debug)]
pub struct MemoryCache {
entries: Arc<RwLock<LruCache<String, CacheEntry>>>,
default_ttl: Option<Duration>,
}
impl MemoryCache {
#[must_use]
pub fn new(capacity: usize) -> Self {
Self {
entries: Arc::new(RwLock::new(LruCache::new(
NonZeroUsize::new(capacity).expect("capacity must be non-zero"),
))),
default_ttl: None,
}
}
#[must_use]
pub fn with_ttl(capacity: usize, default_ttl: Duration) -> Self {
Self {
entries: Arc::new(RwLock::new(LruCache::new(
NonZeroUsize::new(capacity).expect("capacity must be non-zero"),
))),
default_ttl: Some(default_ttl),
}
}
#[must_use]
fn build_key(namespace: &str, key: &str) -> String {
format!("{namespace}:{key}")
}
async fn purge_expired(&self) {
let mut cache = self.entries.write().await;
let expired_keys: Vec<String> = cache
.iter()
.filter(|(_, entry)| entry.is_expired())
.map(|(key, _)| key.clone())
.collect();
for key in expired_keys {
cache.pop(&key);
}
}
pub async fn stats(&self) -> (usize, usize) {
let cache = self.entries.read().await;
(cache.len(), cache.cap().get())
}
}
impl Default for MemoryCache {
fn default() -> Self {
Self::new(1000)
}
}
#[async_trait::async_trait]
impl BaseCache for MemoryCache {
async fn get(&self, namespace: &str, key: &str) -> Result<Option<Vec<u8>>, CheckpointError> {
self.purge_expired().await;
let cache_key = Self::build_key(namespace, key);
{
let mut cache = self.entries.write().await;
if let Some(entry) = cache.get_mut(&cache_key) {
if entry.is_expired() {
cache.pop(&cache_key);
drop(cache);
return Ok(None);
}
let result = Ok(Some(entry.data.clone()));
drop(cache);
return result;
}
}
Ok(None)
}
async fn set(
&self,
namespace: &str,
key: &str,
value: Vec<u8>,
ttl: Option<Duration>,
) -> Result<(), CheckpointError> {
let cache_key = Self::build_key(namespace, key);
let ttl = ttl.or(self.default_ttl);
let entry = CacheEntry {
data: value,
expires_at: ttl.map(|duration| std::time::Instant::now() + duration),
};
self.entries.write().await.put(cache_key, entry);
Ok(())
}
async fn delete(&self, namespace: &str, key: &str) -> Result<(), CheckpointError> {
let cache_key = Self::build_key(namespace, key);
self.entries.write().await.pop(&cache_key);
Ok(())
}
async fn clear(&self, namespace: Option<&str>) -> Result<(), CheckpointError> {
if let Some(ns) = namespace {
let prefix = format!("{ns}:");
let mut cache = self.entries.write().await;
let keys_to_remove: Vec<String> = cache
.iter()
.filter(|(key, _)| key.starts_with(&prefix))
.map(|(key, _)| key.clone())
.collect();
for key in keys_to_remove {
cache.pop(&key);
}
} else {
self.entries.write().await.clear();
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_cache_set_get() {
let cache = MemoryCache::new(10);
cache
.set("ns1", "key1", b"hello".to_vec(), None)
.await
.unwrap();
let value = cache.get("ns1", "key1").await.unwrap();
assert_eq!(value, Some(b"hello".to_vec()));
}
#[tokio::test]
async fn test_memory_cache_miss() {
let cache = MemoryCache::new(10);
let value = cache.get("ns1", "nonexistent").await.unwrap();
assert!(value.is_none());
}
#[tokio::test]
async fn test_memory_cache_delete() {
let cache = MemoryCache::new(10);
cache
.set("ns1", "key1", b"hello".to_vec(), None)
.await
.unwrap();
cache.delete("ns1", "key1").await.unwrap();
let value = cache.get("ns1", "key1").await.unwrap();
assert!(value.is_none());
}
#[tokio::test]
async fn test_memory_cache_ttl() {
let cache = MemoryCache::with_ttl(10, Duration::from_millis(100));
cache
.set("ns1", "key1", b"hello".to_vec(), None)
.await
.unwrap();
let value = cache.get("ns1", "key1").await.unwrap();
assert_eq!(value, Some(b"hello".to_vec()));
tokio::time::sleep(Duration::from_millis(150)).await;
let value = cache.get("ns1", "key1").await.unwrap();
assert!(value.is_none());
}
#[tokio::test]
async fn test_memory_cache_clear_namespace() {
let cache = MemoryCache::new(10);
cache
.set("ns1", "key1", b"data1".to_vec(), None)
.await
.unwrap();
cache
.set("ns2", "key2", b"data2".to_vec(), None)
.await
.unwrap();
cache.clear(Some("ns1")).await.unwrap();
assert!(cache.get("ns1", "key1").await.unwrap().is_none());
assert_eq!(
cache.get("ns2", "key2").await.unwrap(),
Some(b"data2".to_vec())
);
}
#[tokio::test]
async fn test_memory_cache_clear_all() {
let cache = MemoryCache::new(10);
cache
.set("ns1", "key1", b"data1".to_vec(), None)
.await
.unwrap();
cache
.set("ns2", "key2", b"data2".to_vec(), None)
.await
.unwrap();
cache.clear(None).await.unwrap();
assert!(cache.get("ns1", "key1").await.unwrap().is_none());
assert!(cache.get("ns2", "key2").await.unwrap().is_none());
}
#[tokio::test]
async fn test_memory_cache_lru_eviction() {
let cache = MemoryCache::new(2);
cache
.set("ns1", "key1", b"data1".to_vec(), None)
.await
.unwrap();
cache
.set("ns1", "key2", b"data2".to_vec(), None)
.await
.unwrap();
cache.get("ns1", "key1").await.unwrap();
cache
.set("ns1", "key3", b"data3".to_vec(), None)
.await
.unwrap();
assert_eq!(
cache.get("ns1", "key1").await.unwrap(),
Some(b"data1".to_vec())
);
assert!(cache.get("ns1", "key2").await.unwrap().is_none());
assert_eq!(
cache.get("ns1", "key3").await.unwrap(),
Some(b"data3".to_vec())
);
}
#[tokio::test]
async fn test_memory_cache_stats() {
let cache = MemoryCache::new(100);
cache
.set("ns1", "key1", b"data1".to_vec(), None)
.await
.unwrap();
cache
.set("ns1", "key2", b"data2".to_vec(), None)
.await
.unwrap();
let (size, capacity) = cache.stats().await;
assert_eq!(size, 2);
assert_eq!(capacity, 100);
}
}