Skip to main content

engram/embedding/
cache.rs

1//! Embedding cache with zero-copy sharing via Arc<[f32]>
2//!
3//! This cache provides efficient storage and retrieval of embeddings with:
4//! - LRU eviction policy
5//! - Bytes-based capacity (not entry count)
6//! - Zero-copy sharing via Arc<[f32]>
7//! - Thread-safe access with atomic hit/miss counters
8//!
9//! Based on Fix 10 from the design plan:
10//! > Use Arc<[f32]> for zero-copy sharing instead of cloning vectors
11
12use std::collections::HashMap;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::{Arc, Mutex};
15
16/// Statistics for the embedding cache
17#[derive(Debug, Clone)]
18pub struct EmbeddingCacheStats {
19    /// Number of cache hits
20    pub hits: u64,
21    /// Number of cache misses
22    pub misses: u64,
23    /// Current number of entries in cache
24    pub entries: usize,
25    /// Current bytes used by embeddings
26    pub bytes_used: usize,
27    /// Maximum bytes capacity
28    pub max_bytes: usize,
29    /// Hit rate as percentage (0.0 - 100.0)
30    pub hit_rate: f64,
31}
32
33/// LRU node for tracking access order
34struct LruNode {
35    /// The embedding data (shared via Arc)
36    embedding: Arc<[f32]>,
37    /// Size in bytes
38    size_bytes: usize,
39    /// Previous key in LRU order (more recently used)
40    prev: Option<String>,
41    /// Next key in LRU order (less recently used)
42    next: Option<String>,
43}
44
45/// Internal cache state protected by mutex
46struct CacheState {
47    /// Key -> LRU node mapping
48    entries: HashMap<String, LruNode>,
49    /// Most recently used key
50    head: Option<String>,
51    /// Least recently used key
52    tail: Option<String>,
53    /// Current bytes used
54    bytes_used: usize,
55}
56
57impl CacheState {
58    fn new() -> Self {
59        Self {
60            entries: HashMap::new(),
61            head: None,
62            tail: None,
63            bytes_used: 0,
64        }
65    }
66
67    /// Move a key to the front (most recently used)
68    fn move_to_front(&mut self, key: &str) {
69        if self.head.as_deref() == Some(key) {
70            return; // Already at front
71        }
72
73        // Remove from current position
74        if let Some(node) = self.entries.get(key) {
75            let prev = node.prev.clone();
76            let next = node.next.clone();
77
78            // Update neighbors
79            if let Some(ref prev_key) = prev {
80                if let Some(prev_node) = self.entries.get_mut(prev_key) {
81                    prev_node.next = next.clone();
82                }
83            }
84            if let Some(ref next_key) = next {
85                if let Some(next_node) = self.entries.get_mut(next_key) {
86                    next_node.prev = prev.clone();
87                }
88            }
89
90            // Update tail if needed
91            if self.tail.as_deref() == Some(key) {
92                self.tail = prev;
93            }
94        }
95
96        // Insert at front
97        if let Some(node) = self.entries.get_mut(key) {
98            node.prev = None;
99            node.next = self.head.clone();
100        }
101
102        if let Some(ref old_head) = self.head {
103            if let Some(head_node) = self.entries.get_mut(old_head) {
104                head_node.prev = Some(key.to_string());
105            }
106        }
107
108        self.head = Some(key.to_string());
109
110        if self.tail.is_none() {
111            self.tail = self.head.clone();
112        }
113    }
114
115    /// Remove the least recently used entry and return its size
116    fn evict_lru(&mut self) -> Option<usize> {
117        let tail_key = self.tail.take()?;
118
119        if let Some(node) = self.entries.remove(&tail_key) {
120            // Update new tail
121            self.tail = node.prev.clone();
122            if let Some(ref new_tail_key) = self.tail {
123                if let Some(new_tail) = self.entries.get_mut(new_tail_key) {
124                    new_tail.next = None;
125                }
126            }
127
128            // Clear head if this was the only entry
129            if self.head.as_deref() == Some(&tail_key) {
130                self.head = None;
131            }
132
133            self.bytes_used -= node.size_bytes;
134            return Some(node.size_bytes);
135        }
136
137        None
138    }
139}
140
141/// Thread-safe LRU embedding cache with bytes-based capacity
142pub struct EmbeddingCache {
143    /// Cache state protected by mutex
144    state: Mutex<CacheState>,
145    /// Maximum capacity in bytes
146    max_bytes: usize,
147    /// Atomic hit counter
148    hits: AtomicU64,
149    /// Atomic miss counter
150    misses: AtomicU64,
151}
152
153impl EmbeddingCache {
154    /// Create a new cache with the specified byte capacity
155    ///
156    /// # Arguments
157    /// - `max_bytes`: Maximum bytes to use for embeddings
158    ///   - Default recommendation: 100MB (~25K embeddings @ 1536 dims)
159    ///   - Each 1536-dim embedding uses 6144 bytes (1536 * 4)
160    pub fn new(max_bytes: usize) -> Self {
161        Self {
162            state: Mutex::new(CacheState::new()),
163            max_bytes,
164            hits: AtomicU64::new(0),
165            misses: AtomicU64::new(0),
166        }
167    }
168
169    /// Create a cache with default capacity (100MB)
170    pub fn default_capacity() -> Self {
171        Self::new(100 * 1024 * 1024) // 100MB
172    }
173
174    /// Get an embedding from the cache
175    ///
176    /// Returns Arc clone (cheap pointer copy, not vector copy)
177    pub fn get(&self, key: &str) -> Option<Arc<[f32]>> {
178        let mut state = self.state.lock().unwrap();
179
180        if state.entries.contains_key(key) {
181            state.move_to_front(key);
182            self.hits.fetch_add(1, Ordering::Relaxed);
183            state.entries.get(key).map(|n| n.embedding.clone())
184        } else {
185            self.misses.fetch_add(1, Ordering::Relaxed);
186            None
187        }
188    }
189
190    /// Insert an embedding into the cache
191    ///
192    /// If the key already exists, the embedding is updated and moved to front.
193    /// If capacity is exceeded, least recently used entries are evicted.
194    pub fn put(&self, key: String, embedding: Vec<f32>) {
195        let size_bytes = embedding.len() * std::mem::size_of::<f32>();
196
197        // Don't cache if single entry exceeds capacity
198        if size_bytes > self.max_bytes {
199            return;
200        }
201
202        let arc: Arc<[f32]> = embedding.into();
203        let mut state = self.state.lock().unwrap();
204
205        // Remove existing entry if present
206        if let Some(old_node) = state.entries.remove(&key) {
207            state.bytes_used -= old_node.size_bytes;
208
209            // Update LRU links for removed node
210            if let Some(ref prev_key) = old_node.prev {
211                if let Some(prev_node) = state.entries.get_mut(prev_key) {
212                    prev_node.next = old_node.next.clone();
213                }
214            }
215            if let Some(ref next_key) = old_node.next {
216                if let Some(next_node) = state.entries.get_mut(next_key) {
217                    next_node.prev = old_node.prev.clone();
218                }
219            }
220            if state.head.as_deref() == Some(&key) {
221                state.head = old_node.next.clone();
222            }
223            if state.tail.as_deref() == Some(&key) {
224                state.tail = old_node.prev.clone();
225            }
226        }
227
228        // Evict until we have room
229        while state.bytes_used + size_bytes > self.max_bytes {
230            if state.evict_lru().is_none() {
231                break;
232            }
233        }
234
235        // Insert new entry at front
236        let old_head = state.head.clone();
237        let node = LruNode {
238            embedding: arc,
239            size_bytes,
240            prev: None,
241            next: old_head.clone(),
242        };
243
244        // Update old head's prev pointer
245        if let Some(ref old_head_key) = old_head {
246            if let Some(head_node) = state.entries.get_mut(old_head_key) {
247                head_node.prev = Some(key.clone());
248            }
249        }
250
251        state.entries.insert(key.clone(), node);
252        state.bytes_used += size_bytes;
253        state.head = Some(key);
254
255        if state.tail.is_none() {
256            state.tail = state.head.clone();
257        }
258    }
259
260    /// Get cache statistics
261    pub fn stats(&self) -> EmbeddingCacheStats {
262        let state = self.state.lock().unwrap();
263        let hits = self.hits.load(Ordering::Relaxed);
264        let misses = self.misses.load(Ordering::Relaxed);
265        let total = hits + misses;
266
267        EmbeddingCacheStats {
268            hits,
269            misses,
270            entries: state.entries.len(),
271            bytes_used: state.bytes_used,
272            max_bytes: self.max_bytes,
273            hit_rate: if total > 0 {
274                (hits as f64 / total as f64) * 100.0
275            } else {
276                0.0
277            },
278        }
279    }
280
281    /// Clear all entries from the cache
282    pub fn clear(&self) {
283        let mut state = self.state.lock().unwrap();
284        state.entries.clear();
285        state.head = None;
286        state.tail = None;
287        state.bytes_used = 0;
288        // Note: We don't reset hit/miss counters - they're cumulative stats
289    }
290
291    /// Get the number of entries in the cache
292    pub fn len(&self) -> usize {
293        self.state.lock().unwrap().entries.len()
294    }
295
296    /// Check if the cache is empty
297    pub fn is_empty(&self) -> bool {
298        self.len() == 0
299    }
300}
301
302impl Default for EmbeddingCache {
303    fn default() -> Self {
304        Self::default_capacity()
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_basic_operations() {
314        let cache = EmbeddingCache::new(1024 * 1024); // 1MB
315
316        // Insert and retrieve
317        let embedding = vec![1.0, 2.0, 3.0];
318        cache.put("test-key".to_string(), embedding.clone());
319
320        let retrieved = cache.get("test-key").unwrap();
321        assert_eq!(&*retrieved, &[1.0, 2.0, 3.0]);
322
323        // Miss
324        assert!(cache.get("nonexistent").is_none());
325
326        // Stats
327        let stats = cache.stats();
328        assert_eq!(stats.hits, 1);
329        assert_eq!(stats.misses, 1);
330        assert_eq!(stats.entries, 1);
331    }
332
333    #[test]
334    fn test_lru_eviction() {
335        // Small cache: 48 bytes = room for 4 f32s (16 bytes) * 3 entries max
336        let cache = EmbeddingCache::new(48);
337
338        // Insert 3 entries (each 16 bytes = 4 * 4)
339        cache.put("a".to_string(), vec![1.0, 2.0, 3.0, 4.0]);
340        cache.put("b".to_string(), vec![5.0, 6.0, 7.0, 8.0]);
341        cache.put("c".to_string(), vec![9.0, 10.0, 11.0, 12.0]);
342
343        assert_eq!(cache.len(), 3);
344
345        // Insert 4th entry, should evict "a" (LRU)
346        cache.put("d".to_string(), vec![13.0, 14.0, 15.0, 16.0]);
347
348        assert_eq!(cache.len(), 3);
349        assert!(cache.get("a").is_none()); // Evicted
350        assert!(cache.get("b").is_some());
351        assert!(cache.get("c").is_some());
352        assert!(cache.get("d").is_some());
353    }
354
355    #[test]
356    fn test_access_updates_lru() {
357        // Room for 2 entries only
358        let cache = EmbeddingCache::new(32);
359
360        cache.put("a".to_string(), vec![1.0, 2.0, 3.0, 4.0]);
361        cache.put("b".to_string(), vec![5.0, 6.0, 7.0, 8.0]);
362
363        // Access "a" to make it recently used
364        let _ = cache.get("a");
365
366        // Insert "c", should evict "b" (now LRU) instead of "a"
367        cache.put("c".to_string(), vec![9.0, 10.0, 11.0, 12.0]);
368
369        assert!(cache.get("a").is_some()); // Still present
370        assert!(cache.get("b").is_none()); // Evicted
371        assert!(cache.get("c").is_some());
372    }
373
374    #[test]
375    fn test_clear() {
376        let cache = EmbeddingCache::new(1024 * 1024);
377
378        cache.put("a".to_string(), vec![1.0, 2.0, 3.0]);
379        cache.put("b".to_string(), vec![4.0, 5.0, 6.0]);
380
381        assert_eq!(cache.len(), 2);
382
383        cache.clear();
384
385        assert_eq!(cache.len(), 0);
386        assert!(cache.get("a").is_none());
387        assert!(cache.get("b").is_none());
388
389        let stats = cache.stats();
390        assert_eq!(stats.entries, 0);
391        assert_eq!(stats.bytes_used, 0);
392    }
393
394    #[test]
395    fn test_update_existing() {
396        let cache = EmbeddingCache::new(1024 * 1024);
397
398        cache.put("key".to_string(), vec![1.0, 2.0, 3.0]);
399        let v1 = cache.get("key").unwrap();
400        assert_eq!(&*v1, &[1.0, 2.0, 3.0]);
401
402        // Update with new value
403        cache.put("key".to_string(), vec![4.0, 5.0, 6.0, 7.0]);
404        let v2 = cache.get("key").unwrap();
405        assert_eq!(&*v2, &[4.0, 5.0, 6.0, 7.0]);
406
407        assert_eq!(cache.len(), 1);
408    }
409
410    #[test]
411    fn test_zero_copy() {
412        let cache = EmbeddingCache::new(1024 * 1024);
413
414        cache.put("key".to_string(), vec![1.0, 2.0, 3.0]);
415
416        // Get multiple references - should be Arc clones (cheap)
417        let ref1 = cache.get("key").unwrap();
418        let ref2 = cache.get("key").unwrap();
419
420        // Both point to same data
421        assert!(Arc::ptr_eq(&ref1, &ref2));
422    }
423
424    #[test]
425    fn test_stats_tracking() {
426        let cache = EmbeddingCache::new(1024 * 1024);
427
428        // Initial stats
429        let stats = cache.stats();
430        assert_eq!(stats.hits, 0);
431        assert_eq!(stats.misses, 0);
432        assert_eq!(stats.hit_rate, 0.0);
433
434        cache.put("a".to_string(), vec![1.0, 2.0]);
435
436        // Hit
437        cache.get("a");
438        // Miss
439        cache.get("nonexistent");
440        // Hit
441        cache.get("a");
442
443        let stats = cache.stats();
444        assert_eq!(stats.hits, 2);
445        assert_eq!(stats.misses, 1);
446        assert!((stats.hit_rate - 66.666).abs() < 1.0);
447    }
448}