1use crate::model::ModelConfig;
20use lru::LruCache;
21use parking_lot::RwLock;
22use std::num::NonZeroUsize;
23use std::sync::Arc;
24use std::sync::atomic::{AtomicU64, Ordering};
25use tracing::debug;
26
27pub type CacheKey = [u8; 32];
29
30pub const DEFAULT_CACHE_CAPACITY: usize = 4000;
34
35const NUM_SHARDS: usize = 16;
38
39const SHARD_MASK: usize = NUM_SHARDS - 1;
41
42const _: () = assert!(
44 NUM_SHARDS.is_power_of_two(),
45 "NUM_SHARDS must be a power of 2"
46);
47
48struct CacheShard {
50 lru: RwLock<LruCache<CacheKey, Arc<[f32]>>>,
51 hits: AtomicU64,
52 misses: AtomicU64,
53}
54
55impl CacheShard {
56 fn new(capacity: NonZeroUsize) -> Self {
57 Self {
58 lru: RwLock::new(LruCache::new(capacity)),
59 hits: AtomicU64::new(0),
60 misses: AtomicU64::new(0),
61 }
62 }
63
64 #[inline]
65 fn get(&self, key: &CacheKey) -> Option<Arc<[f32]>> {
66 let mut lru = self.lru.write();
67 let result = lru.get(key).cloned();
68 if result.is_some() {
69 self.hits.fetch_add(1, Ordering::Relaxed);
70 } else {
71 self.misses.fetch_add(1, Ordering::Relaxed);
72 }
73 result
74 }
75
76 #[inline]
77 fn put(&self, key: CacheKey, embedding: Arc<[f32]>) {
78 let mut lru = self.lru.write();
79 lru.put(key, embedding);
80 }
81
82 fn len(&self) -> usize {
83 self.lru.read().len()
84 }
85
86 fn clear(&self) {
87 self.lru.write().clear();
88 }
89
90 fn hits(&self) -> u64 {
91 self.hits.load(Ordering::Relaxed)
92 }
93
94 fn misses(&self) -> u64 {
95 self.misses.load(Ordering::Relaxed)
96 }
97}
98
99pub struct EmbeddingCache {
132 shards: Vec<CacheShard>,
133 enabled: bool,
134 capacity: usize,
135}
136
137#[inline(always)]
140fn shard_index(key: &CacheKey) -> usize {
141 key[0] as usize & SHARD_MASK
142}
143
144impl EmbeddingCache {
145 pub fn new(capacity: usize) -> Self {
154 let enabled = capacity != 0;
155
156 let per_shard = if enabled {
160 let base = capacity.div_ceil(NUM_SHARDS);
162 if base == 0 { 1 } else { base }
163 } else {
164 1 };
166
167 let per_shard_nz = NonZeroUsize::new(per_shard).expect("per_shard is always >= 1");
168
169 let shards = (0..NUM_SHARDS)
170 .map(|_| CacheShard::new(per_shard_nz))
171 .collect();
172
173 Self {
174 shards,
175 enabled,
176 capacity,
177 }
178 }
179
180 pub fn with_default_capacity() -> Self {
182 Self::new(DEFAULT_CACHE_CAPACITY)
183 }
184
185 pub fn compute_key(&self, text: &str, model_config: ModelConfig) -> CacheKey {
191 let mut hasher = blake3::Hasher::new();
192 hasher.update(text.as_bytes());
193 let model_key = format!(
195 "{}:{}:{}",
196 model_config.model,
197 model_config.model.key_version(),
198 model_config.dimensions(),
199 );
200 hasher.update(model_key.as_bytes());
201 *hasher.finalize().as_bytes()
202 }
203
204 pub fn get(&self, key: &CacheKey) -> Option<Arc<[f32]>> {
209 if !self.enabled {
210 return None;
211 }
212
213 let idx = shard_index(key);
214 let result = self.shards[idx].get(key);
215
216 if result.is_some() {
217 debug!("cache hit for key {:?}", &key[..8]);
218 }
219
220 result
221 }
222
223 pub fn put(&self, key: CacheKey, embedding: Vec<f32>) {
228 if !self.enabled {
229 return;
230 }
231
232 let idx = shard_index(&key);
233 self.shards[idx].put(key, Arc::from(embedding));
234 debug!("cached embedding for key {:?}", &key[..8]);
235 }
236
237 pub fn get_many(&self, keys: &[CacheKey]) -> Vec<Option<Arc<[f32]>>> {
242 if !self.enabled {
243 return vec![None; keys.len()];
244 }
245
246 keys.iter()
247 .map(|key| {
248 let idx = shard_index(key);
249 self.shards[idx].get(key)
250 })
251 .collect()
252 }
253
254 pub fn put_many(&self, entries: Vec<(CacheKey, Vec<f32>)>) {
258 if !self.enabled {
259 return;
260 }
261
262 for (key, embedding) in entries {
263 let idx = shard_index(&key);
264 self.shards[idx].put(key, Arc::from(embedding));
265 }
266 }
267
268 pub fn stats(&self) -> CacheStats {
272 if !self.enabled {
273 let (hits, misses) = self.aggregate_counters();
274 return CacheStats {
275 size: 0,
276 capacity: 0,
277 hits,
278 misses,
279 };
280 }
281
282 let size: usize = self.shards.iter().map(CacheShard::len).sum();
283 let (hits, misses) = self.aggregate_counters();
284
285 CacheStats {
286 size,
287 capacity: self.capacity,
288 hits,
289 misses,
290 }
291 }
292
293 pub fn per_shard_stats(&self) -> Vec<ShardStats> {
297 self.shards
298 .iter()
299 .enumerate()
300 .map(|(i, s)| ShardStats {
301 shard_id: i,
302 size: s.len(),
303 hits: s.hits(),
304 misses: s.misses(),
305 })
306 .collect()
307 }
308
309 pub fn clear(&self) {
311 if !self.enabled {
312 return;
313 }
314
315 for shard in &self.shards {
316 shard.clear();
317 }
318 debug!("cache cleared");
319 }
320
321 #[inline]
323 pub fn is_enabled(&self) -> bool {
324 self.enabled
325 }
326
327 fn aggregate_counters(&self) -> (u64, u64) {
329 let hits: u64 = self.shards.iter().map(CacheShard::hits).sum();
330 let misses: u64 = self.shards.iter().map(CacheShard::misses).sum();
331 (hits, misses)
332 }
333}
334
335impl Default for EmbeddingCache {
336 fn default() -> Self {
337 Self::with_default_capacity()
338 }
339}
340
341#[derive(Debug, Clone, Copy)]
345pub struct CacheStats {
346 pub size: usize,
348 pub capacity: usize,
350 pub hits: u64,
352 pub misses: u64,
354}
355
356impl CacheStats {
357 pub fn hit_rate(&self) -> f64 {
359 let total = self.hits + self.misses;
360 if total == 0 {
361 0.0
362 } else {
363 self.hits as f64 / total as f64
364 }
365 }
366}
367
368#[derive(Debug, Clone, Copy)]
372pub struct ShardStats {
373 pub shard_id: usize,
375 pub size: usize,
377 pub hits: u64,
379 pub misses: u64,
381}
382
383impl ShardStats {
384 pub fn hit_rate(&self) -> f64 {
386 let total = self.hits + self.misses;
387 if total == 0 {
388 0.0
389 } else {
390 self.hits as f64 / total as f64
391 }
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use crate::model::EmbeddingModel;
399
400 #[test]
401 fn test_cache_basic_operations() {
402 let cache = EmbeddingCache::new(100);
403 let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
404
405 assert!(cache.get(&key).is_none());
407
408 let embedding = vec![0.1, 0.2, 0.3];
410 cache.put(key, embedding.clone());
411
412 let cached = cache.get(&key).unwrap();
414 assert_eq!(&*cached, &embedding[..]);
415 }
416
417 #[test]
418 fn test_cache_eviction() {
419 let cache = EmbeddingCache::new(16);
423
424 let mut keys = Vec::new();
428 for i in 0..32u32 {
429 let text = format!("text_{}", i);
430 let key = cache.compute_key(&text, ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
431 keys.push(key);
432 cache.put(key, vec![i as f32]);
433 }
434
435 let stats = cache.stats();
437 assert!(stats.size <= 16, "size {} exceeds capacity 16", stats.size);
438 }
439
440 #[test]
441 fn test_cache_lru_eviction_within_shard() {
442 let cache = EmbeddingCache::new(32);
444
445 let mut same_shard_keys = Vec::new();
447 let mut i = 0u32;
448 let target_shard;
449
450 let first_key =
452 cache.compute_key("probe_0", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
453 target_shard = shard_index(&first_key);
454
455 loop {
456 let key = cache.compute_key(
457 &format!("lru_test_{}", i),
458 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
459 );
460 if shard_index(&key) == target_shard {
461 same_shard_keys.push((key, i));
462 }
463 if same_shard_keys.len() == 3 {
464 break;
465 }
466 i += 1;
467 }
468
469 let (k1, v1) = same_shard_keys[0];
470 let (k2, v2) = same_shard_keys[1];
471 let (k3, v3) = same_shard_keys[2];
472
473 cache.put(k1, vec![v1 as f32]);
475 cache.put(k2, vec![v2 as f32]);
476
477 assert!(cache.get(&k1).is_some());
479
480 cache.put(k3, vec![v3 as f32]);
482
483 assert!(
484 cache.get(&k1).is_some(),
485 "k1 should survive (recently accessed)"
486 );
487 assert!(cache.get(&k2).is_none(), "k2 should be evicted (LRU)");
488 assert!(cache.get(&k3).is_some(), "k3 should exist (just inserted)");
489 }
490
491 #[test]
492 fn test_cache_different_models_different_keys() {
493 let cache = EmbeddingCache::new(100);
494
495 let key_small = cache.compute_key("text", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
496 let key_base = cache.compute_key("text", ModelConfig::new(EmbeddingModel::BgeBaseEnV15));
497
498 assert_ne!(key_small, key_base);
500 }
501
502 #[test]
503 fn test_cache_stats() {
504 let cache = EmbeddingCache::new(100);
505 let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
506
507 cache.get(&key); cache.put(key, vec![0.1]);
509 cache.get(&key); let stats = cache.stats();
512 assert_eq!(stats.size, 1);
513 assert_eq!(stats.hits, 1);
514 assert_eq!(stats.misses, 1);
515 assert!((stats.hit_rate() - 0.5).abs() < 0.001);
516 }
517
518 #[test]
519 fn test_cache_get_many() {
520 let cache = EmbeddingCache::new(100);
522
523 let key1 = cache.compute_key("one", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
524 let key2 = cache.compute_key("two", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
525 let key3 = cache.compute_key("three", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
526
527 cache.put(key1, vec![1.0]);
528 cache.put(key3, vec![3.0]);
529
530 let results = cache.get_many(&[key1, key2, key3]);
531 assert_eq!(results.len(), 3);
532 assert_eq!(&**results[0].as_ref().unwrap(), &[1.0f32]);
533 assert!(results[1].is_none());
534 assert_eq!(&**results[2].as_ref().unwrap(), &[3.0f32]);
535 }
536
537 #[test]
538 fn test_cache_put_many() {
539 let cache = EmbeddingCache::new(100);
541
542 let key1 = cache.compute_key("one", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
543 let key2 = cache.compute_key("two", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
544
545 cache.put_many(vec![(key1, vec![1.0]), (key2, vec![2.0])]);
546
547 let v1 = cache.get(&key1).unwrap();
548 assert_eq!(&*v1, [1.0f32].as_slice());
549 let v2 = cache.get(&key2).unwrap();
550 assert_eq!(&*v2, [2.0f32].as_slice());
551 }
552
553 #[test]
554 fn test_cache_clear() {
555 let cache = EmbeddingCache::new(100);
556 let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
557
558 cache.put(key, vec![0.1]);
559 assert!(cache.get(&key).is_some());
560
561 cache.clear();
562 assert!(cache.get(&key).is_none());
563 assert_eq!(cache.stats().size, 0);
564 }
565
566 #[test]
567 fn test_cache_default_capacity() {
568 let cache = EmbeddingCache::with_default_capacity();
569 assert_eq!(cache.stats().capacity, DEFAULT_CACHE_CAPACITY);
570 }
571
572 #[test]
573 fn test_cache_disabled_is_noop() {
574 let cache = EmbeddingCache::new(0);
575 assert!(!cache.is_enabled());
576
577 let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
578 cache.put(key, vec![0.1]);
579 assert!(cache.get(&key).is_none());
580
581 let stats = cache.stats();
582 assert_eq!(stats.capacity, 0);
583 assert_eq!(stats.size, 0);
584 }
585
586 #[test]
587 fn test_concurrent_access() {
588 use std::thread;
589
590 let cache = Arc::new(EmbeddingCache::new(4000));
593 let mut handles = Vec::new();
594
595 for t in 0..8 {
597 let cache = Arc::clone(&cache);
598 handles.push(thread::spawn(move || {
599 for i in 0..100 {
600 let text = format!("thread_{}_item_{}", t, i);
602 let key =
603 cache.compute_key(&text, ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
604 let embedding = vec![t as f32; 384];
605 cache.put(key, embedding.clone());
606
607 let result = cache.get(&key);
608 assert!(result.is_some(), "put followed by get must succeed");
609 assert_eq!(result.unwrap().len(), 384);
610 }
611 }));
612 }
613
614 for h in handles {
615 h.join().expect("thread panicked");
616 }
617
618 let stats = cache.stats();
620 assert_eq!(stats.size, 800);
621 assert!(stats.hits >= 800, "at least 800 hits expected");
622 }
623
624 #[test]
625 fn test_shard_distribution() {
626 let cache = EmbeddingCache::new(4000);
629
630 let n = 800;
631 for i in 0..n {
632 let key = cache.compute_key(
633 &format!("item_{}", i),
634 ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
635 );
636 cache.put(key, vec![i as f32]);
637 }
638
639 let shard_stats = cache.per_shard_stats();
640 assert_eq!(shard_stats.len(), NUM_SHARDS);
641
642 for ss in &shard_stats {
644 assert!(
645 ss.size > 0,
646 "shard {} is empty — distribution is pathological",
647 ss.shard_id
648 );
649 }
650
651 let total: usize = shard_stats.iter().map(|s| s.size).sum();
653 assert_eq!(total, n);
654
655 let avg = n / NUM_SHARDS; for ss in &shard_stats {
658 assert!(
659 ss.size <= avg * 3,
660 "shard {} has {} entries (avg {}), distribution too skewed",
661 ss.shard_id,
662 ss.size,
663 avg
664 );
665 }
666 }
667
668 #[test]
669 fn test_per_shard_stats_hit_tracking() {
670 let cache = EmbeddingCache::new(100);
671
672 let key1 = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
674 let key2 = cache.compute_key("world", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
675
676 cache.put(key1, vec![1.0]);
677 cache.put(key2, vec![2.0]);
678
679 cache.get(&key1);
681 cache.get(&key1);
682 cache.get(&key1);
683 cache.get(&key2);
684
685 let shard_stats = cache.per_shard_stats();
686 let total_hits: u64 = shard_stats.iter().map(|s| s.hits).sum();
687 assert_eq!(total_hits, 4, "total hits should be 4");
688
689 let stats = cache.stats();
690 assert_eq!(stats.hits, 4);
691 assert_eq!(stats.misses, 0);
692 }
693
694 #[test]
695 fn test_small_capacity_rounds_up() {
696 let cache = EmbeddingCache::new(3);
698 assert!(cache.is_enabled());
699
700 let key = cache.compute_key("x", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
701 cache.put(key, vec![42.0]);
702 assert!(cache.get(&key).is_some());
703 }
704}