use std::collections::HashMap;
use std::hash::Hash;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
type SerializedCacheStore = Arc<RwLock<HashMap<String, Vec<u8>>>>;
#[derive(Debug, Clone)]
pub struct CacheEntry<V> {
pub value: V,
pub created_at: Instant,
}
impl<V> CacheEntry<V> {
pub fn new(value: V) -> Self {
Self {
value,
created_at: Instant::now(),
}
}
pub fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed() > ttl
}
}
pub struct FactoryCache<K, V>
where
K: Hash + Eq,
V: Clone,
{
store: Arc<RwLock<HashMap<K, CacheEntry<V>>>>,
ttl: Duration,
}
impl<K, V> FactoryCache<K, V>
where
K: Hash + Eq,
V: Clone,
{
pub fn new(ttl: Duration) -> Self {
Self {
store: Arc::new(RwLock::new(HashMap::new())),
ttl,
}
}
pub fn get(&self, key: &K) -> Option<V> {
let store = self.store.read().ok()?;
let entry = store.get(key)?;
if entry.is_expired(self.ttl) {
drop(store);
self.invalidate(key);
None
} else {
Some(entry.value.clone())
}
}
pub fn insert(&self, key: K, value: V) {
if let Ok(mut store) = self.store.write() {
store.insert(key, CacheEntry::new(value));
}
}
pub fn invalidate(&self, key: &K) {
if let Ok(mut store) = self.store.write() {
store.remove(key);
}
}
pub fn clear(&self) {
if let Ok(mut store) = self.store.write() {
store.clear();
}
}
pub fn cleanup_expired(&self) -> usize {
if let Ok(mut store) = self.store.write() {
let initial_len = store.len();
store.retain(|_, entry| !entry.is_expired(self.ttl));
initial_len - store.len()
} else {
0
}
}
pub fn len(&self) -> usize {
self.store.read().map(|s| s.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn stats(&self) -> (usize, usize) {
if let Ok(store) = self.store.read() {
let total = store.len();
let expired = store.values().filter(|e| e.is_expired(self.ttl)).count();
(total, expired)
} else {
(0, 0)
}
}
}
impl<K, V> Clone for FactoryCache<K, V>
where
K: Hash + Eq,
V: Clone,
{
fn clone(&self) -> Self {
Self {
store: Arc::clone(&self.store),
ttl: self.ttl,
}
}
}
pub struct CachedFactory<F> {
factory: F,
cache: Option<SerializedCacheStore>, }
impl<F> CachedFactory<F> {
pub fn new(factory: F) -> Self {
Self {
factory,
cache: Some(Arc::new(RwLock::new(HashMap::new()))),
}
}
pub fn without_cache(factory: F) -> Self {
Self {
factory,
cache: None,
}
}
pub fn inner(&self) -> &F {
&self.factory
}
pub fn into_inner(self) -> F {
self.factory
}
pub fn is_cached(&self) -> bool {
self.cache.is_some()
}
pub fn clear_cache(&self) {
if let Some(cache) = &self.cache {
if let Ok(mut store) = cache.write() {
store.clear();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_cache_basic_operations() {
let cache: FactoryCache<i32, String> = FactoryCache::new(Duration::from_secs(1));
cache.insert(1, "value1".to_string());
cache.insert(2, "value2".to_string());
assert_eq!(cache.get(&1), Some("value1".to_string()));
assert_eq!(cache.get(&2), Some("value2".to_string()));
assert_eq!(cache.get(&3), None);
assert_eq!(cache.len(), 2);
assert!(!cache.is_empty());
}
#[test]
fn test_cache_expiration() {
let cache: FactoryCache<i32, String> = FactoryCache::new(Duration::from_millis(100));
cache.insert(1, "value1".to_string());
assert_eq!(cache.get(&1), Some("value1".to_string()));
thread::sleep(Duration::from_millis(150));
assert_eq!(cache.get(&1), None);
}
#[test]
fn test_cache_invalidation() {
let cache: FactoryCache<i32, String> = FactoryCache::new(Duration::from_secs(60));
cache.insert(1, "value1".to_string());
assert_eq!(cache.get(&1), Some("value1".to_string()));
cache.invalidate(&1);
assert_eq!(cache.get(&1), None);
}
#[test]
fn test_cache_clear() {
let cache: FactoryCache<i32, String> = FactoryCache::new(Duration::from_secs(60));
cache.insert(1, "value1".to_string());
cache.insert(2, "value2".to_string());
assert_eq!(cache.len(), 2);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_cache_cleanup_expired() {
let cache: FactoryCache<i32, String> = FactoryCache::new(Duration::from_millis(100));
cache.insert(1, "value1".to_string());
cache.insert(2, "value2".to_string());
cache.insert(3, "value3".to_string());
thread::sleep(Duration::from_millis(150));
let removed = cache.cleanup_expired();
assert_eq!(removed, 3);
assert_eq!(cache.len(), 0);
}
#[test]
fn test_cache_stats() {
let cache: FactoryCache<i32, String> = FactoryCache::new(Duration::from_millis(100));
cache.insert(1, "value1".to_string());
cache.insert(2, "value2".to_string());
let (total, expired) = cache.stats();
assert_eq!(total, 2);
assert_eq!(expired, 0);
thread::sleep(Duration::from_millis(150));
let (total, expired) = cache.stats();
assert_eq!(total, 2);
assert_eq!(expired, 2);
}
#[test]
fn test_cache_thread_safety() {
let cache: FactoryCache<i32, String> = FactoryCache::new(Duration::from_secs(60));
let cache_clone = cache.clone();
let handle = thread::spawn(move || {
cache_clone.insert(1, "thread_value".to_string());
});
handle.join().unwrap();
assert_eq!(cache.get(&1), Some("thread_value".to_string()));
}
#[test]
fn test_cached_factory() {
struct DummyFactory;
let factory = DummyFactory;
let cached = CachedFactory::new(factory);
assert!(cached.is_cached());
assert!(cached.inner() as *const _ == &cached.factory as *const _);
let uncached = CachedFactory::without_cache(DummyFactory);
assert!(!uncached.is_cached());
}
}