ipfrs_semantic/
cache.rs

1//! Advanced caching for vector embeddings
2//!
3//! This module provides high-performance caching strategies for vector embeddings
4//! including hot embedding cache, adaptive caching, and cache-aligned storage.
5
6use lru::LruCache;
7use parking_lot::RwLock;
8use std::num::NonZeroUsize;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13/// Cache-aligned vector storage
14///
15/// Vectors are aligned to cache line boundaries (64 bytes) to reduce
16/// cache misses and improve SIMD performance.
17#[repr(align(64))]
18#[derive(Debug, Clone)]
19pub struct AlignedVector {
20    data: Vec<f32>,
21}
22
23impl AlignedVector {
24    /// Create a new cache-aligned vector
25    pub fn new(data: Vec<f32>) -> Self {
26        Self { data }
27    }
28
29    /// Create a zeroed cache-aligned vector
30    pub fn zeros(len: usize) -> Self {
31        Self {
32            data: vec![0.0; len],
33        }
34    }
35
36    /// Get a reference to the underlying data
37    pub fn as_slice(&self) -> &[f32] {
38        &self.data
39    }
40
41    /// Get a mutable reference to the underlying data
42    pub fn as_mut_slice(&mut self) -> &mut [f32] {
43        &mut self.data
44    }
45
46    /// Get the length of the vector
47    pub fn len(&self) -> usize {
48        self.data.len()
49    }
50
51    /// Check if the vector is empty
52    pub fn is_empty(&self) -> bool {
53        self.data.is_empty()
54    }
55
56    /// Convert into the underlying Vec
57    pub fn into_vec(self) -> Vec<f32> {
58        self.data
59    }
60}
61
62impl From<Vec<f32>> for AlignedVector {
63    fn from(data: Vec<f32>) -> Self {
64        Self::new(data)
65    }
66}
67
68impl AsRef<[f32]> for AlignedVector {
69    fn as_ref(&self) -> &[f32] {
70        &self.data
71    }
72}
73
74impl AsMut<[f32]> for AlignedVector {
75    fn as_mut(&mut self) -> &mut [f32] {
76        &mut self.data
77    }
78}
79
80/// Access statistics for a cached item
81#[derive(Debug, Clone)]
82struct AccessStats {
83    /// Number of times accessed
84    access_count: u64,
85    /// Last access time
86    last_access: Instant,
87    /// First access time
88    first_access: Instant,
89    /// Total time in cache (for adaptive sizing)
90    time_in_cache: Duration,
91}
92
93impl AccessStats {
94    fn new() -> Self {
95        let now = Instant::now();
96        Self {
97            access_count: 1,
98            last_access: now,
99            first_access: now,
100            time_in_cache: Duration::from_secs(0),
101        }
102    }
103
104    fn record_access(&mut self) {
105        self.access_count += 1;
106        self.last_access = Instant::now();
107        self.time_in_cache = self.last_access.duration_since(self.first_access);
108    }
109
110    fn access_frequency(&self) -> f64 {
111        if self.time_in_cache.as_secs_f64() > 0.0 {
112            self.access_count as f64 / self.time_in_cache.as_secs_f64()
113        } else {
114            self.access_count as f64
115        }
116    }
117}
118
119/// Cached embedding entry
120#[derive(Debug, Clone)]
121struct CachedEmbedding {
122    vector: AlignedVector,
123    stats: AccessStats,
124}
125
126/// Hot embedding cache with LRU eviction and adaptive sizing
127///
128/// Caches frequently accessed embeddings in memory with cache-aligned
129/// storage for optimal SIMD performance.
130pub struct HotEmbeddingCache {
131    /// LRU cache for embeddings
132    cache: Arc<RwLock<LruCache<String, CachedEmbedding>>>,
133    /// Access statistics (hits/misses)
134    hits: Arc<AtomicU64>,
135    misses: Arc<AtomicU64>,
136    /// Total cache capacity
137    capacity: usize,
138    /// Prefetch queue for predicted accesses
139    prefetch_queue: Arc<RwLock<Vec<String>>>,
140}
141
142impl HotEmbeddingCache {
143    /// Create a new hot embedding cache
144    pub fn new(capacity: usize) -> Self {
145        Self {
146            cache: Arc::new(RwLock::new(LruCache::new(
147                NonZeroUsize::new(capacity).unwrap(),
148            ))),
149            hits: Arc::new(AtomicU64::new(0)),
150            misses: Arc::new(AtomicU64::new(0)),
151            capacity,
152            prefetch_queue: Arc::new(RwLock::new(Vec::new())),
153        }
154    }
155
156    /// Get an embedding from the cache
157    pub fn get(&self, key: &str) -> Option<AlignedVector> {
158        let mut cache = self.cache.write();
159        if let Some(entry) = cache.get_mut(key) {
160            entry.stats.record_access();
161            self.hits.fetch_add(1, Ordering::Relaxed);
162            Some(entry.vector.clone())
163        } else {
164            self.misses.fetch_add(1, Ordering::Relaxed);
165            None
166        }
167    }
168
169    /// Insert an embedding into the cache
170    pub fn insert(&self, key: String, vector: Vec<f32>) {
171        let aligned = AlignedVector::new(vector);
172        let entry = CachedEmbedding {
173            vector: aligned,
174            stats: AccessStats::new(),
175        };
176        self.cache.write().put(key, entry);
177    }
178
179    /// Get cache statistics
180    pub fn stats(&self) -> HotCacheStats {
181        let hits = self.hits.load(Ordering::Relaxed);
182        let misses = self.misses.load(Ordering::Relaxed);
183        let total = hits + misses;
184        let hit_rate = if total > 0 {
185            hits as f64 / total as f64
186        } else {
187            0.0
188        };
189
190        let cache = self.cache.read();
191        let size = cache.len();
192
193        HotCacheStats {
194            hits,
195            misses,
196            hit_rate,
197            size,
198            capacity: self.capacity,
199        }
200    }
201
202    /// Clear the cache
203    pub fn clear(&self) {
204        self.cache.write().clear();
205        self.hits.store(0, Ordering::Relaxed);
206        self.misses.store(0, Ordering::Relaxed);
207    }
208
209    /// Get the current size of the cache
210    pub fn len(&self) -> usize {
211        self.cache.read().len()
212    }
213
214    /// Check if the cache is empty
215    pub fn is_empty(&self) -> bool {
216        self.cache.read().is_empty()
217    }
218
219    /// Prefetch embeddings (add to prefetch queue)
220    pub fn prefetch(&self, keys: Vec<String>) {
221        let mut queue = self.prefetch_queue.write();
222        queue.extend(keys);
223    }
224
225    /// Get hot embeddings (most frequently accessed)
226    pub fn get_hot_keys(&self, top_n: usize) -> Vec<String> {
227        let cache = self.cache.read();
228        let mut entries: Vec<_> = cache
229            .iter()
230            .map(|(k, v)| (k.clone(), v.stats.access_frequency()))
231            .collect();
232
233        entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
234        entries.into_iter().take(top_n).map(|(k, _)| k).collect()
235    }
236}
237
238/// Statistics for hot embedding cache
239#[derive(Debug, Clone)]
240pub struct HotCacheStats {
241    /// Number of cache hits
242    pub hits: u64,
243    /// Number of cache misses
244    pub misses: u64,
245    /// Hit rate (hits / total accesses)
246    pub hit_rate: f64,
247    /// Current cache size
248    pub size: usize,
249    /// Maximum cache capacity
250    pub capacity: usize,
251}
252
253/// Adaptive caching strategy
254///
255/// Dynamically adjusts cache size and eviction policy based on
256/// access patterns and hit rates.
257pub struct AdaptiveCacheStrategy {
258    /// Current cache size target
259    target_size: Arc<RwLock<usize>>,
260    /// Minimum cache size
261    min_size: usize,
262    /// Maximum cache size
263    max_size: usize,
264    /// Target hit rate (0.0-1.0)
265    target_hit_rate: f64,
266    /// Adjustment factor for cache sizing
267    adjustment_factor: f64,
268}
269
270impl AdaptiveCacheStrategy {
271    /// Create a new adaptive caching strategy
272    pub fn new(min_size: usize, max_size: usize, target_hit_rate: f64) -> Self {
273        Self {
274            target_size: Arc::new(RwLock::new((min_size + max_size) / 2)),
275            min_size,
276            max_size,
277            target_hit_rate,
278            adjustment_factor: 1.1, // 10% adjustment per iteration
279        }
280    }
281
282    /// Adjust cache size based on current hit rate
283    pub fn adjust(&self, current_hit_rate: f64) -> usize {
284        let mut target = self.target_size.write();
285
286        if current_hit_rate < self.target_hit_rate {
287            // Hit rate too low, increase cache size
288            let new_size =
289                (*target as f64 * self.adjustment_factor).min(self.max_size as f64) as usize;
290            *target = new_size;
291        } else if current_hit_rate > self.target_hit_rate + 0.05 {
292            // Hit rate high enough, can reduce cache size
293            let new_size =
294                (*target as f64 / self.adjustment_factor).max(self.min_size as f64) as usize;
295            *target = new_size;
296        }
297
298        *target
299    }
300
301    /// Get the current target cache size
302    pub fn target_size(&self) -> usize {
303        *self.target_size.read()
304    }
305
306    /// Reset to default size
307    pub fn reset(&self) {
308        *self.target_size.write() = (self.min_size + self.max_size) / 2;
309    }
310}
311
312/// Cache invalidation policy
313#[derive(Debug, Clone, Copy, PartialEq, Eq)]
314pub enum InvalidationPolicy {
315    /// Time-to-live based invalidation
316    TTL(Duration),
317    /// Event-driven invalidation (manual)
318    Event,
319    /// Never invalidate (manual only)
320    Never,
321}
322
323/// Cache invalidation tracker
324pub struct CacheInvalidator {
325    /// Invalidation policy
326    policy: InvalidationPolicy,
327    /// Timestamp of last invalidation
328    last_invalidation: Arc<RwLock<Instant>>,
329}
330
331impl CacheInvalidator {
332    /// Create a new cache invalidator
333    pub fn new(policy: InvalidationPolicy) -> Self {
334        Self {
335            policy,
336            last_invalidation: Arc::new(RwLock::new(Instant::now())),
337        }
338    }
339
340    /// Check if cache should be invalidated
341    pub fn should_invalidate(&self) -> bool {
342        match self.policy {
343            InvalidationPolicy::TTL(ttl) => {
344                let elapsed = self.last_invalidation.read().elapsed();
345                elapsed >= ttl
346            }
347            InvalidationPolicy::Event => false, // Manual invalidation only
348            InvalidationPolicy::Never => false,
349        }
350    }
351
352    /// Mark cache as invalidated
353    pub fn invalidate(&self) {
354        *self.last_invalidation.write() = Instant::now();
355    }
356
357    /// Get time since last invalidation
358    pub fn time_since_invalidation(&self) -> Duration {
359        self.last_invalidation.read().elapsed()
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_aligned_vector_creation() {
369        let data = vec![1.0, 2.0, 3.0, 4.0];
370        let aligned = AlignedVector::new(data.clone());
371
372        assert_eq!(aligned.len(), 4);
373        assert_eq!(aligned.as_slice(), &data[..]);
374    }
375
376    #[test]
377    fn test_aligned_vector_alignment() {
378        let aligned = AlignedVector::zeros(100);
379
380        // Check that the AlignedVector struct itself is aligned to 64 bytes
381        // Note: The inner Vec's data is heap-allocated and uses standard allocator alignment
382        assert_eq!(
383            std::mem::align_of::<AlignedVector>(),
384            64,
385            "AlignedVector struct should be aligned to 64 bytes"
386        );
387
388        // Verify the Vec data pointer has at least the natural alignment for f32
389        let ptr = aligned.as_slice().as_ptr() as usize;
390        assert_eq!(
391            ptr % std::mem::align_of::<f32>(),
392            0,
393            "Data pointer should be properly aligned for f32"
394        );
395    }
396
397    #[test]
398    fn test_hot_cache_basic() {
399        let cache = HotEmbeddingCache::new(10);
400
401        // Insert some vectors
402        cache.insert("key1".to_string(), vec![1.0, 2.0, 3.0]);
403        cache.insert("key2".to_string(), vec![4.0, 5.0, 6.0]);
404
405        // Test retrieval
406        let vec1 = cache.get("key1").unwrap();
407        assert_eq!(vec1.as_slice(), &[1.0, 2.0, 3.0]);
408
409        let vec2 = cache.get("key2").unwrap();
410        assert_eq!(vec2.as_slice(), &[4.0, 5.0, 6.0]);
411
412        // Test miss
413        assert!(cache.get("key3").is_none());
414    }
415
416    #[test]
417    fn test_hot_cache_stats() {
418        let cache = HotEmbeddingCache::new(10);
419
420        cache.insert("key1".to_string(), vec![1.0, 2.0, 3.0]);
421
422        // One hit
423        cache.get("key1");
424        // One miss
425        cache.get("key2");
426
427        let stats = cache.stats();
428        assert_eq!(stats.hits, 1);
429        assert_eq!(stats.misses, 1);
430        assert_eq!(stats.hit_rate, 0.5);
431    }
432
433    #[test]
434    fn test_hot_cache_lru() {
435        let cache = HotEmbeddingCache::new(2);
436
437        cache.insert("key1".to_string(), vec![1.0]);
438        cache.insert("key2".to_string(), vec![2.0]);
439        cache.insert("key3".to_string(), vec![3.0]); // Should evict key1
440
441        assert!(cache.get("key1").is_none());
442        assert!(cache.get("key2").is_some());
443        assert!(cache.get("key3").is_some());
444    }
445
446    #[test]
447    fn test_adaptive_strategy() {
448        let strategy = AdaptiveCacheStrategy::new(100, 1000, 0.8);
449
450        let initial_size = strategy.target_size();
451        assert_eq!(initial_size, 550); // (100 + 1000) / 2
452
453        // Low hit rate should increase size
454        let new_size = strategy.adjust(0.5);
455        assert!(new_size > initial_size);
456
457        // High hit rate should decrease size
458        strategy.reset();
459        let new_size = strategy.adjust(0.95);
460        assert!(new_size < initial_size);
461    }
462
463    #[test]
464    fn test_cache_invalidator_ttl() {
465        let invalidator = CacheInvalidator::new(InvalidationPolicy::TTL(Duration::from_millis(10)));
466
467        assert!(!invalidator.should_invalidate());
468
469        std::thread::sleep(Duration::from_millis(15));
470
471        assert!(invalidator.should_invalidate());
472    }
473
474    #[test]
475    fn test_cache_invalidator_never() {
476        let invalidator = CacheInvalidator::new(InvalidationPolicy::Never);
477
478        std::thread::sleep(Duration::from_millis(10));
479
480        assert!(!invalidator.should_invalidate());
481    }
482
483    #[test]
484    fn test_hot_keys_tracking() {
485        let cache = HotEmbeddingCache::new(10);
486
487        cache.insert("key1".to_string(), vec![1.0]);
488        cache.insert("key2".to_string(), vec![2.0]);
489        cache.insert("key3".to_string(), vec![3.0]);
490
491        // Add a small delay to allow time_in_cache to accumulate
492        std::thread::sleep(Duration::from_millis(1));
493
494        // Access key1 multiple times
495        for _ in 0..5 {
496            cache.get("key1");
497        }
498
499        // Access key2 a few times
500        for _ in 0..2 {
501            cache.get("key2");
502        }
503
504        let hot_keys = cache.get_hot_keys(2);
505        assert_eq!(hot_keys.len(), 2);
506        // key1 should be in the hot keys (either first or second)
507        assert!(hot_keys.contains(&"key1".to_string()));
508        assert!(hot_keys.contains(&"key2".to_string()));
509    }
510}