use super::{Cache, CacheError};
use lru::LruCache;
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
struct CacheEntry<V> {
value: V,
expires_at: Instant,
}
#[derive(Debug)]
pub struct MemoryCache {
cache: Arc<Mutex<LruCache<String, CacheEntry<Vec<u8>>>>>,
}
impl MemoryCache {
pub fn new(capacity: usize) -> Self {
Self {
cache: Arc::new(Mutex::new(LruCache::new(
std::num::NonZeroUsize::new(capacity.max(1))
.expect("capacity.max(1) should always be >= 1"),
))),
}
}
fn cleanup_expired(&self) {
let mut cache = match self.cache.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
let now = Instant::now();
let keys_to_remove: Vec<String> = cache
.iter()
.filter(|(_, entry)| entry.expires_at < now)
.map(|(key, _)| key.clone())
.collect();
for key in keys_to_remove {
cache.pop(&key);
}
}
}
impl<K, V> Cache<K, V> for MemoryCache
where
K: AsRef<str>,
V: Clone + Serialize + for<'de> Deserialize<'de>,
{
fn get(&self, key: &K) -> Option<V> {
self.cleanup_expired();
let mut cache = match self.cache.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
let key_str = key.as_ref();
let entry = cache.get(key_str)?;
if entry.expires_at < Instant::now() {
cache.pop(key_str);
return None;
}
let value = entry.value.clone();
drop(cache);
serde_json::from_slice(&value).ok()
}
fn set(&self, key: &K, value: &V, ttl: Duration) -> Result<(), CacheError> {
let serialized =
serde_json::to_vec(value).map_err(|e| CacheError::Serialization(e.to_string()))?;
let expires_at = Instant::now() + ttl;
let entry = CacheEntry {
value: serialized,
expires_at,
};
let key_string = key.as_ref().to_string();
{
let mut cache = match self.cache.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
cache.put(key_string, entry);
}
Ok(())
}
fn invalidate(&self, key: &K) -> Result<(), CacheError> {
let key_str = key.as_ref();
{
let mut cache = match self.cache.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
cache.pop(key_str);
}
Ok(())
}
fn clear(&self) -> Result<(), CacheError> {
{
let mut cache = match self.cache.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
cache.clear();
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration as StdDuration;
#[allow(clippy::unwrap_used)]
#[test]
fn test_memory_cache_get_set() {
let cache = MemoryCache::new(10);
let key = "test_key".to_string();
let value = "test_value".to_string();
assert!(<MemoryCache as Cache<String, String>>::get(&cache, &key).is_none());
<MemoryCache as Cache<String, String>>::set(
&cache,
&key,
&value,
StdDuration::from_secs(60),
)
.unwrap();
let retrieved = <MemoryCache as Cache<String, String>>::get(&cache, &key);
assert_eq!(retrieved, Some(value));
}
#[allow(clippy::unwrap_used)]
#[test]
fn test_memory_cache_ttl_expiration() {
let cache = MemoryCache::new(10);
let key = "test_key".to_string();
let value = "test_value".to_string();
<MemoryCache as Cache<String, String>>::set(
&cache,
&key,
&value,
StdDuration::from_millis(100),
)
.unwrap();
assert!(<MemoryCache as Cache<String, String>>::get(&cache, &key).is_some());
thread::sleep(StdDuration::from_millis(150));
assert!(<MemoryCache as Cache<String, String>>::get(&cache, &key).is_none());
}
#[allow(clippy::unwrap_used)]
#[test]
fn test_memory_cache_lru_eviction() {
let cache = MemoryCache::new(2);
let value1 = "value1".to_string();
let value2 = "value2".to_string();
let value3 = "value3".to_string();
<MemoryCache as Cache<String, String>>::set(
&cache,
&"key1".to_string(),
&value1,
StdDuration::from_secs(60),
)
.unwrap();
<MemoryCache as Cache<String, String>>::set(
&cache,
&"key2".to_string(),
&value2,
StdDuration::from_secs(60),
)
.unwrap();
<MemoryCache as Cache<String, String>>::get(&cache, &"key1".to_string());
<MemoryCache as Cache<String, String>>::set(
&cache,
&"key3".to_string(),
&value3,
StdDuration::from_secs(60),
)
.unwrap();
assert!(<MemoryCache as Cache<String, String>>::get(&cache, &"key1".to_string()).is_some());
assert!(<MemoryCache as Cache<String, String>>::get(&cache, &"key2".to_string()).is_none());
assert!(<MemoryCache as Cache<String, String>>::get(&cache, &"key3".to_string()).is_some());
}
#[allow(clippy::unwrap_used)]
#[test]
fn test_memory_cache_invalidate() {
let cache = MemoryCache::new(10);
let key = "test_key".to_string();
let value = "test_value".to_string();
<MemoryCache as Cache<String, String>>::set(
&cache,
&key,
&value,
StdDuration::from_secs(60),
)
.unwrap();
assert!(<MemoryCache as Cache<String, String>>::get(&cache, &key).is_some());
<MemoryCache as Cache<String, String>>::invalidate(&cache, &key).unwrap();
assert!(<MemoryCache as Cache<String, String>>::get(&cache, &key).is_none());
}
#[allow(clippy::unwrap_used)]
#[test]
fn test_memory_cache_clear() {
let cache = MemoryCache::new(10);
let value = "test_value".to_string();
<MemoryCache as Cache<String, String>>::set(
&cache,
&"key1".to_string(),
&value,
StdDuration::from_secs(60),
)
.unwrap();
<MemoryCache as Cache<String, String>>::set(
&cache,
&"key2".to_string(),
&value,
StdDuration::from_secs(60),
)
.unwrap();
<MemoryCache as Cache<String, String>>::clear(&cache).unwrap();
assert!(<MemoryCache as Cache<String, String>>::get(&cache, &"key1".to_string()).is_none());
assert!(<MemoryCache as Cache<String, String>>::get(&cache, &"key2".to_string()).is_none());
}
#[allow(clippy::unwrap_used)]
#[test]
fn test_memory_cache_thread_safety() {
let cache = Arc::new(MemoryCache::new(100));
let mut handles = vec![];
for i in 0..10 {
let cache_clone = Arc::clone(&cache);
let handle = thread::spawn(move || {
for j in 0..10 {
let key = format!("key_{i}_{j}");
let value = format!("value_{i}_{j}");
<MemoryCache as Cache<String, String>>::set(
&cache_clone,
&key,
&value,
StdDuration::from_secs(60),
)
.unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
for i in 0..10 {
for j in 0..10 {
let key = format!("key_{i}_{j}");
let expected = format!("value_{i}_{j}");
let retrieved = <MemoryCache as Cache<String, String>>::get(&cache, &key);
assert_eq!(retrieved, Some(expected));
}
}
}
}