use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Arc;
use tokio::sync::{RwLock, Semaphore};
use crate::{
eviction::{EvictionContext, EvictionStrategy},
search::Searchable,
CacheConfig, CacheEntry, CacheError, EntryMetadata, Result, StorageBackend,
};
type CacheStorage<K, V, M> = Arc<RwLock<HashMap<K, Vec<CacheEntry<K, V, M>>>>>;
type EvictionStrategyBox<K, V, M> = Box<dyn EvictionStrategy<K, V, M>>;
type Entry<K, V, M> = CacheEntry<K, V, M>;
#[async_trait]
pub trait AsyncCache<K, V>: Send + Sync
where
K: Hash + Eq + Clone + Send + Sync,
V: Clone + Send + Sync,
{
type Error;
async fn get(&self, key: &K) -> std::result::Result<Option<V>, Self::Error>;
async fn put(&self, key: K, value: V) -> std::result::Result<(), Self::Error>;
async fn remove(&self, key: &K) -> std::result::Result<Option<V>, Self::Error>;
async fn clear(&self) -> std::result::Result<(), Self::Error>;
async fn contains(&self, key: &K) -> std::result::Result<bool, Self::Error>;
async fn len(&self) -> std::result::Result<usize, Self::Error>;
async fn is_empty(&self) -> std::result::Result<bool, Self::Error> {
Ok(self.len().await? == 0)
}
}
#[allow(clippy::type_complexity)]
pub struct Cache<K, V, M = (), B = crate::backends::memory::MemoryBackend<K, V, M>>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
M: EntryMetadata + Default,
B: StorageBackend<Key = K, Value = V, Metadata = M>,
{
entries: CacheStorage<K, V, M>,
config: CacheConfig,
backend: Arc<B>,
save_semaphore: Arc<Semaphore>,
operation_count: Arc<RwLock<usize>>,
eviction_strategy: EvictionStrategyBox<K, V, M>,
}
impl<K, V, M, B> Cache<K, V, M, B>
where
K: Hash + Eq + Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
V: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
M: EntryMetadata + Default,
B: StorageBackend<Key = K, Value = V, Metadata = M>,
{
pub async fn new(config: CacheConfig, backend: B) -> Result<Self> {
let eviction_strategy = crate::eviction::create_strategy(&config.eviction_policy);
let cache = Self {
entries: Arc::new(RwLock::new(HashMap::new())),
config,
backend: Arc::new(backend),
save_semaphore: Arc::new(Semaphore::new(1)),
operation_count: Arc::new(RwLock::new(0)),
eviction_strategy,
};
if cache.config.persistence.enabled && cache.config.persistence.load_on_startup {
let _ = cache.load_from_storage().await;
}
Ok(cache)
}
pub async fn with_config(config: CacheConfig) -> Result<Self>
where
B: Default,
{
Self::new(config, B::default()).await
}
#[allow(clippy::type_complexity)]
pub async fn add_entry(&self, entry: Entry<K, V, M>) -> Result<()> {
let key = entry.key.clone();
{
let mut entries = self.entries.write().await;
let key_entries = entries.entry(key).or_insert_with(Vec::new);
key_entries.push(entry);
if key_entries.len() > self.config.max_entries_per_key {
key_entries.remove(0);
}
let total_entries: usize = entries.values().map(|v| v.len()).sum();
if total_entries > self.config.max_total_entries {
let context = EvictionContext {
max_total_entries: self.config.max_total_entries,
current_total_entries: total_entries,
};
self.eviction_strategy.evict(&mut entries, &context).await;
}
}
self.increment_and_maybe_sync().await?;
Ok(())
}
pub async fn get_entries(&self, key: &K) -> Option<Vec<CacheEntry<K, V, M>>> {
let mut entries = self.entries.write().await;
entries.get_mut(key).map(|entries| {
for entry in entries.iter_mut() {
entry.record_access();
}
entries.clone()
})
}
pub async fn get_latest(&self, key: &K) -> Option<CacheEntry<K, V, M>> {
let mut entries = self.entries.write().await;
entries.get_mut(key).and_then(|entries| {
entries.iter_mut().max_by_key(|e| e.timestamp).map(|e| {
e.record_access();
e.clone()
})
})
}
pub async fn search<Q>(&self, query: &Q) -> Vec<CacheEntry<K, V, M>>
where
CacheEntry<K, V, M>: Searchable<Query = Q>,
{
let entries = self.entries.read().await;
entries
.values()
.flat_map(|v| v.iter())
.filter(|entry| entry.matches(query))
.cloned()
.collect()
}
pub async fn get_stats(&self) -> CacheStats {
let entries = self.entries.read().await;
let total_entries: usize = entries.values().map(|v| v.len()).sum();
let total_keys = entries.len();
let mut total_access_count = 0u64;
let mut expired_count = 0usize;
for entry_vec in entries.values() {
for entry in entry_vec {
total_access_count += entry.access_count;
if entry.is_expired() {
expired_count += 1;
}
}
}
CacheStats {
total_entries,
total_keys,
total_access_count,
expired_count,
memory_usage_bytes: 0, }
}
async fn save_to_storage(&self) -> Result<()> {
if !self.config.persistence.enabled {
return Ok(());
}
let _permit = self.save_semaphore.acquire().await.unwrap();
let entries = self.entries.read().await;
self.backend.save(&entries).await
}
async fn load_from_storage(&self) -> Result<()> {
if !self.config.persistence.enabled {
return Ok(());
}
let loaded_entries = self.backend.load().await?;
let mut entries = self.entries.write().await;
*entries = loaded_entries;
Ok(())
}
async fn increment_and_maybe_sync(&self) -> Result<()> {
let mut count = self.operation_count.write().await;
*count += 1;
if *count >= self.config.persistence.sync_interval {
*count = 0;
drop(count);
let cache = self.clone();
tokio::spawn(async move {
let _ = cache.save_to_storage().await;
});
}
Ok(())
}
}
impl<K, V, M, B> Clone for Cache<K, V, M, B>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
M: EntryMetadata + Default,
B: StorageBackend<Key = K, Value = V, Metadata = M>,
{
fn clone(&self) -> Self {
Self {
entries: Arc::clone(&self.entries),
config: self.config.clone(),
backend: Arc::clone(&self.backend),
save_semaphore: Arc::clone(&self.save_semaphore),
operation_count: Arc::clone(&self.operation_count),
eviction_strategy: crate::eviction::create_strategy(&self.config.eviction_policy),
}
}
}
#[async_trait]
impl<K, V, M, B> AsyncCache<K, V> for Cache<K, V, M, B>
where
K: Hash + Eq + Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
V: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
M: EntryMetadata + Default,
B: StorageBackend<Key = K, Value = V, Metadata = M>,
{
type Error = CacheError;
async fn get(&self, key: &K) -> std::result::Result<Option<V>, Self::Error> {
Ok(self.get_latest(key).await.map(|entry| entry.value))
}
async fn put(&self, key: K, value: V) -> std::result::Result<(), Self::Error> {
{
let mut entries = self.entries.write().await;
let key_entries = entries.entry(key.clone()).or_insert_with(Vec::new);
key_entries.clear();
key_entries.push(CacheEntry::new(key, value));
}
self.increment_and_maybe_sync().await?;
Ok(())
}
async fn remove(&self, key: &K) -> std::result::Result<Option<V>, Self::Error> {
let mut entries = self.entries.write().await;
let removed = entries.remove(key);
if removed.is_some() {
self.backend.remove(key).await?;
self.increment_and_maybe_sync().await?;
}
Ok(removed.and_then(|entries| entries.into_iter().next_back().map(|e| e.value)))
}
async fn clear(&self) -> std::result::Result<(), Self::Error> {
let mut entries = self.entries.write().await;
entries.clear();
self.backend.clear().await?;
Ok(())
}
async fn contains(&self, key: &K) -> std::result::Result<bool, Self::Error> {
let entries = self.entries.read().await;
Ok(entries.contains_key(key))
}
async fn len(&self) -> std::result::Result<usize, Self::Error> {
let entries = self.entries.read().await;
Ok(entries.values().map(|v| v.len()).sum())
}
}
impl<K, V, M, B> Drop for Cache<K, V, M, B>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
M: EntryMetadata + Default,
B: StorageBackend<Key = K, Value = V, Metadata = M>,
{
fn drop(&mut self) {
if self.config.persistence.enabled && self.config.persistence.save_on_drop {
let entries = self.entries.clone();
let backend = self.backend.clone();
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn(async move {
let entries = entries.read().await;
let _ = backend.save(&entries).await;
});
}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub total_entries: usize,
pub total_keys: usize,
pub total_access_count: u64,
pub expired_count: usize,
pub memory_usage_bytes: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::memory::MemoryBackend;
#[tokio::test]
async fn test_cache_basic_operations() {
let config = CacheConfig::default();
let backend = MemoryBackend::new();
let cache: Cache<String, String> = Cache::new(config, backend).await.unwrap();
cache
.put("key1".to_string(), "value1".to_string())
.await
.unwrap();
let value = cache.get(&"key1".to_string()).await.unwrap();
assert_eq!(value, Some("value1".to_string()));
assert!(cache.contains(&"key1".to_string()).await.unwrap());
assert!(!cache.contains(&"key2".to_string()).await.unwrap());
assert_eq!(cache.len().await.unwrap(), 1);
let removed = cache.remove(&"key1".to_string()).await.unwrap();
assert_eq!(removed, Some("value1".to_string()));
assert_eq!(cache.len().await.unwrap(), 0);
}
#[tokio::test]
async fn test_cache_clear() {
let config = CacheConfig::default();
let backend = MemoryBackend::new();
let cache: Cache<String, String> = Cache::new(config, backend).await.unwrap();
cache
.put("key1".to_string(), "value1".to_string())
.await
.unwrap();
cache
.put("key2".to_string(), "value2".to_string())
.await
.unwrap();
assert_eq!(cache.len().await.unwrap(), 2);
cache.clear().await.unwrap();
assert_eq!(cache.len().await.unwrap(), 0);
assert!(!cache.contains(&"key1".to_string()).await.unwrap());
}
}