1use crate::model::ModelConfig;
20use crate::service::EmbeddingRole;
21use lru::LruCache;
22use parking_lot::RwLock;
23use std::num::NonZeroUsize;
24use std::sync::Arc;
25use std::sync::atomic::{AtomicU64, Ordering};
26use tracing::debug;
27
28pub type CacheKey = [u8; 32];
30
31pub const DEFAULT_CACHE_CAPACITY: usize = 4000;
35
36const NUM_SHARDS: usize = 16;
39
40const SHARD_MASK: usize = NUM_SHARDS - 1;
42
43const _: () = assert!(
45 NUM_SHARDS.is_power_of_two(),
46 "NUM_SHARDS must be a power of 2"
47);
48
49struct CacheShard {
51 lru: RwLock<LruCache<CacheKey, Arc<[f32]>>>,
52 hits: AtomicU64,
53 misses: AtomicU64,
54}
55
56impl CacheShard {
57 fn new(capacity: NonZeroUsize) -> Self {
58 Self {
59 lru: RwLock::new(LruCache::new(capacity)),
60 hits: AtomicU64::new(0),
61 misses: AtomicU64::new(0),
62 }
63 }
64
65 #[inline]
66 fn get(&self, key: &CacheKey) -> Option<Arc<[f32]>> {
67 let mut lru = self.lru.write();
68 let result = lru.get(key).cloned();
69 if result.is_some() {
70 self.hits.fetch_add(1, Ordering::Relaxed);
71 } else {
72 self.misses.fetch_add(1, Ordering::Relaxed);
73 }
74 result
75 }
76
77 #[inline]
78 fn put(&self, key: CacheKey, embedding: Arc<[f32]>) {
79 let mut lru = self.lru.write();
80 lru.put(key, embedding);
81 }
82
83 fn len(&self) -> usize {
84 self.lru.read().len()
85 }
86
87 fn clear(&self) {
88 self.lru.write().clear();
89 }
90
91 fn hits(&self) -> u64 {
92 self.hits.load(Ordering::Relaxed)
93 }
94
95 fn misses(&self) -> u64 {
96 self.misses.load(Ordering::Relaxed)
97 }
98}
99
100pub struct EmbeddingCache {
138 shards: Vec<CacheShard>,
139 enabled: bool,
140 capacity: usize,
141}
142
143#[inline(always)]
146fn shard_index(key: &CacheKey) -> usize {
147 key[0] as usize & SHARD_MASK
148}
149
150impl EmbeddingCache {
151 pub fn new(capacity: usize) -> Self {
160 let enabled = capacity != 0;
161
162 let per_shard = if enabled {
166 let base = capacity.div_ceil(NUM_SHARDS);
168 if base == 0 { 1 } else { base }
169 } else {
170 1 };
172
173 let per_shard_nz = NonZeroUsize::new(per_shard).expect("per_shard is always >= 1");
174
175 let shards = (0..NUM_SHARDS)
176 .map(|_| CacheShard::new(per_shard_nz))
177 .collect();
178
179 Self {
180 shards,
181 enabled,
182 capacity,
183 }
184 }
185
186 pub fn with_default_capacity() -> Self {
188 Self::new(DEFAULT_CACHE_CAPACITY)
189 }
190
191 pub fn compute_key(
201 &self,
202 text: &str,
203 model_config: ModelConfig,
204 role: EmbeddingRole,
205 ) -> CacheKey {
206 let mut hasher = blake3::Hasher::new();
207 hasher.update(text.as_bytes());
208 let model_key = format!(
210 "{}:{}:{}:{}",
211 model_config.model,
212 model_config.model.key_version(),
213 model_config.dimensions(),
214 role.cache_tag(),
215 );
216 hasher.update(model_key.as_bytes());
217 *hasher.finalize().as_bytes()
218 }
219
220 pub fn get(&self, key: &CacheKey) -> Option<Arc<[f32]>> {
225 if !self.enabled {
226 return None;
227 }
228
229 let idx = shard_index(key);
230 let result = self.shards[idx].get(key);
231
232 if result.is_some() {
233 debug!("cache hit for key {:?}", &key[..8]);
234 }
235
236 result
237 }
238
239 pub fn put(&self, key: CacheKey, embedding: Vec<f32>) {
244 if !self.enabled {
245 return;
246 }
247
248 let idx = shard_index(&key);
249 self.shards[idx].put(key, Arc::from(embedding));
250 debug!("cached embedding for key {:?}", &key[..8]);
251 }
252
253 pub fn get_many(&self, keys: &[CacheKey]) -> Vec<Option<Arc<[f32]>>> {
258 if !self.enabled {
259 return vec![None; keys.len()];
260 }
261
262 keys.iter()
263 .map(|key| {
264 let idx = shard_index(key);
265 self.shards[idx].get(key)
266 })
267 .collect()
268 }
269
270 pub fn put_many(&self, entries: Vec<(CacheKey, Vec<f32>)>) {
274 if !self.enabled {
275 return;
276 }
277
278 for (key, embedding) in entries {
279 let idx = shard_index(&key);
280 self.shards[idx].put(key, Arc::from(embedding));
281 }
282 }
283
284 pub fn stats(&self) -> CacheStats {
288 if !self.enabled {
289 let (hits, misses) = self.aggregate_counters();
290 return CacheStats {
291 size: 0,
292 capacity: 0,
293 hits,
294 misses,
295 };
296 }
297
298 let size: usize = self.shards.iter().map(CacheShard::len).sum();
299 let (hits, misses) = self.aggregate_counters();
300
301 CacheStats {
302 size,
303 capacity: self.capacity,
304 hits,
305 misses,
306 }
307 }
308
309 pub fn per_shard_stats(&self) -> Vec<ShardStats> {
313 self.shards
314 .iter()
315 .enumerate()
316 .map(|(i, s)| ShardStats {
317 shard_id: i,
318 size: s.len(),
319 hits: s.hits(),
320 misses: s.misses(),
321 })
322 .collect()
323 }
324
325 pub fn clear(&self) {
327 if !self.enabled {
328 return;
329 }
330
331 for shard in &self.shards {
332 shard.clear();
333 }
334 debug!("cache cleared");
335 }
336
337 #[inline]
339 pub fn is_enabled(&self) -> bool {
340 self.enabled
341 }
342
343 fn aggregate_counters(&self) -> (u64, u64) {
345 let hits: u64 = self.shards.iter().map(CacheShard::hits).sum();
346 let misses: u64 = self.shards.iter().map(CacheShard::misses).sum();
347 (hits, misses)
348 }
349}
350
351impl Default for EmbeddingCache {
352 fn default() -> Self {
353 Self::with_default_capacity()
354 }
355}
356
357#[derive(Debug, Clone, Copy)]
361pub struct CacheStats {
362 pub size: usize,
364 pub capacity: usize,
366 pub hits: u64,
368 pub misses: u64,
370}
371
372impl CacheStats {
373 pub fn hit_rate(&self) -> f64 {
375 let total = self.hits + self.misses;
376 if total == 0 {
377 0.0
378 } else {
379 self.hits as f64 / total as f64
380 }
381 }
382}
383
384#[derive(Debug, Clone, Copy)]
388pub struct ShardStats {
389 pub shard_id: usize,
391 pub size: usize,
393 pub hits: u64,
395 pub misses: u64,
397}
398
399impl ShardStats {
400 pub fn hit_rate(&self) -> f64 {
402 let total = self.hits + self.misses;
403 if total == 0 {
404 0.0
405 } else {
406 self.hits as f64 / total as f64
407 }
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use crate::model::EmbeddingModel;
415
416 #[test]
417 fn test_cache_basic_operations() {
418 let cache = EmbeddingCache::new(100);
419 let key = cache.compute_key(
420 "hello",
421 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
422 EmbeddingRole::Generic,
423 );
424
425 assert!(cache.get(&key).is_none());
427
428 let embedding = vec![0.1, 0.2, 0.3];
430 cache.put(key, embedding.clone());
431
432 let cached = cache.get(&key).unwrap();
434 assert_eq!(&*cached, &embedding[..]);
435 }
436
437 #[test]
438 fn test_cache_eviction() {
439 let cache = EmbeddingCache::new(16);
443
444 let mut keys = Vec::new();
448 for i in 0..32u32 {
449 let text = format!("text_{i}");
450 let key = cache.compute_key(
451 &text,
452 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
453 EmbeddingRole::Generic,
454 );
455 keys.push(key);
456 cache.put(key, vec![i as f32]);
457 }
458
459 let stats = cache.stats();
461 assert!(stats.size <= 16, "size {} exceeds capacity 16", stats.size);
462 }
463
464 #[test]
465 fn test_cache_lru_eviction_within_shard() {
466 let cache = EmbeddingCache::new(32);
468
469 let mut same_shard_keys = Vec::new();
471 let mut i = 0u32;
472
473 let first_key = cache.compute_key(
475 "probe_0",
476 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
477 EmbeddingRole::Generic,
478 );
479 let target_shard = shard_index(&first_key);
480
481 loop {
482 let key = cache.compute_key(
483 &format!("lru_test_{i}"),
484 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
485 EmbeddingRole::Generic,
486 );
487 if shard_index(&key) == target_shard {
488 same_shard_keys.push((key, i));
489 }
490 if same_shard_keys.len() == 3 {
491 break;
492 }
493 i += 1;
494 }
495
496 let (k1, v1) = same_shard_keys[0];
497 let (k2, v2) = same_shard_keys[1];
498 let (k3, v3) = same_shard_keys[2];
499
500 cache.put(k1, vec![v1 as f32]);
502 cache.put(k2, vec![v2 as f32]);
503
504 assert!(cache.get(&k1).is_some());
506
507 cache.put(k3, vec![v3 as f32]);
509
510 assert!(
511 cache.get(&k1).is_some(),
512 "k1 should survive (recently accessed)"
513 );
514 assert!(cache.get(&k2).is_none(), "k2 should be evicted (LRU)");
515 assert!(cache.get(&k3).is_some(), "k3 should exist (just inserted)");
516 }
517
518 #[test]
519 fn test_cache_different_models_different_keys() {
520 let cache = EmbeddingCache::new(100);
521
522 let key_small = cache.compute_key(
523 "text",
524 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
525 EmbeddingRole::Generic,
526 );
527 let key_base = cache.compute_key(
528 "text",
529 ModelConfig::new(EmbeddingModel::BgeBaseEnV15),
530 EmbeddingRole::Generic,
531 );
532
533 assert_ne!(key_small, key_base);
535 }
536
537 #[test]
538 fn test_cache_stats() {
539 let cache = EmbeddingCache::new(100);
540 let key = cache.compute_key(
541 "hello",
542 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
543 EmbeddingRole::Generic,
544 );
545
546 cache.get(&key); cache.put(key, vec![0.1]);
548 cache.get(&key); let stats = cache.stats();
551 assert_eq!(stats.size, 1);
552 assert_eq!(stats.hits, 1);
553 assert_eq!(stats.misses, 1);
554 assert!((stats.hit_rate() - 0.5).abs() < 0.001);
555 }
556
557 #[test]
558 fn test_cache_get_many() {
559 let cache = EmbeddingCache::new(100);
561
562 let key1 = cache.compute_key(
563 "one",
564 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
565 EmbeddingRole::Generic,
566 );
567 let key2 = cache.compute_key(
568 "two",
569 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
570 EmbeddingRole::Generic,
571 );
572 let key3 = cache.compute_key(
573 "three",
574 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
575 EmbeddingRole::Generic,
576 );
577
578 cache.put(key1, vec![1.0]);
579 cache.put(key3, vec![3.0]);
580
581 let results = cache.get_many(&[key1, key2, key3]);
582 assert_eq!(results.len(), 3);
583 assert_eq!(&**results[0].as_ref().unwrap(), &[1.0f32]);
584 assert!(results[1].is_none());
585 assert_eq!(&**results[2].as_ref().unwrap(), &[3.0f32]);
586 }
587
588 #[test]
589 fn test_cache_put_many() {
590 let cache = EmbeddingCache::new(100);
592
593 let key1 = cache.compute_key(
594 "one",
595 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
596 EmbeddingRole::Generic,
597 );
598 let key2 = cache.compute_key(
599 "two",
600 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
601 EmbeddingRole::Generic,
602 );
603
604 cache.put_many(vec![(key1, vec![1.0]), (key2, vec![2.0])]);
605
606 let v1 = cache.get(&key1).unwrap();
607 assert_eq!(&*v1, [1.0f32].as_slice());
608 let v2 = cache.get(&key2).unwrap();
609 assert_eq!(&*v2, [2.0f32].as_slice());
610 }
611
612 #[test]
613 fn test_cache_clear() {
614 let cache = EmbeddingCache::new(100);
615 let key = cache.compute_key(
616 "hello",
617 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
618 EmbeddingRole::Generic,
619 );
620
621 cache.put(key, vec![0.1]);
622 assert!(cache.get(&key).is_some());
623
624 cache.clear();
625 assert!(cache.get(&key).is_none());
626 assert_eq!(cache.stats().size, 0);
627 }
628
629 #[test]
630 fn test_cache_default_capacity() {
631 let cache = EmbeddingCache::with_default_capacity();
632 assert_eq!(cache.stats().capacity, DEFAULT_CACHE_CAPACITY);
633 }
634
635 #[test]
636 fn test_cache_disabled_is_noop() {
637 let cache = EmbeddingCache::new(0);
638 assert!(!cache.is_enabled());
639
640 let key = cache.compute_key(
641 "hello",
642 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
643 EmbeddingRole::Generic,
644 );
645 cache.put(key, vec![0.1]);
646 assert!(cache.get(&key).is_none());
647
648 let stats = cache.stats();
649 assert_eq!(stats.capacity, 0);
650 assert_eq!(stats.size, 0);
651 }
652
653 #[test]
654 fn test_concurrent_access() {
655 use std::thread;
656
657 let cache = Arc::new(EmbeddingCache::new(4000));
660 let mut handles = Vec::new();
661
662 for t in 0..8 {
664 let cache = Arc::clone(&cache);
665 handles.push(thread::spawn(move || {
666 for i in 0..100 {
667 let text = format!("thread_{t}_item_{i}");
669 let key = cache.compute_key(
670 &text,
671 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
672 EmbeddingRole::Generic,
673 );
674 let embedding = vec![t as f32; 384];
675 cache.put(key, embedding.clone());
676
677 let result = cache.get(&key);
678 assert!(result.is_some(), "put followed by get must succeed");
679 assert_eq!(result.unwrap().len(), 384);
680 }
681 }));
682 }
683
684 for h in handles {
685 h.join().expect("thread panicked");
686 }
687
688 let stats = cache.stats();
690 assert_eq!(stats.size, 800);
691 assert!(stats.hits >= 800, "at least 800 hits expected");
692 }
693
694 #[test]
695 fn test_shard_distribution() {
696 let cache = EmbeddingCache::new(4000);
699
700 let n = 800;
701 for i in 0..n {
702 let key = cache.compute_key(
703 &format!("item_{i}"),
704 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
705 EmbeddingRole::Generic,
706 );
707 cache.put(key, vec![i as f32]);
708 }
709
710 let shard_stats = cache.per_shard_stats();
711 assert_eq!(shard_stats.len(), NUM_SHARDS);
712
713 for ss in &shard_stats {
715 assert!(
716 ss.size > 0,
717 "shard {} is empty — distribution is pathological",
718 ss.shard_id
719 );
720 }
721
722 let total: usize = shard_stats.iter().map(|s| s.size).sum();
724 assert_eq!(total, n);
725
726 let avg = n / NUM_SHARDS; for ss in &shard_stats {
729 assert!(
730 ss.size <= avg * 3,
731 "shard {} has {} entries (avg {}), distribution too skewed",
732 ss.shard_id,
733 ss.size,
734 avg
735 );
736 }
737 }
738
739 #[test]
740 fn test_per_shard_stats_hit_tracking() {
741 let cache = EmbeddingCache::new(100);
742
743 let key1 = cache.compute_key(
745 "hello",
746 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
747 EmbeddingRole::Generic,
748 );
749 let key2 = cache.compute_key(
750 "world",
751 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
752 EmbeddingRole::Generic,
753 );
754
755 cache.put(key1, vec![1.0]);
756 cache.put(key2, vec![2.0]);
757
758 cache.get(&key1);
760 cache.get(&key1);
761 cache.get(&key1);
762 cache.get(&key2);
763
764 let shard_stats = cache.per_shard_stats();
765 let total_hits: u64 = shard_stats.iter().map(|s| s.hits).sum();
766 assert_eq!(total_hits, 4, "total hits should be 4");
767
768 let stats = cache.stats();
769 assert_eq!(stats.hits, 4);
770 assert_eq!(stats.misses, 0);
771 }
772
773 #[test]
774 fn test_small_capacity_rounds_up() {
775 let cache = EmbeddingCache::new(3);
777 assert!(cache.is_enabled());
778
779 let key = cache.compute_key(
780 "x",
781 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
782 EmbeddingRole::Generic,
783 );
784 cache.put(key, vec![42.0]);
785 assert!(cache.get(&key).is_some());
786 }
787
788 #[test]
795 fn test_role_query_vs_passage_different_keys() {
796 let cache = EmbeddingCache::new(100);
797 let model = ModelConfig::new(EmbeddingModel::MultilingualE5Small);
798 let text = "hello world";
799
800 let key_query = cache.compute_key(text, model, EmbeddingRole::Query);
801 let key_passage = cache.compute_key(text, model, EmbeddingRole::Passage);
802 let key_generic = cache.compute_key(text, model, EmbeddingRole::Generic);
803
804 assert_ne!(key_query, key_passage, "query vs passage must differ");
805 assert_ne!(key_query, key_generic, "query vs generic must differ");
806 assert_ne!(key_passage, key_generic, "passage vs generic must differ");
807 }
808
809 #[test]
811 fn test_role_key_deterministic() {
812 let cache = EmbeddingCache::new(100);
813 let model = ModelConfig::new(EmbeddingModel::BgeSmallEnV15);
814
815 let k1 = cache.compute_key("test", model, EmbeddingRole::Query);
816 let k2 = cache.compute_key("test", model, EmbeddingRole::Query);
817 assert_eq!(k1, k2, "identical inputs must produce identical key");
818 }
819
820 #[test]
822 fn test_role_cache_isolation() {
823 let cache = EmbeddingCache::new(100);
824 let model = ModelConfig::new(EmbeddingModel::MultilingualE5Small);
825
826 let key_query = cache.compute_key("embed me", model, EmbeddingRole::Query);
827 let key_passage = cache.compute_key("embed me", model, EmbeddingRole::Passage);
828
829 cache.put(key_query, vec![1.0, 2.0]);
831
832 assert!(
834 cache.get(&key_passage).is_none(),
835 "passage key must miss after storing under query key"
836 );
837
838 assert!(
840 cache.get(&key_query).is_some(),
841 "query key must hit after storing under query key"
842 );
843 }
844}