1use std::collections::HashMap;
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct CacheKey {
12 pub content_hash: u64,
14 pub model_id: String,
16}
17
18impl CacheKey {
19 pub fn new(content_hash: u64, model_id: impl Into<String>) -> Self {
21 Self {
22 content_hash,
23 model_id: model_id.into(),
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
30pub struct CacheEntry {
31 pub embedding: Vec<f32>,
33 pub access_count: u64,
35 pub size_bytes: usize,
37}
38
39impl CacheEntry {
40 fn new(embedding: Vec<f32>) -> Self {
41 let size_bytes = embedding.len() * std::mem::size_of::<f32>();
42 Self {
43 embedding,
44 access_count: 0,
45 size_bytes,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Default)]
52pub struct CacheStats {
53 pub hits: u64,
55 pub misses: u64,
57 pub evictions: u64,
59 pub hit_rate: f64,
61 pub total_size_bytes: usize,
63}
64
65struct LruNode {
71 key: CacheKey,
72 entry: CacheEntry,
73 prev: Option<usize>, next: Option<usize>,
75}
76
77struct LruList {
81 nodes: Vec<Option<LruNode>>,
82 free: Vec<usize>,
83 head: Option<usize>,
84 tail: Option<usize>,
85}
86
87impl LruList {
88 fn new(capacity: usize) -> Self {
89 Self {
90 nodes: (0..capacity).map(|_| None).collect(),
91 free: (0..capacity).rev().collect(),
92 head: None,
93 tail: None,
94 }
95 }
96
97 fn allocate(&mut self, key: CacheKey, entry: CacheEntry) -> Option<usize> {
98 let idx = self.free.pop()?;
99 self.nodes[idx] = Some(LruNode {
100 key,
101 entry,
102 prev: None,
103 next: None,
104 });
105 Some(idx)
106 }
107
108 fn reclaim(&mut self, idx: usize) {
109 self.nodes[idx] = None;
110 self.free.push(idx);
111 }
112
113 fn move_to_head(&mut self, idx: usize) {
115 self.detach(idx);
116 self.attach_head(idx);
117 }
118
119 fn detach(&mut self, idx: usize) {
120 let (prev, next) = {
121 let node = self.nodes[idx].as_ref().expect("node must exist");
122 (node.prev, node.next)
123 };
124 if let Some(p) = prev {
125 self.nodes[p].as_mut().expect("prev must exist").next = next;
126 } else {
127 self.head = next;
128 }
129 if let Some(n) = next {
130 self.nodes[n].as_mut().expect("next must exist").prev = prev;
131 } else {
132 self.tail = prev;
133 }
134 let node = self.nodes[idx].as_mut().expect("node must exist");
135 node.prev = None;
136 node.next = None;
137 }
138
139 fn attach_head(&mut self, idx: usize) {
140 let node = self.nodes[idx].as_mut().expect("node must exist");
141 node.next = self.head;
142 node.prev = None;
143 let old_head = self.head;
144 self.head = Some(idx);
145 if let Some(h) = old_head {
146 self.nodes[h].as_mut().expect("old head must exist").prev = Some(idx);
147 } else {
148 self.tail = Some(idx);
149 }
150 }
151
152 fn evict_lru(&mut self) -> Option<(CacheKey, CacheEntry)> {
154 let tail_idx = self.tail?;
155 self.detach(tail_idx);
156 let node = self.nodes[tail_idx].take().expect("tail node must exist");
157 self.free.push(tail_idx);
158 Some((node.key, node.entry))
159 }
160}
161
162pub struct EmbeddingCache {
168 capacity: usize,
169 map: HashMap<CacheKey, usize>, list: LruList,
171 hits: u64,
172 misses: u64,
173 evictions: u64,
174 total_size_bytes: usize,
175}
176
177impl EmbeddingCache {
178 pub fn new(capacity: usize) -> Self {
182 let capacity = capacity.max(1);
183 Self {
184 capacity,
185 map: HashMap::with_capacity(capacity),
186 list: LruList::new(capacity),
187 hits: 0,
188 misses: 0,
189 evictions: 0,
190 total_size_bytes: 0,
191 }
192 }
193
194 pub fn get(&mut self, key: &CacheKey) -> Option<&[f32]> {
196 if let Some(&idx) = self.map.get(key) {
197 self.list.move_to_head(idx);
198 self.hits += 1;
199 let node = self.list.nodes[idx].as_mut().expect("node must exist");
200 node.entry.access_count += 1;
201 let idx2 = *self.map.get(key).expect("just inserted");
205 let node2 = self.list.nodes[idx2].as_ref().expect("node must exist");
206 return Some(&node2.entry.embedding);
207 }
208 self.misses += 1;
209 None
210 }
211
212 pub fn insert(&mut self, key: CacheKey, embedding: Vec<f32>) {
214 if let Some(&idx) = self.map.get(&key) {
216 let node = self.list.nodes[idx].as_mut().expect("node must exist");
217 let old_size = node.entry.size_bytes;
218 node.entry = CacheEntry::new(embedding);
219 let new_size = node.entry.size_bytes;
220 self.total_size_bytes = self.total_size_bytes - old_size + new_size;
221 self.list.move_to_head(idx);
222 return;
223 }
224
225 if self.map.len() >= self.capacity {
227 if let Some((evicted_key, evicted_entry)) = self.list.evict_lru() {
228 self.total_size_bytes -= evicted_entry.size_bytes;
229 self.map.remove(&evicted_key);
230 self.evictions += 1;
231 }
232 }
233
234 let entry = CacheEntry::new(embedding);
235 self.total_size_bytes += entry.size_bytes;
236
237 if let Some(idx) = self.list.allocate(key.clone(), entry) {
238 self.list.attach_head(idx);
239 self.map.insert(key, idx);
240 }
241 }
242
243 pub fn evict_lru(&mut self) -> Option<(CacheKey, CacheEntry)> {
245 let (k, e) = self.list.evict_lru()?;
246 self.map.remove(&k);
247 self.total_size_bytes -= e.size_bytes;
248 self.evictions += 1;
249 Some((k, e))
250 }
251
252 pub fn invalidate(&mut self, key: &CacheKey) -> bool {
254 if let Some(idx) = self.map.remove(key) {
255 self.list.detach(idx);
256 let node = self.list.nodes[idx].take().expect("node must exist");
257 self.list.free.push(idx);
258 self.total_size_bytes -= node.entry.size_bytes;
259 true
260 } else {
261 false
262 }
263 }
264
265 pub fn invalidate_model(&mut self, model_id: &str) -> usize {
267 let keys_to_remove: Vec<CacheKey> = self
268 .map
269 .keys()
270 .filter(|k| k.model_id == model_id)
271 .cloned()
272 .collect();
273 let count = keys_to_remove.len();
274 for key in keys_to_remove {
275 self.invalidate(&key);
276 }
277 count
278 }
279
280 pub fn stats(&self) -> CacheStats {
282 let total = self.hits + self.misses;
283 let hit_rate = if total == 0 {
284 0.0
285 } else {
286 self.hits as f64 / total as f64
287 };
288 CacheStats {
289 hits: self.hits,
290 misses: self.misses,
291 evictions: self.evictions,
292 hit_rate,
293 total_size_bytes: self.total_size_bytes,
294 }
295 }
296
297 pub fn len(&self) -> usize {
299 self.map.len()
300 }
301
302 pub fn is_empty(&self) -> bool {
304 self.map.is_empty()
305 }
306
307 pub fn capacity(&self) -> usize {
309 self.capacity
310 }
311}
312
313pub struct MemoryBoundedCache {
320 inner: EmbeddingCache,
321 max_bytes: usize,
322}
323
324impl MemoryBoundedCache {
325 pub fn new(max_bytes: usize) -> Self {
330 let capacity = (max_bytes / (4 * 128)).max(4);
333 Self {
334 inner: EmbeddingCache::new(capacity),
335 max_bytes,
336 }
337 }
338
339 pub fn insert(&mut self, key: CacheKey, embedding: Vec<f32>) {
341 self.inner.insert(key, embedding);
342 while self.inner.total_size_bytes > self.max_bytes {
344 if self.inner.evict_lru().is_none() {
345 break;
346 }
347 }
348 }
349
350 pub fn get(&mut self, key: &CacheKey) -> Option<&[f32]> {
352 self.inner.get(key)
353 }
354
355 pub fn total_size_bytes(&self) -> usize {
357 self.inner.total_size_bytes
358 }
359
360 pub fn max_bytes(&self) -> usize {
362 self.max_bytes
363 }
364
365 pub fn stats(&self) -> CacheStats {
367 self.inner.stats()
368 }
369
370 pub fn len(&self) -> usize {
372 self.inner.len()
373 }
374
375 pub fn is_empty(&self) -> bool {
377 self.inner.is_empty()
378 }
379}
380
381#[cfg(test)]
386mod tests {
387 use super::*;
388
389 fn key(hash: u64, model: &str) -> CacheKey {
390 CacheKey::new(hash, model)
391 }
392
393 fn emb(dims: usize) -> Vec<f32> {
394 (0..dims).map(|i| i as f32 * 0.1).collect()
395 }
396
397 #[test]
400 fn test_cache_key_equality() {
401 assert_eq!(key(1, "model-a"), key(1, "model-a"));
402 assert_ne!(key(1, "model-a"), key(1, "model-b"));
403 assert_ne!(key(1, "model-a"), key(2, "model-a"));
404 }
405
406 #[test]
407 fn test_cache_key_clone() {
408 let k = key(42, "bert");
409 let k2 = k.clone();
410 assert_eq!(k, k2);
411 }
412
413 #[test]
416 fn test_cache_entry_size_bytes() {
417 let e = CacheEntry::new(vec![0.0f32; 128]);
418 assert_eq!(e.size_bytes, 128 * 4);
419 }
420
421 #[test]
422 fn test_cache_entry_access_count_starts_zero() {
423 let e = CacheEntry::new(vec![1.0; 4]);
424 assert_eq!(e.access_count, 0);
425 }
426
427 #[test]
430 fn test_cache_empty_initially() {
431 let c = EmbeddingCache::new(10);
432 assert!(c.is_empty());
433 assert_eq!(c.len(), 0);
434 }
435
436 #[test]
437 fn test_cache_capacity() {
438 let c = EmbeddingCache::new(5);
439 assert_eq!(c.capacity(), 5);
440 }
441
442 #[test]
443 fn test_cache_capacity_min_one() {
444 let c = EmbeddingCache::new(0);
445 assert_eq!(c.capacity(), 1);
446 }
447
448 #[test]
449 fn test_cache_insert_and_get() {
450 let mut c = EmbeddingCache::new(10);
451 let k = key(1, "m");
452 let e = emb(4);
453 c.insert(k.clone(), e.clone());
454 let got = c.get(&k).expect("should be cached");
455 assert_eq!(got, e.as_slice());
456 }
457
458 #[test]
459 fn test_cache_miss() {
460 let mut c = EmbeddingCache::new(10);
461 assert!(c.get(&key(99, "m")).is_none());
462 }
463
464 #[test]
465 fn test_cache_len_after_insert() {
466 let mut c = EmbeddingCache::new(10);
467 c.insert(key(1, "m"), emb(4));
468 c.insert(key(2, "m"), emb(4));
469 assert_eq!(c.len(), 2);
470 }
471
472 #[test]
473 fn test_cache_stats_hits_misses() {
474 let mut c = EmbeddingCache::new(10);
475 c.insert(key(1, "m"), emb(4));
476 c.get(&key(1, "m")); c.get(&key(99, "m")); let s = c.stats();
479 assert_eq!(s.hits, 1);
480 assert_eq!(s.misses, 1);
481 assert!((s.hit_rate - 0.5).abs() < 1e-9);
482 }
483
484 #[test]
485 fn test_cache_stats_no_lookups_hit_rate_zero() {
486 let c = EmbeddingCache::new(10);
487 assert_eq!(c.stats().hit_rate, 0.0);
488 }
489
490 #[test]
491 fn test_cache_stats_total_size_bytes() {
492 let mut c = EmbeddingCache::new(10);
493 c.insert(key(1, "m"), vec![0.0f32; 64]);
494 c.insert(key(2, "m"), vec![0.0f32; 128]);
495 assert_eq!(c.stats().total_size_bytes, (64 + 128) * 4);
496 }
497
498 #[test]
501 fn test_cache_evicts_lru_when_full() {
502 let mut c = EmbeddingCache::new(3);
503 c.insert(key(1, "m"), emb(4));
504 c.insert(key(2, "m"), emb(4));
505 c.insert(key(3, "m"), emb(4));
506 c.insert(key(4, "m"), emb(4));
508 assert!(c.get(&key(1, "m")).is_none(), "key(1) should be evicted");
509 assert!(c.get(&key(4, "m")).is_some());
510 }
511
512 #[test]
513 fn test_cache_get_promotes_to_mru() {
514 let mut c = EmbeddingCache::new(3);
515 c.insert(key(1, "m"), emb(4));
516 c.insert(key(2, "m"), emb(4));
517 c.insert(key(3, "m"), emb(4));
518 c.get(&key(1, "m"));
520 c.insert(key(4, "m"), emb(4));
521 assert!(c.get(&key(2, "m")).is_none(), "key(2) should be evicted");
522 assert!(c.get(&key(1, "m")).is_some());
523 }
524
525 #[test]
526 fn test_cache_evictions_stat_incremented() {
527 let mut c = EmbeddingCache::new(2);
528 c.insert(key(1, "m"), emb(4));
529 c.insert(key(2, "m"), emb(4));
530 c.insert(key(3, "m"), emb(4)); assert_eq!(c.stats().evictions, 1);
532 }
533
534 #[test]
535 fn test_cache_manual_evict_lru() {
536 let mut c = EmbeddingCache::new(3);
537 c.insert(key(1, "m"), emb(4));
538 c.insert(key(2, "m"), emb(4));
539 let (evicted_key, _) = c.evict_lru().expect("should evict");
540 assert_eq!(evicted_key, key(1, "m"));
542 assert_eq!(c.len(), 1);
543 }
544
545 #[test]
546 fn test_cache_manual_evict_lru_empty() {
547 let mut c = EmbeddingCache::new(3);
548 assert!(c.evict_lru().is_none());
549 }
550
551 #[test]
554 fn test_cache_invalidate_present() {
555 let mut c = EmbeddingCache::new(10);
556 c.insert(key(1, "m"), emb(4));
557 assert!(c.invalidate(&key(1, "m")));
558 assert!(c.is_empty());
559 }
560
561 #[test]
562 fn test_cache_invalidate_absent() {
563 let mut c = EmbeddingCache::new(10);
564 assert!(!c.invalidate(&key(99, "m")));
565 }
566
567 #[test]
568 fn test_cache_invalidate_model() {
569 let mut c = EmbeddingCache::new(10);
570 c.insert(key(1, "bert"), emb(4));
571 c.insert(key(2, "bert"), emb(4));
572 c.insert(key(3, "gpt"), emb(4));
573 let removed = c.invalidate_model("bert");
574 assert_eq!(removed, 2);
575 assert!(c.get(&key(1, "bert")).is_none());
576 assert!(c.get(&key(2, "bert")).is_none());
577 assert!(c.get(&key(3, "gpt")).is_some());
578 }
579
580 #[test]
581 fn test_cache_invalidate_model_none() {
582 let mut c = EmbeddingCache::new(10);
583 assert_eq!(c.invalidate_model("unknown"), 0);
584 }
585
586 #[test]
587 fn test_cache_size_decreases_on_invalidate() {
588 let mut c = EmbeddingCache::new(10);
589 c.insert(key(1, "m"), vec![0.0f32; 64]);
590 let before = c.stats().total_size_bytes;
591 c.invalidate(&key(1, "m"));
592 assert_eq!(c.stats().total_size_bytes, before - 64 * 4);
593 }
594
595 #[test]
598 fn test_cache_insert_same_key_updates() {
599 let mut c = EmbeddingCache::new(10);
600 c.insert(key(1, "m"), emb(4));
601 c.insert(key(1, "m"), vec![99.0; 4]);
602 let got = c.get(&key(1, "m")).expect("should exist");
603 assert_eq!(got[0], 99.0);
604 assert_eq!(c.len(), 1);
605 }
606
607 #[test]
610 fn test_memory_bounded_cache_empty() {
611 let c = MemoryBoundedCache::new(1024);
612 assert!(c.is_empty());
613 assert_eq!(c.total_size_bytes(), 0);
614 }
615
616 #[test]
617 fn test_memory_bounded_cache_max_bytes() {
618 let c = MemoryBoundedCache::new(4096);
619 assert_eq!(c.max_bytes(), 4096);
620 }
621
622 #[test]
623 fn test_memory_bounded_cache_stays_within_limit() {
624 let mut c = MemoryBoundedCache::new(512);
626 for i in 0..10u64 {
627 c.insert(key(i, "m"), vec![0.0f32; 64]);
628 }
629 assert!(c.total_size_bytes() <= 512);
630 }
631
632 #[test]
633 fn test_memory_bounded_cache_insert_and_get() {
634 let mut c = MemoryBoundedCache::new(1 << 20); let k = key(1, "m");
636 c.insert(k.clone(), vec![1.0; 128]);
637 let got = c.get(&k).expect("should be present");
638 assert_eq!(got.len(), 128);
639 assert!((got[0] - 1.0).abs() < 1e-9);
640 }
641
642 #[test]
643 fn test_memory_bounded_cache_stats() {
644 let mut c = MemoryBoundedCache::new(1 << 20);
645 c.insert(key(1, "m"), vec![0.0; 32]);
646 c.get(&key(1, "m")); c.get(&key(2, "m")); let s = c.stats();
649 assert_eq!(s.hits, 1);
650 assert_eq!(s.misses, 1);
651 }
652
653 #[test]
656 fn test_cache_stress_insert_get() {
657 let mut c = EmbeddingCache::new(100);
658 for i in 0u64..200 {
659 c.insert(key(i, "m"), emb(32));
660 }
661 assert_eq!(c.len(), 100);
662 let s = c.stats();
663 assert_eq!(s.evictions, 100);
664 }
665
666 #[test]
667 fn test_cache_all_hits() {
668 let mut c = EmbeddingCache::new(50);
669 for i in 0u64..50 {
670 c.insert(key(i, "m"), emb(8));
671 }
672 for i in 0u64..50 {
673 assert!(c.get(&key(i, "m")).is_some());
674 }
675 assert_eq!(c.stats().hits, 50);
676 assert_eq!(c.stats().misses, 0);
677 assert!((c.stats().hit_rate - 1.0).abs() < 1e-9);
678 }
679
680 #[test]
683 fn test_cache_access_count_increments_on_get() {
684 let mut c = EmbeddingCache::new(10);
685 c.insert(key(1, "m"), emb(4));
686 c.get(&key(1, "m"));
688 c.get(&key(1, "m"));
689 c.get(&key(1, "m"));
690 assert_eq!(c.stats().hits, 3);
692 }
693
694 #[test]
695 fn test_cache_multiple_models_isolated() {
696 let mut c = EmbeddingCache::new(10);
697 c.insert(key(1, "bert"), emb(4));
698 c.insert(key(1, "gpt"), emb(8));
699 assert!(c.get(&key(1, "bert")).is_some());
700 assert!(c.get(&key(1, "gpt")).is_some());
701 }
702
703 #[test]
704 fn test_cache_key_hash_collision_different_models() {
705 let k1 = key(100, "modelA");
707 let k2 = key(100, "modelB");
708 assert_ne!(k1, k2);
709 }
710
711 #[test]
712 fn test_cache_evict_only_one_on_overflow() {
713 let mut c = EmbeddingCache::new(3);
714 c.insert(key(1, "m"), emb(4));
715 c.insert(key(2, "m"), emb(4));
716 c.insert(key(3, "m"), emb(4));
717 c.insert(key(4, "m"), emb(4)); assert_eq!(c.len(), 3);
719 }
720
721 #[test]
722 fn test_cache_get_returns_correct_embedding() {
723 let mut c = EmbeddingCache::new(5);
724 let v = vec![1.0f32, 2.0, 3.0, 4.0];
725 c.insert(key(42, "m"), v.clone());
726 let got = c.get(&key(42, "m")).expect("should be present");
727 assert_eq!(got, v.as_slice());
728 }
729
730 #[test]
731 fn test_cache_is_not_empty_after_insert() {
732 let mut c = EmbeddingCache::new(5);
733 c.insert(key(1, "m"), emb(4));
734 assert!(!c.is_empty());
735 }
736
737 #[test]
738 fn test_cache_len_zero_initially() {
739 let c = EmbeddingCache::new(5);
740 assert_eq!(c.len(), 0);
741 }
742
743 #[test]
744 fn test_cache_invalidate_all_via_model() {
745 let mut c = EmbeddingCache::new(10);
746 for i in 0u64..5 {
747 c.insert(key(i, "bert"), emb(4));
748 }
749 c.insert(key(99, "gpt"), emb(4));
750 let removed = c.invalidate_model("bert");
751 assert_eq!(removed, 5);
752 assert_eq!(c.len(), 1); }
754
755 #[test]
756 fn test_cache_stats_evictions_multiple() {
757 let mut c = EmbeddingCache::new(2);
758 for i in 0u64..6 {
759 c.insert(key(i, "m"), emb(4));
760 }
761 assert_eq!(c.stats().evictions, 4);
763 }
764
765 #[test]
766 fn test_cache_size_zero_after_all_invalidated() {
767 let mut c = EmbeddingCache::new(10);
768 c.insert(key(1, "m"), emb(32));
769 c.insert(key(2, "m"), emb(32));
770 c.invalidate(&key(1, "m"));
771 c.invalidate(&key(2, "m"));
772 assert_eq!(c.stats().total_size_bytes, 0);
773 }
774
775 #[test]
776 fn test_memory_bounded_cache_evicts_to_stay_within_limit() {
777 let mut c = MemoryBoundedCache::new(256);
779 for i in 0u64..5 {
780 c.insert(key(i, "m"), vec![0.0f32; 64]);
781 }
782 assert!(c.total_size_bytes() <= 256);
783 }
784
785 #[test]
786 fn test_memory_bounded_cache_get_returns_none_for_missing() {
787 let mut c = MemoryBoundedCache::new(1024);
788 assert!(c.get(&key(99, "m")).is_none());
789 }
790
791 #[test]
792 fn test_memory_bounded_cache_len_tracks_inserts() {
793 let mut c = MemoryBoundedCache::new(1 << 20);
794 assert_eq!(c.len(), 0);
795 c.insert(key(1, "m"), emb(4));
796 assert_eq!(c.len(), 1);
797 }
798
799 #[test]
800 fn test_cache_insert_updates_size_correctly() {
801 let mut c = EmbeddingCache::new(10);
802 c.insert(key(1, "m"), vec![0.0f32; 10]);
803 assert_eq!(c.stats().total_size_bytes, 10 * 4);
804 c.insert(key(2, "m"), vec![0.0f32; 20]);
805 assert_eq!(c.stats().total_size_bytes, 30 * 4);
806 }
807}