rexis_rag/caching/
cache_core.rs

1//! # Core Cache Implementations
2//!
3//! Foundation cache data structures with different eviction policies.
4
5use super::{Cache, CacheEntryMetadata, CacheStats};
6use crate::RragResult;
7use std::collections::{HashMap, VecDeque};
8use std::hash::Hash;
9use std::time::{Duration, SystemTime};
10
11/// LRU Cache implementation
12pub struct LRUCache<K, V>
13where
14    K: Hash + Eq + Clone + Send + Sync + 'static,
15    V: Clone + Send + Sync + 'static,
16{
17    /// Internal storage
18    storage: HashMap<K, CacheNode<V>>,
19
20    /// Access order tracking
21    access_order: VecDeque<K>,
22
23    /// Maximum capacity
24    max_size: usize,
25
26    /// Cache statistics
27    stats: CacheStats,
28
29    /// Thread safety
30    _phantom: std::marker::PhantomData<(K, V)>,
31}
32
33/// LFU Cache implementation
34pub struct LFUCache<K, V>
35where
36    K: Hash + Eq + Clone + Send + Sync + 'static,
37    V: Clone + Send + Sync + 'static,
38{
39    /// Internal storage
40    storage: HashMap<K, CacheNode<V>>,
41
42    /// Frequency tracking
43    frequencies: HashMap<K, u64>,
44
45    /// Frequency buckets for efficient eviction
46    frequency_buckets: HashMap<u64, Vec<K>>,
47
48    /// Minimum frequency
49    min_frequency: u64,
50
51    /// Maximum capacity
52    max_size: usize,
53
54    /// Cache statistics
55    stats: CacheStats,
56}
57
58/// TTL Cache implementation
59pub struct TTLCache<K, V>
60where
61    K: Hash + Eq + Clone + Send + Sync + 'static,
62    V: Clone + Send + Sync + 'static,
63{
64    /// Internal storage with expiry
65    storage: HashMap<K, (V, SystemTime)>,
66
67    /// Default TTL
68    default_ttl: Duration,
69
70    /// Cleanup interval
71    cleanup_interval: Duration,
72
73    /// Last cleanup time
74    last_cleanup: SystemTime,
75
76    /// Cache statistics
77    stats: CacheStats,
78}
79
80/// ARC (Adaptive Replacement Cache) implementation
81pub struct ARCCache<K, V>
82where
83    K: Hash + Eq + Clone + Send + Sync + 'static,
84    V: Clone + Send + Sync + 'static,
85{
86    /// Recently used cache (T1)
87    t1: HashMap<K, V>,
88
89    /// Frequently used cache (T2)
90    t2: HashMap<K, V>,
91
92    /// Ghost entries recently evicted from T1 (B1)
93    b1: HashMap<K, ()>,
94
95    /// Ghost entries recently evicted from T2 (B2)
96    b2: HashMap<K, ()>,
97
98    /// LRU lists for T1 and T2
99    t1_lru: VecDeque<K>,
100    t2_lru: VecDeque<K>,
101    b1_lru: VecDeque<K>,
102    b2_lru: VecDeque<K>,
103
104    /// Adaptive parameter
105    p: f32,
106
107    /// Maximum capacity
108    max_size: usize,
109
110    /// Cache statistics
111    stats: CacheStats,
112}
113
114/// Semantic-aware cache implementation
115pub struct SemanticAwareCache<K, V>
116where
117    K: Hash + Eq + Clone + Send + Sync + 'static,
118    V: Clone + Send + Sync + 'static,
119{
120    /// Primary storage
121    storage: HashMap<K, CacheNode<V>>,
122
123    /// Semantic similarity tracking
124    similarity_groups: HashMap<u64, Vec<K>>,
125
126    /// Embedding vectors for similarity computation
127    embeddings: HashMap<K, Vec<f32>>,
128
129    /// Access patterns
130    access_patterns: HashMap<K, AccessPattern>,
131
132    /// Maximum capacity
133    max_size: usize,
134
135    /// Similarity threshold for grouping
136    similarity_threshold: f32,
137
138    /// Cache statistics
139    stats: CacheStats,
140}
141
142/// Cache node with metadata
143#[derive(Debug, Clone)]
144pub struct CacheNode<V> {
145    /// The cached value
146    pub value: V,
147
148    /// Entry metadata
149    pub metadata: CacheEntryMetadata,
150
151    /// Computed size in bytes (approximate)
152    pub size_bytes: usize,
153}
154
155/// Access pattern tracking
156#[derive(Debug, Clone)]
157pub struct AccessPattern {
158    /// Total accesses
159    pub count: u64,
160
161    /// Recent access times
162    pub recent_accesses: VecDeque<SystemTime>,
163
164    /// Average access interval
165    pub avg_interval: Duration,
166
167    /// Access trend (increasing, decreasing, stable)
168    pub trend: AccessTrend,
169}
170
171/// Access trend types
172#[derive(Debug, Clone, Copy)]
173pub enum AccessTrend {
174    Increasing,
175    Decreasing,
176    Stable,
177    Unknown,
178}
179
180/// Priority entry for frequency-based eviction
181#[derive(Debug, Clone, PartialEq, Eq)]
182struct FrequencyEntry<K>
183where
184    K: Ord,
185{
186    key: K,
187    frequency: u64,
188    last_access: SystemTime,
189}
190
191impl<K, V> LRUCache<K, V>
192where
193    K: Hash + Eq + Clone + Send + Sync + 'static,
194    V: Clone + Send + Sync + 'static,
195{
196    /// Create new LRU cache
197    pub fn new(max_size: usize) -> Self {
198        Self {
199            storage: HashMap::with_capacity(max_size),
200            access_order: VecDeque::with_capacity(max_size),
201            max_size,
202            stats: CacheStats::default(),
203            _phantom: std::marker::PhantomData,
204        }
205    }
206
207    /// Update LRU order
208    fn update_lru(&mut self, key: &K) {
209        // Remove from current position
210        if let Some(pos) = self.access_order.iter().position(|k| k == key) {
211            self.access_order.remove(pos);
212        }
213
214        // Add to front (most recent)
215        self.access_order.push_front(key.clone());
216    }
217
218    /// Evict least recently used entry
219    fn evict_lru(&mut self) -> Option<K> {
220        if let Some(key) = self.access_order.pop_back() {
221            self.storage.remove(&key);
222            self.stats.evictions += 1;
223            Some(key)
224        } else {
225            None
226        }
227    }
228}
229
230impl<K, V> Cache<K, V> for LRUCache<K, V>
231where
232    K: Hash + Eq + Clone + Send + Sync + 'static,
233    V: Clone + Send + Sync + 'static,
234{
235    fn get(&self, key: &K) -> Option<V> {
236        let _start_time = SystemTime::now();
237
238        if let Some(node) = self.storage.get(key) {
239            // Update stats - hits handled by mutable reference in real implementation
240            Some(node.value.clone())
241        } else {
242            // Miss handled by mutable reference in real implementation
243            None
244        }
245    }
246
247    fn put(&mut self, key: K, value: V) -> RragResult<()> {
248        let size_bytes = std::mem::size_of::<V>();
249        let node = CacheNode {
250            value,
251            metadata: CacheEntryMetadata::new(),
252            size_bytes,
253        };
254
255        // If key exists, update and move to front
256        if self.storage.contains_key(&key) {
257            self.storage.insert(key.clone(), node);
258            self.update_lru(&key);
259            return Ok(());
260        }
261
262        // If at capacity, evict LRU
263        if self.storage.len() >= self.max_size {
264            self.evict_lru();
265        }
266
267        // Insert new entry
268        self.storage.insert(key.clone(), node);
269        self.update_lru(&key);
270
271        Ok(())
272    }
273
274    fn remove(&mut self, key: &K) -> Option<V> {
275        if let Some(node) = self.storage.remove(key) {
276            // Remove from LRU order
277            if let Some(pos) = self.access_order.iter().position(|k| k == key) {
278                self.access_order.remove(pos);
279            }
280            Some(node.value)
281        } else {
282            None
283        }
284    }
285
286    fn contains(&self, key: &K) -> bool {
287        self.storage.contains_key(key)
288    }
289
290    fn clear(&mut self) {
291        self.storage.clear();
292        self.access_order.clear();
293        self.stats = CacheStats::default();
294    }
295
296    fn size(&self) -> usize {
297        self.storage.len()
298    }
299
300    fn stats(&self) -> CacheStats {
301        self.stats.clone()
302    }
303}
304
305impl<K, V> LFUCache<K, V>
306where
307    K: Hash + Eq + Clone + Send + Sync + 'static,
308    V: Clone + Send + Sync + 'static,
309{
310    /// Create new LFU cache
311    pub fn new(max_size: usize) -> Self {
312        Self {
313            storage: HashMap::with_capacity(max_size),
314            frequencies: HashMap::with_capacity(max_size),
315            frequency_buckets: HashMap::new(),
316            min_frequency: 1,
317            max_size,
318            stats: CacheStats::default(),
319        }
320    }
321
322    /// Update frequency
323    fn update_frequency(&mut self, key: &K) {
324        let old_freq = self.frequencies.get(key).copied().unwrap_or(0);
325        let new_freq = old_freq + 1;
326
327        self.frequencies.insert(key.clone(), new_freq);
328
329        // Update frequency buckets
330        if old_freq > 0 {
331            if let Some(bucket) = self.frequency_buckets.get_mut(&old_freq) {
332                bucket.retain(|k| k != key);
333                if bucket.is_empty() && old_freq == self.min_frequency {
334                    self.min_frequency += 1;
335                }
336            }
337        }
338
339        self.frequency_buckets
340            .entry(new_freq)
341            .or_insert_with(Vec::new)
342            .push(key.clone());
343    }
344
345    /// Evict least frequently used entry
346    fn evict_lfu(&mut self) -> Option<K> {
347        if let Some(bucket) = self.frequency_buckets.get_mut(&self.min_frequency) {
348            if let Some(key) = bucket.pop() {
349                self.storage.remove(&key);
350                self.frequencies.remove(&key);
351                self.stats.evictions += 1;
352                return Some(key);
353            }
354        }
355        None
356    }
357}
358
359impl<K, V> Cache<K, V> for LFUCache<K, V>
360where
361    K: Hash + Eq + Clone + Send + Sync + 'static,
362    V: Clone + Send + Sync + 'static,
363{
364    fn get(&self, key: &K) -> Option<V> {
365        if let Some(node) = self.storage.get(key) {
366            Some(node.value.clone())
367        } else {
368            None
369        }
370    }
371
372    fn put(&mut self, key: K, value: V) -> RragResult<()> {
373        let size_bytes = std::mem::size_of::<V>();
374        let node = CacheNode {
375            value,
376            metadata: CacheEntryMetadata::new(),
377            size_bytes,
378        };
379
380        // If key exists, update
381        if self.storage.contains_key(&key) {
382            self.storage.insert(key.clone(), node);
383            self.update_frequency(&key);
384            return Ok(());
385        }
386
387        // If at capacity, evict LFU
388        if self.storage.len() >= self.max_size {
389            self.evict_lfu();
390        }
391
392        // Insert new entry
393        self.storage.insert(key.clone(), node);
394        self.update_frequency(&key);
395
396        Ok(())
397    }
398
399    fn remove(&mut self, key: &K) -> Option<V> {
400        if let Some(node) = self.storage.remove(key) {
401            self.frequencies.remove(key);
402            Some(node.value)
403        } else {
404            None
405        }
406    }
407
408    fn contains(&self, key: &K) -> bool {
409        self.storage.contains_key(key)
410    }
411
412    fn clear(&mut self) {
413        self.storage.clear();
414        self.frequencies.clear();
415        self.frequency_buckets.clear();
416        self.min_frequency = 1;
417        self.stats = CacheStats::default();
418    }
419
420    fn size(&self) -> usize {
421        self.storage.len()
422    }
423
424    fn stats(&self) -> CacheStats {
425        self.stats.clone()
426    }
427}
428
429impl<K, V> TTLCache<K, V>
430where
431    K: Hash + Eq + Clone + Send + Sync + 'static,
432    V: Clone + Send + Sync + 'static,
433{
434    /// Create new TTL cache
435    pub fn new(default_ttl: Duration) -> Self {
436        Self {
437            storage: HashMap::new(),
438            default_ttl,
439            cleanup_interval: Duration::from_secs(60), // 1 minute
440            last_cleanup: SystemTime::now(),
441            stats: CacheStats::default(),
442        }
443    }
444
445    /// Cleanup expired entries
446    fn cleanup_expired(&mut self) {
447        let now = SystemTime::now();
448
449        // Only cleanup if interval has passed
450        if now.duration_since(self.last_cleanup).unwrap_or_default() < self.cleanup_interval {
451            return;
452        }
453
454        let before_count = self.storage.len();
455        self.storage.retain(|_key, (_, expiry)| now < *expiry);
456        let after_count = self.storage.len();
457
458        self.stats.evictions += (before_count - after_count) as u64;
459        self.last_cleanup = now;
460    }
461}
462
463impl<K, V> Cache<K, V> for TTLCache<K, V>
464where
465    K: Hash + Eq + Clone + Send + Sync + 'static,
466    V: Clone + Send + Sync + 'static,
467{
468    fn get(&self, key: &K) -> Option<V> {
469        if let Some((value, expiry)) = self.storage.get(key) {
470            if SystemTime::now() < *expiry {
471                Some(value.clone())
472            } else {
473                None
474            }
475        } else {
476            None
477        }
478    }
479
480    fn put(&mut self, key: K, value: V) -> RragResult<()> {
481        let expiry = SystemTime::now() + self.default_ttl;
482        self.storage.insert(key, (value, expiry));
483
484        // Periodic cleanup
485        self.cleanup_expired();
486
487        Ok(())
488    }
489
490    fn remove(&mut self, key: &K) -> Option<V> {
491        self.storage.remove(key).map(|(value, _)| value)
492    }
493
494    fn contains(&self, key: &K) -> bool {
495        if let Some((_, expiry)) = self.storage.get(key) {
496            SystemTime::now() < *expiry
497        } else {
498            false
499        }
500    }
501
502    fn clear(&mut self) {
503        self.storage.clear();
504        self.stats = CacheStats::default();
505    }
506
507    fn size(&self) -> usize {
508        // Count only non-expired entries
509        let now = SystemTime::now();
510        self.storage
511            .values()
512            .filter(|(_, expiry)| now < *expiry)
513            .count()
514    }
515
516    fn stats(&self) -> CacheStats {
517        self.stats.clone()
518    }
519}
520
521impl<K> PartialOrd for FrequencyEntry<K>
522where
523    K: Ord,
524{
525    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
526        Some(self.cmp(other))
527    }
528}
529
530impl<K> Ord for FrequencyEntry<K>
531where
532    K: Ord,
533{
534    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
535        // Lower frequency first (for min-heap)
536        self.frequency
537            .cmp(&other.frequency)
538            .then_with(|| self.last_access.cmp(&other.last_access))
539    }
540}
541
542impl AccessPattern {
543    /// Create new access pattern
544    pub fn new() -> Self {
545        Self {
546            count: 0,
547            recent_accesses: VecDeque::new(),
548            avg_interval: Duration::from_secs(0),
549            trend: AccessTrend::Unknown,
550        }
551    }
552
553    /// Record an access
554    pub fn record_access(&mut self) {
555        let now = SystemTime::now();
556        self.count += 1;
557        self.recent_accesses.push_back(now);
558
559        // Keep only recent accesses (last 10)
560        if self.recent_accesses.len() > 10 {
561            self.recent_accesses.pop_front();
562        }
563
564        self.update_metrics();
565    }
566
567    /// Update computed metrics
568    fn update_metrics(&mut self) {
569        if self.recent_accesses.len() < 2 {
570            return;
571        }
572
573        // Calculate average interval
574        let mut total_interval = Duration::from_secs(0);
575        let mut interval_count = 0;
576
577        for window in self.recent_accesses.as_slices().0.windows(2) {
578            if let Ok(interval) = window[1].duration_since(window[0]) {
579                total_interval += interval;
580                interval_count += 1;
581            }
582        }
583
584        if interval_count > 0 {
585            self.avg_interval = total_interval / interval_count as u32;
586        }
587
588        // Determine trend (simplified)
589        if self.recent_accesses.len() >= 4 {
590            let _first_half_avg = self.recent_accesses.len() / 2;
591            // Trend analysis would go here
592            self.trend = AccessTrend::Stable; // Simplified for now
593        }
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    #[test]
602    fn test_lru_cache() {
603        let mut cache = LRUCache::new(3);
604
605        cache.put("a".to_string(), 1).unwrap();
606        cache.put("b".to_string(), 2).unwrap();
607        cache.put("c".to_string(), 3).unwrap();
608
609        assert_eq!(cache.size(), 3);
610        assert_eq!(cache.get(&"a".to_string()), Some(1));
611
612        // This should evict the LRU entry
613        cache.put("d".to_string(), 4).unwrap();
614        assert_eq!(cache.size(), 3);
615    }
616
617    #[test]
618    fn test_lfu_cache() {
619        let mut cache = LFUCache::new(2);
620
621        cache.put("a".to_string(), 1).unwrap();
622        cache.put("b".to_string(), 2).unwrap();
623
624        // Access 'a' more frequently
625        cache.get(&"a".to_string());
626        cache.get(&"a".to_string());
627
628        // This should evict 'b' (less frequent)
629        cache.put("c".to_string(), 3).unwrap();
630
631        assert_eq!(cache.get(&"a".to_string()), Some(1));
632        assert_eq!(cache.get(&"b".to_string()), None);
633        assert_eq!(cache.get(&"c".to_string()), Some(3));
634    }
635
636    #[test]
637    fn test_ttl_cache() {
638        let mut cache = TTLCache::new(Duration::from_millis(100));
639
640        cache.put("key".to_string(), "value".to_string()).unwrap();
641        assert_eq!(cache.get(&"key".to_string()), Some("value".to_string()));
642
643        // Sleep longer than TTL
644        std::thread::sleep(Duration::from_millis(150));
645        assert_eq!(cache.get(&"key".to_string()), None);
646    }
647
648    #[test]
649    fn test_access_pattern() {
650        let mut pattern = AccessPattern::new();
651        assert_eq!(pattern.count, 0);
652
653        pattern.record_access();
654        assert_eq!(pattern.count, 1);
655
656        pattern.record_access();
657        assert_eq!(pattern.count, 2);
658    }
659}