use std::{
collections::HashMap,
future::Future,
pin::Pin,
sync::RwLock,
time::{Duration, Instant},
};
use serde::{de::DeserializeOwned, Serialize};
use super::{Cache, CacheConfig};
struct CacheEntry {
data: Vec<u8>,
expires_at: Option<Instant>,
}
impl CacheEntry {
fn new(data: Vec<u8>, ttl: Option<Duration>) -> Self {
Self {
data,
expires_at: ttl.map(|d| Instant::now() + d),
}
}
fn is_expired(&self) -> bool {
self.expires_at.map(|t| Instant::now() > t).unwrap_or(false)
}
}
pub struct MemoryCache {
entries: RwLock<HashMap<String, CacheEntry>>,
config: CacheConfig,
}
impl MemoryCache {
pub fn new() -> Self {
Self::with_config(CacheConfig::default())
}
pub fn with_config(config: CacheConfig) -> Self {
Self {
entries: RwLock::new(HashMap::new()),
config,
}
}
pub fn builder() -> MemoryCacheBuilder {
MemoryCacheBuilder::new()
}
pub fn cleanup_expired(&self) {
let mut entries = self.entries.write().unwrap();
entries.retain(|_, entry| !entry.is_expired());
}
fn effective_ttl(&self, ttl: Option<Duration>) -> Option<Duration> {
ttl.or(self.config.default_ttl)
}
fn full_key(&self, key: &str) -> String {
self.config.build_key(key)
}
}
impl Default for MemoryCache {
fn default() -> Self {
Self::new()
}
}
impl Cache for MemoryCache {
fn get<T: DeserializeOwned + Send>(
&self,
key: &str,
) -> Pin<Box<dyn Future<Output = Option<T>> + Send + '_>> {
let key = self.full_key(key);
Box::pin(async move {
let entries = self.entries.read().unwrap();
entries.get(&key).and_then(|entry| {
if entry.is_expired() {
None
} else {
serde_json::from_slice(&entry.data).ok()
}
})
})
}
fn set<T: Serialize + Send + Sync>(
&self,
key: &str,
value: &T,
ttl: Option<Duration>,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
let key = self.full_key(key);
let ttl = self.effective_ttl(ttl);
let data = match serde_json::to_vec(value) {
Ok(d) => d,
Err(_) => return Box::pin(async {}),
};
Box::pin(async move {
let mut entries = self.entries.write().unwrap();
if let Some(max) = self.config.max_entries {
if entries.len() >= max && !entries.contains_key(&key) {
entries.retain(|_, entry| !entry.is_expired());
if entries.len() >= max {
if let Some(oldest_key) = entries.keys().next().cloned() {
entries.remove(&oldest_key);
}
}
}
}
entries.insert(key, CacheEntry::new(data, ttl));
})
}
fn delete(&self, key: &str) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
let key = self.full_key(key);
Box::pin(async move {
let mut entries = self.entries.write().unwrap();
entries.remove(&key).is_some()
})
}
fn exists(&self, key: &str) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
let key = self.full_key(key);
Box::pin(async move {
let entries = self.entries.read().unwrap();
entries
.get(&key)
.map(|entry| !entry.is_expired())
.unwrap_or(false)
})
}
fn clear(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
Box::pin(async move {
let mut entries = self.entries.write().unwrap();
entries.clear();
})
}
fn len(&self) -> Pin<Box<dyn Future<Output = Option<usize>> + Send + '_>> {
Box::pin(async move {
let entries = self.entries.read().unwrap();
let count = entries.values().filter(|e| !e.is_expired()).count();
Some(count)
})
}
}
pub struct MemoryCacheBuilder {
config: CacheConfig,
}
impl MemoryCacheBuilder {
pub fn new() -> Self {
Self {
config: CacheConfig::default(),
}
}
pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
self.config.prefix = Some(prefix.into());
self
}
pub fn default_ttl(mut self, ttl: Duration) -> Self {
self.config.default_ttl = Some(ttl);
self
}
pub fn max_entries(mut self, max: usize) -> Self {
self.config.max_entries = Some(max);
self
}
pub fn build(self) -> MemoryCache {
MemoryCache::with_config(self.config)
}
}
impl Default for MemoryCacheBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_cache_basic() {
let cache = MemoryCache::new();
cache.set("key", &"value", None).await;
let value: Option<String> = cache.get("key").await;
assert_eq!(value, Some("value".to_string()));
}
#[tokio::test]
async fn test_memory_cache_missing_key() {
let cache = MemoryCache::new();
let value: Option<String> = cache.get("nonexistent").await;
assert_eq!(value, None);
}
#[tokio::test]
async fn test_memory_cache_delete() {
let cache = MemoryCache::new();
cache.set("key", &"value", None).await;
assert!(cache.exists("key").await);
let deleted = cache.delete("key").await;
assert!(deleted);
assert!(!cache.exists("key").await);
}
#[tokio::test]
async fn test_memory_cache_delete_nonexistent() {
let cache = MemoryCache::new();
let deleted = cache.delete("nonexistent").await;
assert!(!deleted);
}
#[tokio::test]
async fn test_memory_cache_clear() {
let cache = MemoryCache::new();
cache.set("key1", &"value1", None).await;
cache.set("key2", &"value2", None).await;
cache.clear().await;
assert!(!cache.exists("key1").await);
assert!(!cache.exists("key2").await);
}
#[tokio::test]
async fn test_memory_cache_len() {
let cache = MemoryCache::new();
assert_eq!(cache.len().await, Some(0));
cache.set("key1", &"value1", None).await;
cache.set("key2", &"value2", None).await;
assert_eq!(cache.len().await, Some(2));
}
#[tokio::test]
async fn test_memory_cache_ttl_expired() {
let cache = MemoryCache::new();
cache
.set("key", &"value", Some(Duration::from_millis(1)))
.await;
tokio::time::sleep(Duration::from_millis(10)).await;
let value: Option<String> = cache.get("key").await;
assert_eq!(value, None);
}
#[tokio::test]
async fn test_memory_cache_ttl_not_expired() {
let cache = MemoryCache::new();
cache
.set("key", &"value", Some(Duration::from_secs(60)))
.await;
let value: Option<String> = cache.get("key").await;
assert_eq!(value, Some("value".to_string()));
}
#[tokio::test]
async fn test_memory_cache_with_prefix() {
let cache = MemoryCache::builder().prefix("test").build();
cache.set("key", &"value", None).await;
let value: Option<String> = cache.get("key").await;
assert_eq!(value, Some("value".to_string()));
}
#[tokio::test]
async fn test_memory_cache_max_entries() {
let cache = MemoryCache::builder().max_entries(2).build();
cache.set("key1", &"value1", None).await;
cache.set("key2", &"value2", None).await;
cache.set("key3", &"value3", None).await;
assert!(cache.len().await.unwrap() <= 2);
}
#[tokio::test]
async fn test_memory_cache_complex_type() {
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
struct User {
id: u64,
name: String,
}
let cache = MemoryCache::new();
let user = User {
id: 1,
name: "Alice".to_string(),
};
cache.set("user:1", &user, None).await;
let retrieved: Option<User> = cache.get("user:1").await;
assert_eq!(retrieved, Some(user));
}
#[tokio::test]
async fn test_memory_cache_builder_default_ttl() {
let cache = MemoryCache::builder()
.default_ttl(Duration::from_millis(1))
.build();
cache.set("key", &"value", None).await;
tokio::time::sleep(Duration::from_millis(10)).await;
let value: Option<String> = cache.get("key").await;
assert_eq!(value, None);
}
}