Skip to main content

lattice_embed/
cache.rs

1//! Sharded embedding cache with LRU eviction.
2//!
3//! Caches embeddings to avoid re-computing for identical texts. Uses 16 independent
4//! shards to reduce write-lock contention — each `get()` on an LRU cache requires a
5//! write lock (to update access order), so a single `RwLock<LruCache>` serializes all
6//! reads under high QPS. Sharding reduces contention by a factor of `NUM_SHARDS`.
7//!
8//! # Design
9//!
10//! - **Shard selection**: First byte of the Blake3 cache key, masked to `NUM_SHARDS - 1`.
11//!   Blake3 output is uniformly distributed, so shard load is balanced.
12//! - **Per-shard capacity**: `total_capacity / NUM_SHARDS`. Each shard independently
13//!   evicts its own LRU entries.
14//! - **Per-shard statistics**: Hit/miss counters are per-shard `AtomicU64`s, aggregated
15//!   in `stats()`. This eliminates cross-shard atomic contention on the hot path.
16//! - **Zero-capacity**: `capacity=0` disables caching entirely — all operations become
17//!   no-ops with no locking or hashing work.
18
19use crate::model::ModelConfig;
20use lru::LruCache;
21use parking_lot::RwLock;
22use std::num::NonZeroUsize;
23use std::sync::Arc;
24use std::sync::atomic::{AtomicU64, Ordering};
25use tracing::debug;
26
27/// **Unstable**: internal implementation detail; type alias may change with cache redesign.
28pub type CacheKey = [u8; 32];
29
30/// **Unstable**: tuning constant; value may change as memory models evolve.
31///
32/// Default cache capacity (number of embeddings). ~6MB for 384-dim vectors at 4000 entries.
33pub const DEFAULT_CACHE_CAPACITY: usize = 4000;
34
35/// Number of cache shards. Must be a power of 2 for fast modulo (bitwise AND).
36/// 16 shards on 8-core M4 Pro gives 2x oversubscription, keeping contention low.
37const NUM_SHARDS: usize = 16;
38
39/// Mask for shard index computation: `key[0] as usize & SHARD_MASK`.
40const SHARD_MASK: usize = NUM_SHARDS - 1;
41
42// Compile-time assertion that NUM_SHARDS is a power of 2.
43const _: () = assert!(
44    NUM_SHARDS.is_power_of_two(),
45    "NUM_SHARDS must be a power of 2"
46);
47
48/// A single cache shard with its own LRU cache and hit/miss counters.
49struct CacheShard {
50    lru: RwLock<LruCache<CacheKey, Arc<[f32]>>>,
51    hits: AtomicU64,
52    misses: AtomicU64,
53}
54
55impl CacheShard {
56    fn new(capacity: NonZeroUsize) -> Self {
57        Self {
58            lru: RwLock::new(LruCache::new(capacity)),
59            hits: AtomicU64::new(0),
60            misses: AtomicU64::new(0),
61        }
62    }
63
64    #[inline]
65    fn get(&self, key: &CacheKey) -> Option<Arc<[f32]>> {
66        let mut lru = self.lru.write();
67        let result = lru.get(key).cloned();
68        if result.is_some() {
69            self.hits.fetch_add(1, Ordering::Relaxed);
70        } else {
71            self.misses.fetch_add(1, Ordering::Relaxed);
72        }
73        result
74    }
75
76    #[inline]
77    fn put(&self, key: CacheKey, embedding: Arc<[f32]>) {
78        let mut lru = self.lru.write();
79        lru.put(key, embedding);
80    }
81
82    fn len(&self) -> usize {
83        self.lru.read().len()
84    }
85
86    fn clear(&self) {
87        self.lru.write().clear();
88    }
89
90    fn hits(&self) -> u64 {
91        self.hits.load(Ordering::Relaxed)
92    }
93
94    fn misses(&self) -> u64 {
95        self.misses.load(Ordering::Relaxed)
96    }
97}
98
99/// **Unstable**: internal LRU caching mechanism; shard count and eviction policy may change.
100///
101/// Embedding cache with sharded LRU eviction policy.
102///
103/// Thread-safe cache for storing computed embeddings. Uses Blake3 hashing
104/// for fast, collision-resistant cache keys. Internally sharded into 16
105/// independent LRU caches to reduce write-lock contention.
106///
107/// # Disabling
108///
109/// Pass `capacity=0` to disable caching. All cache operations become no-ops
110/// (no locking, no hashing work beyond key construction).
111///
112/// # Example
113///
114/// ```rust
115/// use lattice_embed::{EmbeddingCache, EmbeddingModel, ModelConfig};
116///
117/// let cache = EmbeddingCache::new(1000);
118///
119/// // Cache miss - no embedding stored yet
120/// let key = cache.compute_key("Hello, world!", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
121/// assert!(cache.get(&key).is_none());
122///
123/// // Store embedding
124/// let embedding = vec![0.1, 0.2, 0.3];
125/// cache.put(key, embedding.clone());
126///
127/// // Cache hit — returns Arc<[f32]>
128/// let cached = cache.get(&key).unwrap();
129/// assert_eq!(&*cached, &embedding[..]);
130/// ```
131pub struct EmbeddingCache {
132    shards: Vec<CacheShard>,
133    enabled: bool,
134    capacity: usize,
135}
136
137/// Select shard index from a cache key. Uses first byte masked to shard count.
138/// Blake3 output is uniformly distributed, so this gives balanced load.
139#[inline(always)]
140fn shard_index(key: &CacheKey) -> usize {
141    key[0] as usize & SHARD_MASK
142}
143
144impl EmbeddingCache {
145    /// **Unstable**: constructor signature may change when shard count becomes configurable.
146    ///
147    /// The capacity is divided equally across 16 internal shards. Each shard
148    /// independently manages its own LRU eviction.
149    ///
150    /// # Arguments
151    ///
152    /// * `capacity` - Maximum total number of embeddings to cache. Use 0 to disable caching.
153    pub fn new(capacity: usize) -> Self {
154        let enabled = capacity != 0;
155
156        // Per-shard capacity: ceiling division ensures total actual capacity >= requested.
157        // E.g., capacity=4000, NUM_SHARDS=16 → 250/shard (exact).
158        // E.g., capacity=10, NUM_SHARDS=16 → 1/shard (at least 1).
159        let per_shard = if enabled {
160            // Ceiling division: (capacity + NUM_SHARDS - 1) / NUM_SHARDS, minimum 1.
161            let base = capacity.div_ceil(NUM_SHARDS);
162            if base == 0 { 1 } else { base }
163        } else {
164            1 // Dummy capacity for disabled cache
165        };
166
167        let per_shard_nz = NonZeroUsize::new(per_shard).expect("per_shard is always >= 1");
168
169        let shards = (0..NUM_SHARDS)
170            .map(|_| CacheShard::new(per_shard_nz))
171            .collect();
172
173        Self {
174            shards,
175            enabled,
176            capacity,
177        }
178    }
179
180    /// **Unstable**: convenience constructor; subject to change with cache redesign.
181    pub fn with_default_capacity() -> Self {
182        Self::new(DEFAULT_CACHE_CAPACITY)
183    }
184
185    /// **Unstable**: key scheme (Blake3 + EmbeddingKey canonical bytes) may change; don't store keys across sessions.
186    ///
187    /// Uses Blake3 hashing for fast, collision-resistant keys. The key includes the model
188    /// name, revision, and active dimension from the `ModelConfig`, so different MRL truncations
189    /// produce different cache keys.
190    pub fn compute_key(&self, text: &str, model_config: ModelConfig) -> CacheKey {
191        let mut hasher = blake3::Hasher::new();
192        hasher.update(text.as_bytes());
193        // Unique identifier for the model config: "model_name:version:dims"
194        let model_key = format!(
195            "{}:{}:{}",
196            model_config.model,
197            model_config.model.key_version(),
198            model_config.dimensions(),
199        );
200        hasher.update(model_key.as_bytes());
201        *hasher.finalize().as_bytes()
202    }
203
204    /// **Unstable**: return type (`Arc<[f32]>`) may change to a newtype; internal cache API.
205    ///
206    /// Returns `Some(Arc<[f32]>)` if found (cheap refcount bump), `None` otherwise.
207    /// Updates per-shard hit/miss counters for metrics.
208    pub fn get(&self, key: &CacheKey) -> Option<Arc<[f32]>> {
209        if !self.enabled {
210            return None;
211        }
212
213        let idx = shard_index(key);
214        let result = self.shards[idx].get(key);
215
216        if result.is_some() {
217            debug!("cache hit for key {:?}", &key[..8]);
218        }
219
220        result
221    }
222
223    /// **Unstable**: internal cache storage method; interface may change.
224    ///
225    /// Converts the Vec into `Arc<[f32]>` for shared-ownership storage.
226    /// If the shard is at capacity, its least recently used entry is evicted.
227    pub fn put(&self, key: CacheKey, embedding: Vec<f32>) {
228        if !self.enabled {
229            return;
230        }
231
232        let idx = shard_index(&key);
233        self.shards[idx].put(key, Arc::from(embedding));
234        debug!("cached embedding for key {:?}", &key[..8]);
235    }
236
237    /// **Unstable**: batch cache access; return type may change with cache redesign.
238    ///
239    /// Returns a vector of `Option<Arc<[f32]>>` for each key, in the same order.
240    /// Each hit is an O(1) refcount bump (no data copy).
241    pub fn get_many(&self, keys: &[CacheKey]) -> Vec<Option<Arc<[f32]>>> {
242        if !self.enabled {
243            return vec![None; keys.len()];
244        }
245
246        keys.iter()
247            .map(|key| {
248                let idx = shard_index(key);
249                self.shards[idx].get(key)
250            })
251            .collect()
252    }
253
254    /// **Unstable**: batch cache storage; interface may change with cache redesign.
255    ///
256    /// Converts each Vec into `Arc<[f32]>` for shared-ownership storage.
257    pub fn put_many(&self, entries: Vec<(CacheKey, Vec<f32>)>) {
258        if !self.enabled {
259            return;
260        }
261
262        for (key, embedding) in entries {
263            let idx = shard_index(&key);
264            self.shards[idx].put(key, Arc::from(embedding));
265        }
266    }
267
268    /// **Unstable**: returns `CacheStats` which is itself Unstable; metrics shape may evolve.
269    ///
270    /// Aggregates per-shard counters. The `size` field is the sum of all shard sizes.
271    pub fn stats(&self) -> CacheStats {
272        if !self.enabled {
273            let (hits, misses) = self.aggregate_counters();
274            return CacheStats {
275                size: 0,
276                capacity: 0,
277                hits,
278                misses,
279            };
280        }
281
282        let size: usize = self.shards.iter().map(CacheShard::len).sum();
283        let (hits, misses) = self.aggregate_counters();
284
285        CacheStats {
286            size,
287            capacity: self.capacity,
288            hits,
289            misses,
290        }
291    }
292
293    /// **Unstable**: internal monitoring hook; shard count and `ShardStats` shape may change.
294    ///
295    /// Returns a vector of `(size, hits, misses)` tuples, one per shard.
296    pub fn per_shard_stats(&self) -> Vec<ShardStats> {
297        self.shards
298            .iter()
299            .enumerate()
300            .map(|(i, s)| ShardStats {
301                shard_id: i,
302                size: s.len(),
303                hits: s.hits(),
304                misses: s.misses(),
305            })
306            .collect()
307    }
308
309    /// **Unstable**: internal cache management; may be removed in favor of capacity-based eviction.
310    pub fn clear(&self) {
311        if !self.enabled {
312            return;
313        }
314
315        for shard in &self.shards {
316            shard.clear();
317        }
318        debug!("cache cleared");
319    }
320
321    /// **Unstable**: internal state query; may be removed when zero-capacity is the only disable path.
322    #[inline]
323    pub fn is_enabled(&self) -> bool {
324        self.enabled
325    }
326
327    /// Aggregate hit/miss counters across all shards.
328    fn aggregate_counters(&self) -> (u64, u64) {
329        let hits: u64 = self.shards.iter().map(CacheShard::hits).sum();
330        let misses: u64 = self.shards.iter().map(CacheShard::misses).sum();
331        (hits, misses)
332    }
333}
334
335impl Default for EmbeddingCache {
336    fn default() -> Self {
337        Self::with_default_capacity()
338    }
339}
340
341/// **Unstable**: metrics fields may be added/removed as monitoring needs evolve.
342///
343/// Cache statistics (aggregated across all shards).
344#[derive(Debug, Clone, Copy)]
345pub struct CacheStats {
346    /// Current number of cached entries (sum across all shards).
347    pub size: usize,
348    /// Maximum total cache capacity.
349    pub capacity: usize,
350    /// Number of cache hits (sum across all shards).
351    pub hits: u64,
352    /// Number of cache misses (sum across all shards).
353    pub misses: u64,
354}
355
356impl CacheStats {
357    /// **Unstable**: convenience metric; may move to a separate stats helper.
358    pub fn hit_rate(&self) -> f64 {
359        let total = self.hits + self.misses;
360        if total == 0 {
361            0.0
362        } else {
363            self.hits as f64 / total as f64
364        }
365    }
366}
367
368/// **Unstable**: shard count is an internal implementation detail; this struct may be removed.
369///
370/// Per-shard statistics for detailed monitoring.
371#[derive(Debug, Clone, Copy)]
372pub struct ShardStats {
373    /// Shard index (0 to NUM_SHARDS-1).
374    pub shard_id: usize,
375    /// Current number of entries in this shard.
376    pub size: usize,
377    /// Number of cache hits in this shard.
378    pub hits: u64,
379    /// Number of cache misses in this shard.
380    pub misses: u64,
381}
382
383impl ShardStats {
384    /// **Unstable**: per-shard metric; may be removed with `ShardStats`.
385    pub fn hit_rate(&self) -> f64 {
386        let total = self.hits + self.misses;
387        if total == 0 {
388            0.0
389        } else {
390            self.hits as f64 / total as f64
391        }
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use crate::model::EmbeddingModel;
399
400    #[test]
401    fn test_cache_basic_operations() {
402        let cache = EmbeddingCache::new(100);
403        let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
404
405        // Miss
406        assert!(cache.get(&key).is_none());
407
408        // Put
409        let embedding = vec![0.1, 0.2, 0.3];
410        cache.put(key, embedding.clone());
411
412        // Hit — returns Arc<[f32]>
413        let cached = cache.get(&key).unwrap();
414        assert_eq!(&*cached, &embedding[..]);
415    }
416
417    #[test]
418    fn test_cache_eviction() {
419        // With 16 shards, a capacity of 16 gives 1 entry per shard.
420        // To test eviction, we need keys that hash to the same shard.
421        // Use a larger capacity and fill it up.
422        let cache = EmbeddingCache::new(16);
423
424        // Insert 32 entries — each shard has capacity 1, so each shard
425        // can only hold 1 entry. Inserting 2 entries to the same shard
426        // will evict the first.
427        let mut keys = Vec::new();
428        for i in 0..32u32 {
429            let text = format!("text_{}", i);
430            let key = cache.compute_key(&text, ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
431            keys.push(key);
432            cache.put(key, vec![i as f32]);
433        }
434
435        // Total size should not exceed capacity (16)
436        let stats = cache.stats();
437        assert!(stats.size <= 16, "size {} exceeds capacity 16", stats.size);
438    }
439
440    #[test]
441    fn test_cache_lru_eviction_within_shard() {
442        // Create cache with capacity 32 (2 per shard).
443        let cache = EmbeddingCache::new(32);
444
445        // Find 3 keys that land in the same shard.
446        let mut same_shard_keys = Vec::new();
447        let mut i = 0u32;
448        let target_shard;
449
450        // Find the first key's shard and collect 3 keys for it.
451        let first_key =
452            cache.compute_key("probe_0", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
453        target_shard = shard_index(&first_key);
454
455        loop {
456            let key = cache.compute_key(
457                &format!("lru_test_{}", i),
458                ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
459            );
460            if shard_index(&key) == target_shard {
461                same_shard_keys.push((key, i));
462            }
463            if same_shard_keys.len() == 3 {
464                break;
465            }
466            i += 1;
467        }
468
469        let (k1, v1) = same_shard_keys[0];
470        let (k2, v2) = same_shard_keys[1];
471        let (k3, v3) = same_shard_keys[2];
472
473        // Insert k1 and k2 (shard capacity is 2).
474        cache.put(k1, vec![v1 as f32]);
475        cache.put(k2, vec![v2 as f32]);
476
477        // Access k1 to make it recently used.
478        assert!(cache.get(&k1).is_some());
479
480        // Insert k3 — should evict k2 (least recently used in this shard).
481        cache.put(k3, vec![v3 as f32]);
482
483        assert!(
484            cache.get(&k1).is_some(),
485            "k1 should survive (recently accessed)"
486        );
487        assert!(cache.get(&k2).is_none(), "k2 should be evicted (LRU)");
488        assert!(cache.get(&k3).is_some(), "k3 should exist (just inserted)");
489    }
490
491    #[test]
492    fn test_cache_different_models_different_keys() {
493        let cache = EmbeddingCache::new(100);
494
495        let key_small = cache.compute_key("text", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
496        let key_base = cache.compute_key("text", ModelConfig::new(EmbeddingModel::BgeBaseEnV15));
497
498        // Same text, different models = different keys
499        assert_ne!(key_small, key_base);
500    }
501
502    #[test]
503    fn test_cache_stats() {
504        let cache = EmbeddingCache::new(100);
505        let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
506
507        cache.get(&key); // Miss
508        cache.put(key, vec![0.1]);
509        cache.get(&key); // Hit
510
511        let stats = cache.stats();
512        assert_eq!(stats.size, 1);
513        assert_eq!(stats.hits, 1);
514        assert_eq!(stats.misses, 1);
515        assert!((stats.hit_rate() - 0.5).abs() < 0.001);
516    }
517
518    #[test]
519    fn test_cache_get_many() {
520        // Use capacity large enough that no shard evicts.
521        let cache = EmbeddingCache::new(100);
522
523        let key1 = cache.compute_key("one", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
524        let key2 = cache.compute_key("two", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
525        let key3 = cache.compute_key("three", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
526
527        cache.put(key1, vec![1.0]);
528        cache.put(key3, vec![3.0]);
529
530        let results = cache.get_many(&[key1, key2, key3]);
531        assert_eq!(results.len(), 3);
532        assert_eq!(&**results[0].as_ref().unwrap(), &[1.0f32]);
533        assert!(results[1].is_none());
534        assert_eq!(&**results[2].as_ref().unwrap(), &[3.0f32]);
535    }
536
537    #[test]
538    fn test_cache_put_many() {
539        // Use capacity large enough that no shard evicts (ceil(100/16) = 7 per shard).
540        let cache = EmbeddingCache::new(100);
541
542        let key1 = cache.compute_key("one", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
543        let key2 = cache.compute_key("two", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
544
545        cache.put_many(vec![(key1, vec![1.0]), (key2, vec![2.0])]);
546
547        let v1 = cache.get(&key1).unwrap();
548        assert_eq!(&*v1, [1.0f32].as_slice());
549        let v2 = cache.get(&key2).unwrap();
550        assert_eq!(&*v2, [2.0f32].as_slice());
551    }
552
553    #[test]
554    fn test_cache_clear() {
555        let cache = EmbeddingCache::new(100);
556        let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
557
558        cache.put(key, vec![0.1]);
559        assert!(cache.get(&key).is_some());
560
561        cache.clear();
562        assert!(cache.get(&key).is_none());
563        assert_eq!(cache.stats().size, 0);
564    }
565
566    #[test]
567    fn test_cache_default_capacity() {
568        let cache = EmbeddingCache::with_default_capacity();
569        assert_eq!(cache.stats().capacity, DEFAULT_CACHE_CAPACITY);
570    }
571
572    #[test]
573    fn test_cache_disabled_is_noop() {
574        let cache = EmbeddingCache::new(0);
575        assert!(!cache.is_enabled());
576
577        let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
578        cache.put(key, vec![0.1]);
579        assert!(cache.get(&key).is_none());
580
581        let stats = cache.stats();
582        assert_eq!(stats.capacity, 0);
583        assert_eq!(stats.size, 0);
584    }
585
586    #[test]
587    fn test_concurrent_access() {
588        use std::thread;
589
590        // Use large capacity so no eviction occurs (800 entries across 16 shards).
591        // Per-shard capacity = ceil(4000/16) = 250, so 800 entries fit easily.
592        let cache = Arc::new(EmbeddingCache::new(4000));
593        let mut handles = Vec::new();
594
595        // Spawn 8 threads, each doing 100 put+get operations.
596        for t in 0..8 {
597            let cache = Arc::clone(&cache);
598            handles.push(thread::spawn(move || {
599                for i in 0..100 {
600                    // Each thread uses unique keys to avoid contention on same entry.
601                    let text = format!("thread_{}_item_{}", t, i);
602                    let key =
603                        cache.compute_key(&text, ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
604                    let embedding = vec![t as f32; 384];
605                    cache.put(key, embedding.clone());
606
607                    let result = cache.get(&key);
608                    assert!(result.is_some(), "put followed by get must succeed");
609                    assert_eq!(result.unwrap().len(), 384);
610                }
611            }));
612        }
613
614        for h in handles {
615            h.join().expect("thread panicked");
616        }
617
618        // All 800 entries should be in cache (capacity 4000 >> 800).
619        let stats = cache.stats();
620        assert_eq!(stats.size, 800);
621        assert!(stats.hits >= 800, "at least 800 hits expected");
622    }
623
624    #[test]
625    fn test_shard_distribution() {
626        // Use generous capacity to avoid eviction from uneven distribution.
627        // 4000 total → 250/shard. We insert 800 entries → ~50/shard on average.
628        let cache = EmbeddingCache::new(4000);
629
630        let n = 800;
631        for i in 0..n {
632            let key = cache.compute_key(
633                &format!("item_{}", i),
634                ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
635            );
636            cache.put(key, vec![i as f32]);
637        }
638
639        let shard_stats = cache.per_shard_stats();
640        assert_eq!(shard_stats.len(), NUM_SHARDS);
641
642        // Each shard should have entries. With uniform hash, each shard gets ~50.
643        for ss in &shard_stats {
644            assert!(
645                ss.size > 0,
646                "shard {} is empty — distribution is pathological",
647                ss.shard_id
648            );
649        }
650
651        // No eviction since 800 << 4000. Total should be exactly 800.
652        let total: usize = shard_stats.iter().map(|s| s.size).sum();
653        assert_eq!(total, n);
654
655        // Check distribution is reasonably uniform: no shard has >3x the average.
656        let avg = n / NUM_SHARDS; // 50
657        for ss in &shard_stats {
658            assert!(
659                ss.size <= avg * 3,
660                "shard {} has {} entries (avg {}), distribution too skewed",
661                ss.shard_id,
662                ss.size,
663                avg
664            );
665        }
666    }
667
668    #[test]
669    fn test_per_shard_stats_hit_tracking() {
670        let cache = EmbeddingCache::new(100);
671
672        // Insert a few entries and access them.
673        let key1 = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
674        let key2 = cache.compute_key("world", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
675
676        cache.put(key1, vec![1.0]);
677        cache.put(key2, vec![2.0]);
678
679        // Access key1 three times, key2 once.
680        cache.get(&key1);
681        cache.get(&key1);
682        cache.get(&key1);
683        cache.get(&key2);
684
685        let shard_stats = cache.per_shard_stats();
686        let total_hits: u64 = shard_stats.iter().map(|s| s.hits).sum();
687        assert_eq!(total_hits, 4, "total hits should be 4");
688
689        let stats = cache.stats();
690        assert_eq!(stats.hits, 4);
691        assert_eq!(stats.misses, 0);
692    }
693
694    #[test]
695    fn test_small_capacity_rounds_up() {
696        // Capacity smaller than NUM_SHARDS: each shard gets at least 1.
697        let cache = EmbeddingCache::new(3);
698        assert!(cache.is_enabled());
699
700        let key = cache.compute_key("x", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
701        cache.put(key, vec![42.0]);
702        assert!(cache.get(&key).is_some());
703    }
704}