Skip to main content

threatflux_cache/
cache.rs

1//! Core cache implementation
2
3use async_trait::async_trait;
4use serde::{de::DeserializeOwned, Serialize};
5use std::collections::HashMap;
6use std::hash::Hash;
7use std::sync::Arc;
8use tokio::sync::{RwLock, Semaphore};
9
10use crate::{
11    eviction::{EvictionContext, EvictionStrategy},
12    search::Searchable,
13    CacheConfig, CacheEntry, CacheError, EntryMetadata, Result, StorageBackend,
14};
15
16/// Type alias for cache entries storage
17type CacheStorage<K, V, M> = Arc<RwLock<HashMap<K, Vec<CacheEntry<K, V, M>>>>>;
18
19/// Type alias for eviction strategy
20type EvictionStrategyBox<K, V, M> = Box<dyn EvictionStrategy<K, V, M>>;
21
22/// Type alias for cache entry
23type Entry<K, V, M> = CacheEntry<K, V, M>;
24
25/// Async cache trait defining the core cache operations
26#[async_trait]
27pub trait AsyncCache<K, V>: Send + Sync
28where
29    K: Hash + Eq + Clone + Send + Sync,
30    V: Clone + Send + Sync,
31{
32    /// Error type for cache operations
33    type Error;
34
35    /// Get a value from the cache
36    async fn get(&self, key: &K) -> std::result::Result<Option<V>, Self::Error>;
37
38    /// Put a value into the cache
39    async fn put(&self, key: K, value: V) -> std::result::Result<(), Self::Error>;
40
41    /// Remove a value from the cache
42    async fn remove(&self, key: &K) -> std::result::Result<Option<V>, Self::Error>;
43
44    /// Clear all entries from the cache
45    async fn clear(&self) -> std::result::Result<(), Self::Error>;
46
47    /// Check if the cache contains a key
48    async fn contains(&self, key: &K) -> std::result::Result<bool, Self::Error>;
49
50    /// Get the number of entries in the cache
51    async fn len(&self) -> std::result::Result<usize, Self::Error>;
52
53    /// Check if the cache is empty
54    async fn is_empty(&self) -> std::result::Result<bool, Self::Error> {
55        Ok(self.len().await? == 0)
56    }
57}
58
59/// Main cache implementation
60#[allow(clippy::type_complexity)]
61pub struct Cache<K, V, M = (), B = crate::backends::memory::MemoryBackend<K, V, M>>
62where
63    K: Hash + Eq + Clone + Send + Sync + 'static,
64    V: Clone + Send + Sync + 'static,
65    M: EntryMetadata + Default,
66    B: StorageBackend<Key = K, Value = V, Metadata = M>,
67{
68    entries: CacheStorage<K, V, M>,
69    config: CacheConfig,
70    backend: Arc<B>,
71    save_semaphore: Arc<Semaphore>,
72    operation_count: Arc<RwLock<usize>>,
73    eviction_strategy: EvictionStrategyBox<K, V, M>,
74}
75
76impl<K, V, M, B> Cache<K, V, M, B>
77where
78    K: Hash + Eq + Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
79    V: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
80    M: EntryMetadata + Default,
81    B: StorageBackend<Key = K, Value = V, Metadata = M>,
82{
83    /// Create a new cache with the given configuration and backend
84    pub async fn new(config: CacheConfig, backend: B) -> Result<Self> {
85        let eviction_strategy = crate::eviction::create_strategy(&config.eviction_policy);
86
87        let cache = Self {
88            entries: Arc::new(RwLock::new(HashMap::new())),
89            config,
90            backend: Arc::new(backend),
91            save_semaphore: Arc::new(Semaphore::new(1)),
92            operation_count: Arc::new(RwLock::new(0)),
93            eviction_strategy,
94        };
95
96        // Load existing cache if configured
97        if cache.config.persistence.enabled && cache.config.persistence.load_on_startup {
98            let _ = cache.load_from_storage().await;
99        }
100
101        Ok(cache)
102    }
103
104    /// Create a new cache with default memory backend
105    pub async fn with_config(config: CacheConfig) -> Result<Self>
106    where
107        B: Default,
108    {
109        Self::new(config, B::default()).await
110    }
111
112    /// Add an entry to the cache
113    #[allow(clippy::type_complexity)]
114    pub async fn add_entry(&self, entry: Entry<K, V, M>) -> Result<()> {
115        let key = entry.key.clone();
116
117        {
118            let mut entries = self.entries.write().await;
119            let key_entries = entries.entry(key).or_insert_with(Vec::new);
120            key_entries.push(entry);
121
122            // Limit entries per key
123            if key_entries.len() > self.config.max_entries_per_key {
124                key_entries.remove(0);
125            }
126
127            // Check if we need to evict
128            let total_entries: usize = entries.values().map(|v| v.len()).sum();
129            if total_entries > self.config.max_total_entries {
130                let context = EvictionContext {
131                    max_total_entries: self.config.max_total_entries,
132                    current_total_entries: total_entries,
133                };
134                self.eviction_strategy.evict(&mut entries, &context).await;
135            }
136        }
137
138        // Increment operation count and check if we need to sync
139        self.increment_and_maybe_sync().await?;
140
141        Ok(())
142    }
143
144    /// Get all entries for a key
145    pub async fn get_entries(&self, key: &K) -> Option<Vec<CacheEntry<K, V, M>>> {
146        let mut entries = self.entries.write().await;
147        entries.get_mut(key).map(|entries| {
148            // Update access statistics
149            for entry in entries.iter_mut() {
150                entry.record_access();
151            }
152            entries.clone()
153        })
154    }
155
156    /// Get the latest entry for a key
157    pub async fn get_latest(&self, key: &K) -> Option<CacheEntry<K, V, M>> {
158        let mut entries = self.entries.write().await;
159        entries.get_mut(key).and_then(|entries| {
160            entries.iter_mut().max_by_key(|e| e.timestamp).map(|e| {
161                e.record_access();
162                e.clone()
163            })
164        })
165    }
166
167    /// Search entries based on a query
168    pub async fn search<Q>(&self, query: &Q) -> Vec<CacheEntry<K, V, M>>
169    where
170        CacheEntry<K, V, M>: Searchable<Query = Q>,
171    {
172        let entries = self.entries.read().await;
173        entries
174            .values()
175            .flat_map(|v| v.iter())
176            .filter(|entry| entry.matches(query))
177            .cloned()
178            .collect()
179    }
180
181    /// Get cache statistics
182    pub async fn get_stats(&self) -> CacheStats {
183        let entries = self.entries.read().await;
184
185        let total_entries: usize = entries.values().map(|v| v.len()).sum();
186        let total_keys = entries.len();
187        let mut total_access_count = 0u64;
188        let mut expired_count = 0usize;
189
190        for entry_vec in entries.values() {
191            for entry in entry_vec {
192                total_access_count += entry.access_count;
193                if entry.is_expired() {
194                    expired_count += 1;
195                }
196            }
197        }
198
199        CacheStats {
200            total_entries,
201            total_keys,
202            total_access_count,
203            expired_count,
204            memory_usage_bytes: 0, // Would need size estimation
205        }
206    }
207
208    /// Save cache to storage backend
209    async fn save_to_storage(&self) -> Result<()> {
210        if !self.config.persistence.enabled {
211            return Ok(());
212        }
213
214        let _permit = self.save_semaphore.acquire().await.unwrap();
215        let entries = self.entries.read().await;
216        self.backend.save(&entries).await
217    }
218
219    /// Load cache from storage backend
220    async fn load_from_storage(&self) -> Result<()> {
221        if !self.config.persistence.enabled {
222            return Ok(());
223        }
224
225        let loaded_entries = self.backend.load().await?;
226        let mut entries = self.entries.write().await;
227        *entries = loaded_entries;
228        Ok(())
229    }
230
231    /// Increment operation count and sync if needed
232    async fn increment_and_maybe_sync(&self) -> Result<()> {
233        let mut count = self.operation_count.write().await;
234        *count += 1;
235
236        if *count >= self.config.persistence.sync_interval {
237            *count = 0;
238            drop(count); // Release the lock before saving
239
240            // Spawn background save
241            let cache = self.clone();
242            tokio::spawn(async move {
243                let _ = cache.save_to_storage().await;
244            });
245        }
246
247        Ok(())
248    }
249}
250
251// Implement Clone for Cache
252impl<K, V, M, B> Clone for Cache<K, V, M, B>
253where
254    K: Hash + Eq + Clone + Send + Sync + 'static,
255    V: Clone + Send + Sync + 'static,
256    M: EntryMetadata + Default,
257    B: StorageBackend<Key = K, Value = V, Metadata = M>,
258{
259    fn clone(&self) -> Self {
260        Self {
261            entries: Arc::clone(&self.entries),
262            config: self.config.clone(),
263            backend: Arc::clone(&self.backend),
264            save_semaphore: Arc::clone(&self.save_semaphore),
265            operation_count: Arc::clone(&self.operation_count),
266            eviction_strategy: crate::eviction::create_strategy(&self.config.eviction_policy),
267        }
268    }
269}
270
271// Implement AsyncCache trait
272#[async_trait]
273impl<K, V, M, B> AsyncCache<K, V> for Cache<K, V, M, B>
274where
275    K: Hash + Eq + Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
276    V: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
277    M: EntryMetadata + Default,
278    B: StorageBackend<Key = K, Value = V, Metadata = M>,
279{
280    type Error = CacheError;
281
282    async fn get(&self, key: &K) -> std::result::Result<Option<V>, Self::Error> {
283        Ok(self.get_latest(key).await.map(|entry| entry.value))
284    }
285
286    async fn put(&self, key: K, value: V) -> std::result::Result<(), Self::Error> {
287        {
288            let mut entries = self.entries.write().await;
289            let key_entries = entries.entry(key.clone()).or_insert_with(Vec::new);
290
291            // For AsyncCache trait, replace existing entries rather than add
292            key_entries.clear();
293            key_entries.push(CacheEntry::new(key, value));
294        }
295
296        // Increment operation count and check if we need to sync
297        self.increment_and_maybe_sync().await?;
298        Ok(())
299    }
300
301    async fn remove(&self, key: &K) -> std::result::Result<Option<V>, Self::Error> {
302        let mut entries = self.entries.write().await;
303        let removed = entries.remove(key);
304
305        if removed.is_some() {
306            // Remove from backend
307            self.backend.remove(key).await?;
308            self.increment_and_maybe_sync().await?;
309        }
310
311        Ok(removed.and_then(|entries| entries.into_iter().next_back().map(|e| e.value)))
312    }
313
314    async fn clear(&self) -> std::result::Result<(), Self::Error> {
315        let mut entries = self.entries.write().await;
316        entries.clear();
317
318        self.backend.clear().await?;
319
320        Ok(())
321    }
322
323    async fn contains(&self, key: &K) -> std::result::Result<bool, Self::Error> {
324        let entries = self.entries.read().await;
325        Ok(entries.contains_key(key))
326    }
327
328    async fn len(&self) -> std::result::Result<usize, Self::Error> {
329        let entries = self.entries.read().await;
330        Ok(entries.values().map(|v| v.len()).sum())
331    }
332}
333
334// Implement Drop to save on shutdown
335impl<K, V, M, B> Drop for Cache<K, V, M, B>
336where
337    K: Hash + Eq + Clone + Send + Sync + 'static,
338    V: Clone + Send + Sync + 'static,
339    M: EntryMetadata + Default,
340    B: StorageBackend<Key = K, Value = V, Metadata = M>,
341{
342    fn drop(&mut self) {
343        if self.config.persistence.enabled && self.config.persistence.save_on_drop {
344            // Try to save synchronously in drop
345            let entries = self.entries.clone();
346            let backend = self.backend.clone();
347
348            // We can't use async in drop, so we spawn a task to save
349            if let Ok(handle) = tokio::runtime::Handle::try_current() {
350                handle.spawn(async move {
351                    let entries = entries.read().await;
352                    let _ = backend.save(&entries).await;
353                });
354            }
355        }
356    }
357}
358
359/// Cache statistics
360#[derive(Debug, Clone, Default)]
361pub struct CacheStats {
362    /// Total number of entries
363    pub total_entries: usize,
364    /// Total number of unique keys
365    pub total_keys: usize,
366    /// Total access count across all entries
367    pub total_access_count: u64,
368    /// Number of expired entries
369    pub expired_count: usize,
370    /// Approximate memory usage in bytes
371    pub memory_usage_bytes: usize,
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::backends::memory::MemoryBackend;
378
379    #[tokio::test]
380    async fn test_cache_basic_operations() {
381        let config = CacheConfig::default();
382        let backend = MemoryBackend::new();
383        let cache: Cache<String, String> = Cache::new(config, backend).await.unwrap();
384
385        // Test put and get
386        cache
387            .put("key1".to_string(), "value1".to_string())
388            .await
389            .unwrap();
390        let value = cache.get(&"key1".to_string()).await.unwrap();
391        assert_eq!(value, Some("value1".to_string()));
392
393        // Test contains
394        assert!(cache.contains(&"key1".to_string()).await.unwrap());
395        assert!(!cache.contains(&"key2".to_string()).await.unwrap());
396
397        // Test len
398        assert_eq!(cache.len().await.unwrap(), 1);
399
400        // Test remove
401        let removed = cache.remove(&"key1".to_string()).await.unwrap();
402        assert_eq!(removed, Some("value1".to_string()));
403        assert_eq!(cache.len().await.unwrap(), 0);
404    }
405
406    #[tokio::test]
407    async fn test_cache_clear() {
408        let config = CacheConfig::default();
409        let backend = MemoryBackend::new();
410        let cache: Cache<String, String> = Cache::new(config, backend).await.unwrap();
411
412        cache
413            .put("key1".to_string(), "value1".to_string())
414            .await
415            .unwrap();
416        cache
417            .put("key2".to_string(), "value2".to_string())
418            .await
419            .unwrap();
420
421        assert_eq!(cache.len().await.unwrap(), 2);
422
423        cache.clear().await.unwrap();
424        assert_eq!(cache.len().await.unwrap(), 0);
425        assert!(!cache.contains(&"key1".to_string()).await.unwrap());
426    }
427}