Skip to main content

oxirs_embed/
embedding_cache.rs

1//! LRU embedding cache with memory-bounded eviction and per-model invalidation.
2
3use std::collections::HashMap;
4
5// ──────────────────────────────────────────────────────────────────────────────
6// Types
7// ──────────────────────────────────────────────────────────────────────────────
8
9/// Cache key: content hash + model identifier.
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct CacheKey {
12    /// Hash of the input content (e.g. FNV-1a or SipHash).
13    pub content_hash: u64,
14    /// Identifier of the embedding model used.
15    pub model_id: String,
16}
17
18impl CacheKey {
19    /// Create a cache key.
20    pub fn new(content_hash: u64, model_id: impl Into<String>) -> Self {
21        Self {
22            content_hash,
23            model_id: model_id.into(),
24        }
25    }
26}
27
28/// A single cached embedding entry.
29#[derive(Debug, Clone)]
30pub struct CacheEntry {
31    /// The embedding vector.
32    pub embedding: Vec<f32>,
33    /// Number of times this entry has been accessed (cache hits).
34    pub access_count: u64,
35    /// Approximate size in bytes (4 × dimensions).
36    pub size_bytes: usize,
37}
38
39impl CacheEntry {
40    fn new(embedding: Vec<f32>) -> Self {
41        let size_bytes = embedding.len() * std::mem::size_of::<f32>();
42        Self {
43            embedding,
44            access_count: 0,
45            size_bytes,
46        }
47    }
48}
49
50/// Cache statistics snapshot.
51#[derive(Debug, Clone, Default)]
52pub struct CacheStats {
53    /// Total successful lookups.
54    pub hits: u64,
55    /// Total unsuccessful lookups.
56    pub misses: u64,
57    /// Total entries evicted due to capacity or memory limits.
58    pub evictions: u64,
59    /// `hits / (hits + misses)`, or 0.0 when no lookups have occurred.
60    pub hit_rate: f64,
61    /// Sum of `size_bytes` across all entries currently held.
62    pub total_size_bytes: usize,
63}
64
65// ──────────────────────────────────────────────────────────────────────────────
66// Internal LRU node list
67// ──────────────────────────────────────────────────────────────────────────────
68
69/// An intrusive doubly-linked list node.
70struct LruNode {
71    key: CacheKey,
72    entry: CacheEntry,
73    prev: Option<usize>, // index into slab
74    next: Option<usize>,
75}
76
77/// Simple slab-allocated doubly-linked list for LRU tracking.
78///
79/// `head` is the most-recently-used end; `tail` is the least-recently-used.
80struct LruList {
81    nodes: Vec<Option<LruNode>>,
82    free: Vec<usize>,
83    head: Option<usize>,
84    tail: Option<usize>,
85}
86
87impl LruList {
88    fn new(capacity: usize) -> Self {
89        Self {
90            nodes: (0..capacity).map(|_| None).collect(),
91            free: (0..capacity).rev().collect(),
92            head: None,
93            tail: None,
94        }
95    }
96
97    fn allocate(&mut self, key: CacheKey, entry: CacheEntry) -> Option<usize> {
98        let idx = self.free.pop()?;
99        self.nodes[idx] = Some(LruNode {
100            key,
101            entry,
102            prev: None,
103            next: None,
104        });
105        Some(idx)
106    }
107
108    fn reclaim(&mut self, idx: usize) {
109        self.nodes[idx] = None;
110        self.free.push(idx);
111    }
112
113    /// Move `idx` to the head (most-recently-used position).
114    fn move_to_head(&mut self, idx: usize) {
115        self.detach(idx);
116        self.attach_head(idx);
117    }
118
119    fn detach(&mut self, idx: usize) {
120        let (prev, next) = {
121            let node = self.nodes[idx].as_ref().expect("node must exist");
122            (node.prev, node.next)
123        };
124        if let Some(p) = prev {
125            self.nodes[p].as_mut().expect("prev must exist").next = next;
126        } else {
127            self.head = next;
128        }
129        if let Some(n) = next {
130            self.nodes[n].as_mut().expect("next must exist").prev = prev;
131        } else {
132            self.tail = prev;
133        }
134        let node = self.nodes[idx].as_mut().expect("node must exist");
135        node.prev = None;
136        node.next = None;
137    }
138
139    fn attach_head(&mut self, idx: usize) {
140        let node = self.nodes[idx].as_mut().expect("node must exist");
141        node.next = self.head;
142        node.prev = None;
143        let old_head = self.head;
144        self.head = Some(idx);
145        if let Some(h) = old_head {
146            self.nodes[h].as_mut().expect("old head must exist").prev = Some(idx);
147        } else {
148            self.tail = Some(idx);
149        }
150    }
151
152    /// Remove and return the LRU (tail) node's key and entry.
153    fn evict_lru(&mut self) -> Option<(CacheKey, CacheEntry)> {
154        let tail_idx = self.tail?;
155        self.detach(tail_idx);
156        let node = self.nodes[tail_idx].take().expect("tail node must exist");
157        self.free.push(tail_idx);
158        Some((node.key, node.entry))
159    }
160}
161
162// ──────────────────────────────────────────────────────────────────────────────
163// EmbeddingCache
164// ──────────────────────────────────────────────────────────────────────────────
165
166/// LRU embedding cache with a fixed entry-count capacity.
167pub struct EmbeddingCache {
168    capacity: usize,
169    map: HashMap<CacheKey, usize>, // key → slab index
170    list: LruList,
171    hits: u64,
172    misses: u64,
173    evictions: u64,
174    total_size_bytes: usize,
175}
176
177impl EmbeddingCache {
178    /// Create a cache that holds at most `capacity` embeddings.
179    ///
180    /// Capacity is clamped to at least 1.
181    pub fn new(capacity: usize) -> Self {
182        let capacity = capacity.max(1);
183        Self {
184            capacity,
185            map: HashMap::with_capacity(capacity),
186            list: LruList::new(capacity),
187            hits: 0,
188            misses: 0,
189            evictions: 0,
190            total_size_bytes: 0,
191        }
192    }
193
194    /// Look up an embedding. Promotes the entry to MRU on hit.
195    pub fn get(&mut self, key: &CacheKey) -> Option<&[f32]> {
196        if let Some(&idx) = self.map.get(key) {
197            self.list.move_to_head(idx);
198            self.hits += 1;
199            let node = self.list.nodes[idx].as_mut().expect("node must exist");
200            node.entry.access_count += 1;
201            // Safety: we return a reference into the node which lives inside the
202            // Vec<Option<LruNode>>. We need to go through the map again to avoid
203            // holding `&mut self`.
204            let idx2 = *self.map.get(key).expect("just inserted");
205            let node2 = self.list.nodes[idx2].as_ref().expect("node must exist");
206            return Some(&node2.entry.embedding);
207        }
208        self.misses += 1;
209        None
210    }
211
212    /// Insert an embedding. Evicts the LRU entry when at capacity.
213    pub fn insert(&mut self, key: CacheKey, embedding: Vec<f32>) {
214        // If already present, update in place.
215        if let Some(&idx) = self.map.get(&key) {
216            let node = self.list.nodes[idx].as_mut().expect("node must exist");
217            let old_size = node.entry.size_bytes;
218            node.entry = CacheEntry::new(embedding);
219            let new_size = node.entry.size_bytes;
220            self.total_size_bytes = self.total_size_bytes - old_size + new_size;
221            self.list.move_to_head(idx);
222            return;
223        }
224
225        // Evict LRU if full.
226        if self.map.len() >= self.capacity {
227            if let Some((evicted_key, evicted_entry)) = self.list.evict_lru() {
228                self.total_size_bytes -= evicted_entry.size_bytes;
229                self.map.remove(&evicted_key);
230                self.evictions += 1;
231            }
232        }
233
234        let entry = CacheEntry::new(embedding);
235        self.total_size_bytes += entry.size_bytes;
236
237        if let Some(idx) = self.list.allocate(key.clone(), entry) {
238            self.list.attach_head(idx);
239            self.map.insert(key, idx);
240        }
241    }
242
243    /// Manually evict the LRU entry.
244    pub fn evict_lru(&mut self) -> Option<(CacheKey, CacheEntry)> {
245        let (k, e) = self.list.evict_lru()?;
246        self.map.remove(&k);
247        self.total_size_bytes -= e.size_bytes;
248        self.evictions += 1;
249        Some((k, e))
250    }
251
252    /// Remove a specific entry. Returns `true` if it was present.
253    pub fn invalidate(&mut self, key: &CacheKey) -> bool {
254        if let Some(idx) = self.map.remove(key) {
255            self.list.detach(idx);
256            let node = self.list.nodes[idx].take().expect("node must exist");
257            self.list.free.push(idx);
258            self.total_size_bytes -= node.entry.size_bytes;
259            true
260        } else {
261            false
262        }
263    }
264
265    /// Remove all entries for a given model.
266    pub fn invalidate_model(&mut self, model_id: &str) -> usize {
267        let keys_to_remove: Vec<CacheKey> = self
268            .map
269            .keys()
270            .filter(|k| k.model_id == model_id)
271            .cloned()
272            .collect();
273        let count = keys_to_remove.len();
274        for key in keys_to_remove {
275            self.invalidate(&key);
276        }
277        count
278    }
279
280    /// Current statistics snapshot.
281    pub fn stats(&self) -> CacheStats {
282        let total = self.hits + self.misses;
283        let hit_rate = if total == 0 {
284            0.0
285        } else {
286            self.hits as f64 / total as f64
287        };
288        CacheStats {
289            hits: self.hits,
290            misses: self.misses,
291            evictions: self.evictions,
292            hit_rate,
293            total_size_bytes: self.total_size_bytes,
294        }
295    }
296
297    /// Number of entries currently in the cache.
298    pub fn len(&self) -> usize {
299        self.map.len()
300    }
301
302    /// `true` when the cache is empty.
303    pub fn is_empty(&self) -> bool {
304        self.map.is_empty()
305    }
306
307    /// Maximum entry capacity.
308    pub fn capacity(&self) -> usize {
309        self.capacity
310    }
311}
312
313// ──────────────────────────────────────────────────────────────────────────────
314// MemoryBoundedCache
315// ──────────────────────────────────────────────────────────────────────────────
316
317/// An `EmbeddingCache` variant that evicts entries when total memory exceeds a
318/// byte limit.
319pub struct MemoryBoundedCache {
320    inner: EmbeddingCache,
321    max_bytes: usize,
322}
323
324impl MemoryBoundedCache {
325    /// Create a memory-bounded cache with a maximum of `max_bytes`.
326    ///
327    /// The internal entry capacity is set to a generous upper bound so that the
328    /// byte limit is always the binding constraint.
329    pub fn new(max_bytes: usize) -> Self {
330        // Assume max 4 bytes/dim × 1024 dims → 4 KiB per entry.
331        // Use a large entry-count capacity so byte limit governs.
332        let capacity = (max_bytes / (4 * 128)).max(4);
333        Self {
334            inner: EmbeddingCache::new(capacity),
335            max_bytes,
336        }
337    }
338
339    /// Insert an embedding, evicting LRU entries until within the memory limit.
340    pub fn insert(&mut self, key: CacheKey, embedding: Vec<f32>) {
341        self.inner.insert(key, embedding);
342        // Evict until within bounds.
343        while self.inner.total_size_bytes > self.max_bytes {
344            if self.inner.evict_lru().is_none() {
345                break;
346            }
347        }
348    }
349
350    /// Delegate to inner cache.
351    pub fn get(&mut self, key: &CacheKey) -> Option<&[f32]> {
352        self.inner.get(key)
353    }
354
355    /// Current byte usage.
356    pub fn total_size_bytes(&self) -> usize {
357        self.inner.total_size_bytes
358    }
359
360    /// Maximum byte limit.
361    pub fn max_bytes(&self) -> usize {
362        self.max_bytes
363    }
364
365    /// Statistics.
366    pub fn stats(&self) -> CacheStats {
367        self.inner.stats()
368    }
369
370    /// Number of entries.
371    pub fn len(&self) -> usize {
372        self.inner.len()
373    }
374
375    /// True when empty.
376    pub fn is_empty(&self) -> bool {
377        self.inner.is_empty()
378    }
379}
380
381// ──────────────────────────────────────────────────────────────────────────────
382// Tests
383// ──────────────────────────────────────────────────────────────────────────────
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    fn key(hash: u64, model: &str) -> CacheKey {
390        CacheKey::new(hash, model)
391    }
392
393    fn emb(dims: usize) -> Vec<f32> {
394        (0..dims).map(|i| i as f32 * 0.1).collect()
395    }
396
397    // ── CacheKey ─────────────────────────────────────────────────────────────
398
399    #[test]
400    fn test_cache_key_equality() {
401        assert_eq!(key(1, "model-a"), key(1, "model-a"));
402        assert_ne!(key(1, "model-a"), key(1, "model-b"));
403        assert_ne!(key(1, "model-a"), key(2, "model-a"));
404    }
405
406    #[test]
407    fn test_cache_key_clone() {
408        let k = key(42, "bert");
409        let k2 = k.clone();
410        assert_eq!(k, k2);
411    }
412
413    // ── CacheEntry ────────────────────────────────────────────────────────────
414
415    #[test]
416    fn test_cache_entry_size_bytes() {
417        let e = CacheEntry::new(vec![0.0f32; 128]);
418        assert_eq!(e.size_bytes, 128 * 4);
419    }
420
421    #[test]
422    fn test_cache_entry_access_count_starts_zero() {
423        let e = CacheEntry::new(vec![1.0; 4]);
424        assert_eq!(e.access_count, 0);
425    }
426
427    // ── EmbeddingCache basic ──────────────────────────────────────────────────
428
429    #[test]
430    fn test_cache_empty_initially() {
431        let c = EmbeddingCache::new(10);
432        assert!(c.is_empty());
433        assert_eq!(c.len(), 0);
434    }
435
436    #[test]
437    fn test_cache_capacity() {
438        let c = EmbeddingCache::new(5);
439        assert_eq!(c.capacity(), 5);
440    }
441
442    #[test]
443    fn test_cache_capacity_min_one() {
444        let c = EmbeddingCache::new(0);
445        assert_eq!(c.capacity(), 1);
446    }
447
448    #[test]
449    fn test_cache_insert_and_get() {
450        let mut c = EmbeddingCache::new(10);
451        let k = key(1, "m");
452        let e = emb(4);
453        c.insert(k.clone(), e.clone());
454        let got = c.get(&k).expect("should be cached");
455        assert_eq!(got, e.as_slice());
456    }
457
458    #[test]
459    fn test_cache_miss() {
460        let mut c = EmbeddingCache::new(10);
461        assert!(c.get(&key(99, "m")).is_none());
462    }
463
464    #[test]
465    fn test_cache_len_after_insert() {
466        let mut c = EmbeddingCache::new(10);
467        c.insert(key(1, "m"), emb(4));
468        c.insert(key(2, "m"), emb(4));
469        assert_eq!(c.len(), 2);
470    }
471
472    #[test]
473    fn test_cache_stats_hits_misses() {
474        let mut c = EmbeddingCache::new(10);
475        c.insert(key(1, "m"), emb(4));
476        c.get(&key(1, "m")); // hit
477        c.get(&key(99, "m")); // miss
478        let s = c.stats();
479        assert_eq!(s.hits, 1);
480        assert_eq!(s.misses, 1);
481        assert!((s.hit_rate - 0.5).abs() < 1e-9);
482    }
483
484    #[test]
485    fn test_cache_stats_no_lookups_hit_rate_zero() {
486        let c = EmbeddingCache::new(10);
487        assert_eq!(c.stats().hit_rate, 0.0);
488    }
489
490    #[test]
491    fn test_cache_stats_total_size_bytes() {
492        let mut c = EmbeddingCache::new(10);
493        c.insert(key(1, "m"), vec![0.0f32; 64]);
494        c.insert(key(2, "m"), vec![0.0f32; 128]);
495        assert_eq!(c.stats().total_size_bytes, (64 + 128) * 4);
496    }
497
498    // ── LRU eviction ─────────────────────────────────────────────────────────
499
500    #[test]
501    fn test_cache_evicts_lru_when_full() {
502        let mut c = EmbeddingCache::new(3);
503        c.insert(key(1, "m"), emb(4));
504        c.insert(key(2, "m"), emb(4));
505        c.insert(key(3, "m"), emb(4));
506        // key(1) is LRU; inserting key(4) should evict it
507        c.insert(key(4, "m"), emb(4));
508        assert!(c.get(&key(1, "m")).is_none(), "key(1) should be evicted");
509        assert!(c.get(&key(4, "m")).is_some());
510    }
511
512    #[test]
513    fn test_cache_get_promotes_to_mru() {
514        let mut c = EmbeddingCache::new(3);
515        c.insert(key(1, "m"), emb(4));
516        c.insert(key(2, "m"), emb(4));
517        c.insert(key(3, "m"), emb(4));
518        // Promote key(1) so key(2) becomes LRU
519        c.get(&key(1, "m"));
520        c.insert(key(4, "m"), emb(4));
521        assert!(c.get(&key(2, "m")).is_none(), "key(2) should be evicted");
522        assert!(c.get(&key(1, "m")).is_some());
523    }
524
525    #[test]
526    fn test_cache_evictions_stat_incremented() {
527        let mut c = EmbeddingCache::new(2);
528        c.insert(key(1, "m"), emb(4));
529        c.insert(key(2, "m"), emb(4));
530        c.insert(key(3, "m"), emb(4)); // evicts key(1)
531        assert_eq!(c.stats().evictions, 1);
532    }
533
534    #[test]
535    fn test_cache_manual_evict_lru() {
536        let mut c = EmbeddingCache::new(3);
537        c.insert(key(1, "m"), emb(4));
538        c.insert(key(2, "m"), emb(4));
539        let (evicted_key, _) = c.evict_lru().expect("should evict");
540        // The LRU is key(1) since key(2) was inserted last.
541        assert_eq!(evicted_key, key(1, "m"));
542        assert_eq!(c.len(), 1);
543    }
544
545    #[test]
546    fn test_cache_manual_evict_lru_empty() {
547        let mut c = EmbeddingCache::new(3);
548        assert!(c.evict_lru().is_none());
549    }
550
551    // ── Invalidation ─────────────────────────────────────────────────────────
552
553    #[test]
554    fn test_cache_invalidate_present() {
555        let mut c = EmbeddingCache::new(10);
556        c.insert(key(1, "m"), emb(4));
557        assert!(c.invalidate(&key(1, "m")));
558        assert!(c.is_empty());
559    }
560
561    #[test]
562    fn test_cache_invalidate_absent() {
563        let mut c = EmbeddingCache::new(10);
564        assert!(!c.invalidate(&key(99, "m")));
565    }
566
567    #[test]
568    fn test_cache_invalidate_model() {
569        let mut c = EmbeddingCache::new(10);
570        c.insert(key(1, "bert"), emb(4));
571        c.insert(key(2, "bert"), emb(4));
572        c.insert(key(3, "gpt"), emb(4));
573        let removed = c.invalidate_model("bert");
574        assert_eq!(removed, 2);
575        assert!(c.get(&key(1, "bert")).is_none());
576        assert!(c.get(&key(2, "bert")).is_none());
577        assert!(c.get(&key(3, "gpt")).is_some());
578    }
579
580    #[test]
581    fn test_cache_invalidate_model_none() {
582        let mut c = EmbeddingCache::new(10);
583        assert_eq!(c.invalidate_model("unknown"), 0);
584    }
585
586    #[test]
587    fn test_cache_size_decreases_on_invalidate() {
588        let mut c = EmbeddingCache::new(10);
589        c.insert(key(1, "m"), vec![0.0f32; 64]);
590        let before = c.stats().total_size_bytes;
591        c.invalidate(&key(1, "m"));
592        assert_eq!(c.stats().total_size_bytes, before - 64 * 4);
593    }
594
595    // ── Update in-place ───────────────────────────────────────────────────────
596
597    #[test]
598    fn test_cache_insert_same_key_updates() {
599        let mut c = EmbeddingCache::new(10);
600        c.insert(key(1, "m"), emb(4));
601        c.insert(key(1, "m"), vec![99.0; 4]);
602        let got = c.get(&key(1, "m")).expect("should exist");
603        assert_eq!(got[0], 99.0);
604        assert_eq!(c.len(), 1);
605    }
606
607    // ── MemoryBoundedCache ────────────────────────────────────────────────────
608
609    #[test]
610    fn test_memory_bounded_cache_empty() {
611        let c = MemoryBoundedCache::new(1024);
612        assert!(c.is_empty());
613        assert_eq!(c.total_size_bytes(), 0);
614    }
615
616    #[test]
617    fn test_memory_bounded_cache_max_bytes() {
618        let c = MemoryBoundedCache::new(4096);
619        assert_eq!(c.max_bytes(), 4096);
620    }
621
622    #[test]
623    fn test_memory_bounded_cache_stays_within_limit() {
624        // 512 bytes limit; each 64-dim embedding = 256 bytes
625        let mut c = MemoryBoundedCache::new(512);
626        for i in 0..10u64 {
627            c.insert(key(i, "m"), vec![0.0f32; 64]);
628        }
629        assert!(c.total_size_bytes() <= 512);
630    }
631
632    #[test]
633    fn test_memory_bounded_cache_insert_and_get() {
634        let mut c = MemoryBoundedCache::new(1 << 20); // 1 MiB
635        let k = key(1, "m");
636        c.insert(k.clone(), vec![1.0; 128]);
637        let got = c.get(&k).expect("should be present");
638        assert_eq!(got.len(), 128);
639        assert!((got[0] - 1.0).abs() < 1e-9);
640    }
641
642    #[test]
643    fn test_memory_bounded_cache_stats() {
644        let mut c = MemoryBoundedCache::new(1 << 20);
645        c.insert(key(1, "m"), vec![0.0; 32]);
646        c.get(&key(1, "m")); // hit
647        c.get(&key(2, "m")); // miss
648        let s = c.stats();
649        assert_eq!(s.hits, 1);
650        assert_eq!(s.misses, 1);
651    }
652
653    // ── Stress ────────────────────────────────────────────────────────────────
654
655    #[test]
656    fn test_cache_stress_insert_get() {
657        let mut c = EmbeddingCache::new(100);
658        for i in 0u64..200 {
659            c.insert(key(i, "m"), emb(32));
660        }
661        assert_eq!(c.len(), 100);
662        let s = c.stats();
663        assert_eq!(s.evictions, 100);
664    }
665
666    #[test]
667    fn test_cache_all_hits() {
668        let mut c = EmbeddingCache::new(50);
669        for i in 0u64..50 {
670            c.insert(key(i, "m"), emb(8));
671        }
672        for i in 0u64..50 {
673            assert!(c.get(&key(i, "m")).is_some());
674        }
675        assert_eq!(c.stats().hits, 50);
676        assert_eq!(c.stats().misses, 0);
677        assert!((c.stats().hit_rate - 1.0).abs() < 1e-9);
678    }
679
680    // ── Round 7 additional tests ──────────────────────────────────────────────
681
682    #[test]
683    fn test_cache_access_count_increments_on_get() {
684        let mut c = EmbeddingCache::new(10);
685        c.insert(key(1, "m"), emb(4));
686        // Access the entry three times and verify via a fresh insert (same key).
687        c.get(&key(1, "m"));
688        c.get(&key(1, "m"));
689        c.get(&key(1, "m"));
690        // After three hits, stats.hits should reflect them.
691        assert_eq!(c.stats().hits, 3);
692    }
693
694    #[test]
695    fn test_cache_multiple_models_isolated() {
696        let mut c = EmbeddingCache::new(10);
697        c.insert(key(1, "bert"), emb(4));
698        c.insert(key(1, "gpt"), emb(8));
699        assert!(c.get(&key(1, "bert")).is_some());
700        assert!(c.get(&key(1, "gpt")).is_some());
701    }
702
703    #[test]
704    fn test_cache_key_hash_collision_different_models() {
705        // Same hash, different model → different keys.
706        let k1 = key(100, "modelA");
707        let k2 = key(100, "modelB");
708        assert_ne!(k1, k2);
709    }
710
711    #[test]
712    fn test_cache_evict_only_one_on_overflow() {
713        let mut c = EmbeddingCache::new(3);
714        c.insert(key(1, "m"), emb(4));
715        c.insert(key(2, "m"), emb(4));
716        c.insert(key(3, "m"), emb(4));
717        c.insert(key(4, "m"), emb(4)); // evicts key(1)
718        assert_eq!(c.len(), 3);
719    }
720
721    #[test]
722    fn test_cache_get_returns_correct_embedding() {
723        let mut c = EmbeddingCache::new(5);
724        let v = vec![1.0f32, 2.0, 3.0, 4.0];
725        c.insert(key(42, "m"), v.clone());
726        let got = c.get(&key(42, "m")).expect("should be present");
727        assert_eq!(got, v.as_slice());
728    }
729
730    #[test]
731    fn test_cache_is_not_empty_after_insert() {
732        let mut c = EmbeddingCache::new(5);
733        c.insert(key(1, "m"), emb(4));
734        assert!(!c.is_empty());
735    }
736
737    #[test]
738    fn test_cache_len_zero_initially() {
739        let c = EmbeddingCache::new(5);
740        assert_eq!(c.len(), 0);
741    }
742
743    #[test]
744    fn test_cache_invalidate_all_via_model() {
745        let mut c = EmbeddingCache::new(10);
746        for i in 0u64..5 {
747            c.insert(key(i, "bert"), emb(4));
748        }
749        c.insert(key(99, "gpt"), emb(4));
750        let removed = c.invalidate_model("bert");
751        assert_eq!(removed, 5);
752        assert_eq!(c.len(), 1); // only gpt remains
753    }
754
755    #[test]
756    fn test_cache_stats_evictions_multiple() {
757        let mut c = EmbeddingCache::new(2);
758        for i in 0u64..6 {
759            c.insert(key(i, "m"), emb(4));
760        }
761        // Each of the 4 inserts beyond capacity evicts one entry.
762        assert_eq!(c.stats().evictions, 4);
763    }
764
765    #[test]
766    fn test_cache_size_zero_after_all_invalidated() {
767        let mut c = EmbeddingCache::new(10);
768        c.insert(key(1, "m"), emb(32));
769        c.insert(key(2, "m"), emb(32));
770        c.invalidate(&key(1, "m"));
771        c.invalidate(&key(2, "m"));
772        assert_eq!(c.stats().total_size_bytes, 0);
773    }
774
775    #[test]
776    fn test_memory_bounded_cache_evicts_to_stay_within_limit() {
777        // 256 bytes = 1 entry of 64 dims * 4 bytes each.
778        let mut c = MemoryBoundedCache::new(256);
779        for i in 0u64..5 {
780            c.insert(key(i, "m"), vec![0.0f32; 64]);
781        }
782        assert!(c.total_size_bytes() <= 256);
783    }
784
785    #[test]
786    fn test_memory_bounded_cache_get_returns_none_for_missing() {
787        let mut c = MemoryBoundedCache::new(1024);
788        assert!(c.get(&key(99, "m")).is_none());
789    }
790
791    #[test]
792    fn test_memory_bounded_cache_len_tracks_inserts() {
793        let mut c = MemoryBoundedCache::new(1 << 20);
794        assert_eq!(c.len(), 0);
795        c.insert(key(1, "m"), emb(4));
796        assert_eq!(c.len(), 1);
797    }
798
799    #[test]
800    fn test_cache_insert_updates_size_correctly() {
801        let mut c = EmbeddingCache::new(10);
802        c.insert(key(1, "m"), vec![0.0f32; 10]);
803        assert_eq!(c.stats().total_size_bytes, 10 * 4);
804        c.insert(key(2, "m"), vec![0.0f32; 20]);
805        assert_eq!(c.stats().total_size_bytes, 30 * 4);
806    }
807}