Skip to main content

lattice_embed/
cache.rs

1//! Sharded embedding cache with LRU eviction.
2//!
3//! Caches embeddings to avoid re-computing for identical texts. Uses 16 independent
4//! shards to reduce write-lock contention — each `get()` on an LRU cache requires a
5//! write lock (to update access order), so a single `RwLock<LruCache>` serializes all
6//! reads under high QPS. Sharding reduces contention by a factor of `NUM_SHARDS`.
7//!
8//! # Design
9//!
10//! - **Shard selection**: First byte of the Blake3 cache key, masked to `NUM_SHARDS - 1`.
11//!   Blake3 output is uniformly distributed, so shard load is balanced.
12//! - **Per-shard capacity**: `total_capacity / NUM_SHARDS`. Each shard independently
13//!   evicts its own LRU entries.
14//! - **Per-shard statistics**: Hit/miss counters are per-shard `AtomicU64`s, aggregated
15//!   in `stats()`. This eliminates cross-shard atomic contention on the hot path.
16//! - **Zero-capacity**: `capacity=0` disables caching entirely — all operations become
17//!   no-ops with no locking or hashing work.
18
19use crate::model::ModelConfig;
20use crate::service::EmbeddingRole;
21use lru::LruCache;
22use parking_lot::RwLock;
23use std::num::NonZeroUsize;
24use std::sync::Arc;
25use std::sync::atomic::{AtomicU64, Ordering};
26use tracing::debug;
27
28/// **Unstable**: internal implementation detail; type alias may change with cache redesign.
29pub type CacheKey = [u8; 32];
30
31/// **Unstable**: tuning constant; value may change as memory models evolve.
32///
33/// Default cache capacity (number of embeddings). ~6MB for 384-dim vectors at 4000 entries.
34pub const DEFAULT_CACHE_CAPACITY: usize = 4000;
35
36/// Number of cache shards. Must be a power of 2 for fast modulo (bitwise AND).
37/// 16 shards on 8-core M4 Pro gives 2x oversubscription, keeping contention low.
38const NUM_SHARDS: usize = 16;
39
40/// Mask for shard index computation: `key[0] as usize & SHARD_MASK`.
41const SHARD_MASK: usize = NUM_SHARDS - 1;
42
43// Compile-time assertion that NUM_SHARDS is a power of 2.
44const _: () = assert!(
45    NUM_SHARDS.is_power_of_two(),
46    "NUM_SHARDS must be a power of 2"
47);
48
49/// A single cache shard with its own LRU cache and hit/miss counters.
50struct CacheShard {
51    lru: RwLock<LruCache<CacheKey, Arc<[f32]>>>,
52    hits: AtomicU64,
53    misses: AtomicU64,
54}
55
56impl CacheShard {
57    fn new(capacity: NonZeroUsize) -> Self {
58        Self {
59            lru: RwLock::new(LruCache::new(capacity)),
60            hits: AtomicU64::new(0),
61            misses: AtomicU64::new(0),
62        }
63    }
64
65    #[inline]
66    fn get(&self, key: &CacheKey) -> Option<Arc<[f32]>> {
67        let mut lru = self.lru.write();
68        let result = lru.get(key).cloned();
69        if result.is_some() {
70            self.hits.fetch_add(1, Ordering::Relaxed);
71        } else {
72            self.misses.fetch_add(1, Ordering::Relaxed);
73        }
74        result
75    }
76
77    #[inline]
78    fn put(&self, key: CacheKey, embedding: Arc<[f32]>) {
79        let mut lru = self.lru.write();
80        lru.put(key, embedding);
81    }
82
83    fn len(&self) -> usize {
84        self.lru.read().len()
85    }
86
87    fn clear(&self) {
88        self.lru.write().clear();
89    }
90
91    fn hits(&self) -> u64 {
92        self.hits.load(Ordering::Relaxed)
93    }
94
95    fn misses(&self) -> u64 {
96        self.misses.load(Ordering::Relaxed)
97    }
98}
99
100/// **Unstable**: internal LRU caching mechanism; shard count and eviction policy may change.
101///
102/// Embedding cache with sharded LRU eviction policy.
103///
104/// Thread-safe cache for storing computed embeddings. Uses Blake3 hashing
105/// for fast, collision-resistant cache keys. Internally sharded into 16
106/// independent LRU caches to reduce write-lock contention.
107///
108/// # Disabling
109///
110/// Pass `capacity=0` to disable caching. All cache operations become no-ops
111/// (no locking, no hashing work beyond key construction).
112///
113/// # Example
114///
115/// ```rust
116/// use lattice_embed::{EmbeddingCache, EmbeddingModel, ModelConfig};
117/// use lattice_embed::service::EmbeddingRole;
118///
119/// let cache = EmbeddingCache::new(1000);
120///
121/// // Cache miss - no embedding stored yet
122/// let key = cache.compute_key(
123///     "Hello, world!",
124///     ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
125///     EmbeddingRole::Generic,
126/// );
127/// assert!(cache.get(&key).is_none());
128///
129/// // Store embedding
130/// let embedding = vec![0.1, 0.2, 0.3];
131/// cache.put(key, embedding.clone());
132///
133/// // Cache hit — returns Arc<[f32]>
134/// let cached = cache.get(&key).unwrap();
135/// assert_eq!(&*cached, &embedding[..]);
136/// ```
137pub struct EmbeddingCache {
138    shards: Vec<CacheShard>,
139    enabled: bool,
140    capacity: usize,
141}
142
143/// Select shard index from a cache key. Uses first byte masked to shard count.
144/// Blake3 output is uniformly distributed, so this gives balanced load.
145#[inline(always)]
146fn shard_index(key: &CacheKey) -> usize {
147    key[0] as usize & SHARD_MASK
148}
149
150impl EmbeddingCache {
151    /// **Unstable**: constructor signature may change when shard count becomes configurable.
152    ///
153    /// The capacity is divided equally across 16 internal shards. Each shard
154    /// independently manages its own LRU eviction.
155    ///
156    /// # Arguments
157    ///
158    /// * `capacity` - Maximum total number of embeddings to cache. Use 0 to disable caching.
159    pub fn new(capacity: usize) -> Self {
160        let enabled = capacity != 0;
161
162        // Per-shard capacity: ceiling division ensures total actual capacity >= requested.
163        // E.g., capacity=4000, NUM_SHARDS=16 → 250/shard (exact).
164        // E.g., capacity=10, NUM_SHARDS=16 → 1/shard (at least 1).
165        let per_shard = if enabled {
166            // Ceiling division: (capacity + NUM_SHARDS - 1) / NUM_SHARDS, minimum 1.
167            let base = capacity.div_ceil(NUM_SHARDS);
168            if base == 0 { 1 } else { base }
169        } else {
170            1 // Dummy capacity for disabled cache
171        };
172
173        let per_shard_nz = NonZeroUsize::new(per_shard).expect("per_shard is always >= 1");
174
175        let shards = (0..NUM_SHARDS)
176            .map(|_| CacheShard::new(per_shard_nz))
177            .collect();
178
179        Self {
180            shards,
181            enabled,
182            capacity,
183        }
184    }
185
186    /// **Unstable**: convenience constructor; subject to change with cache redesign.
187    pub fn with_default_capacity() -> Self {
188        Self::new(DEFAULT_CACHE_CAPACITY)
189    }
190
191    /// **Unstable**: key scheme (Blake3 + EmbeddingKey canonical bytes) may change; don't store keys across sessions.
192    ///
193    /// Uses Blake3 hashing for fast, collision-resistant keys. The key includes the model
194    /// name, revision, and active dimension from the `ModelConfig`, so different MRL truncations
195    /// produce different cache keys.
196    ///
197    /// The role is also included so that `embed_query("hello")` and `embed_passage("hello")`
198    /// produce different cache entries even when the raw text and model config are identical.
199    /// Use `EmbeddingRole::Generic` for the backwards-compatible `embed()` path.
200    pub fn compute_key(
201        &self,
202        text: &str,
203        model_config: ModelConfig,
204        role: EmbeddingRole,
205    ) -> CacheKey {
206        let mut hasher = blake3::Hasher::new();
207        hasher.update(text.as_bytes());
208        // Unique identifier for the model config: "model_name:version:dims"
209        let model_key = format!(
210            "{}:{}:{}:{}",
211            model_config.model,
212            model_config.model.key_version(),
213            model_config.dimensions(),
214            role.cache_tag(),
215        );
216        hasher.update(model_key.as_bytes());
217        *hasher.finalize().as_bytes()
218    }
219
220    /// **Unstable**: return type (`Arc<[f32]>`) may change to a newtype; internal cache API.
221    ///
222    /// Returns `Some(Arc<[f32]>)` if found (cheap refcount bump), `None` otherwise.
223    /// Updates per-shard hit/miss counters for metrics.
224    pub fn get(&self, key: &CacheKey) -> Option<Arc<[f32]>> {
225        if !self.enabled {
226            return None;
227        }
228
229        let idx = shard_index(key);
230        let result = self.shards[idx].get(key);
231
232        if result.is_some() {
233            debug!("cache hit for key {:?}", &key[..8]);
234        }
235
236        result
237    }
238
239    /// **Unstable**: internal cache storage method; interface may change.
240    ///
241    /// Converts the Vec into `Arc<[f32]>` for shared-ownership storage.
242    /// If the shard is at capacity, its least recently used entry is evicted.
243    pub fn put(&self, key: CacheKey, embedding: Vec<f32>) {
244        if !self.enabled {
245            return;
246        }
247
248        let idx = shard_index(&key);
249        self.shards[idx].put(key, Arc::from(embedding));
250        debug!("cached embedding for key {:?}", &key[..8]);
251    }
252
253    /// **Unstable**: batch cache access; return type may change with cache redesign.
254    ///
255    /// Returns a vector of `Option<Arc<[f32]>>` for each key, in the same order.
256    /// Each hit is an O(1) refcount bump (no data copy).
257    pub fn get_many(&self, keys: &[CacheKey]) -> Vec<Option<Arc<[f32]>>> {
258        if !self.enabled {
259            return vec![None; keys.len()];
260        }
261
262        keys.iter()
263            .map(|key| {
264                let idx = shard_index(key);
265                self.shards[idx].get(key)
266            })
267            .collect()
268    }
269
270    /// **Unstable**: batch cache storage; interface may change with cache redesign.
271    ///
272    /// Converts each Vec into `Arc<[f32]>` for shared-ownership storage.
273    pub fn put_many(&self, entries: Vec<(CacheKey, Vec<f32>)>) {
274        if !self.enabled {
275            return;
276        }
277
278        for (key, embedding) in entries {
279            let idx = shard_index(&key);
280            self.shards[idx].put(key, Arc::from(embedding));
281        }
282    }
283
284    /// **Unstable**: returns `CacheStats` which is itself Unstable; metrics shape may evolve.
285    ///
286    /// Aggregates per-shard counters. The `size` field is the sum of all shard sizes.
287    pub fn stats(&self) -> CacheStats {
288        if !self.enabled {
289            let (hits, misses) = self.aggregate_counters();
290            return CacheStats {
291                size: 0,
292                capacity: 0,
293                hits,
294                misses,
295            };
296        }
297
298        let size: usize = self.shards.iter().map(CacheShard::len).sum();
299        let (hits, misses) = self.aggregate_counters();
300
301        CacheStats {
302            size,
303            capacity: self.capacity,
304            hits,
305            misses,
306        }
307    }
308
309    /// **Unstable**: internal monitoring hook; shard count and `ShardStats` shape may change.
310    ///
311    /// Returns a vector of `(size, hits, misses)` tuples, one per shard.
312    pub fn per_shard_stats(&self) -> Vec<ShardStats> {
313        self.shards
314            .iter()
315            .enumerate()
316            .map(|(i, s)| ShardStats {
317                shard_id: i,
318                size: s.len(),
319                hits: s.hits(),
320                misses: s.misses(),
321            })
322            .collect()
323    }
324
325    /// **Unstable**: internal cache management; may be removed in favor of capacity-based eviction.
326    pub fn clear(&self) {
327        if !self.enabled {
328            return;
329        }
330
331        for shard in &self.shards {
332            shard.clear();
333        }
334        debug!("cache cleared");
335    }
336
337    /// **Unstable**: internal state query; may be removed when zero-capacity is the only disable path.
338    #[inline]
339    pub fn is_enabled(&self) -> bool {
340        self.enabled
341    }
342
343    /// Aggregate hit/miss counters across all shards.
344    fn aggregate_counters(&self) -> (u64, u64) {
345        let hits: u64 = self.shards.iter().map(CacheShard::hits).sum();
346        let misses: u64 = self.shards.iter().map(CacheShard::misses).sum();
347        (hits, misses)
348    }
349}
350
351impl Default for EmbeddingCache {
352    fn default() -> Self {
353        Self::with_default_capacity()
354    }
355}
356
357/// **Unstable**: metrics fields may be added/removed as monitoring needs evolve.
358///
359/// Cache statistics (aggregated across all shards).
360#[derive(Debug, Clone, Copy)]
361pub struct CacheStats {
362    /// Current number of cached entries (sum across all shards).
363    pub size: usize,
364    /// Maximum total cache capacity.
365    pub capacity: usize,
366    /// Number of cache hits (sum across all shards).
367    pub hits: u64,
368    /// Number of cache misses (sum across all shards).
369    pub misses: u64,
370}
371
372impl CacheStats {
373    /// **Unstable**: convenience metric; may move to a separate stats helper.
374    pub fn hit_rate(&self) -> f64 {
375        let total = self.hits + self.misses;
376        if total == 0 {
377            0.0
378        } else {
379            self.hits as f64 / total as f64
380        }
381    }
382}
383
384/// **Unstable**: shard count is an internal implementation detail; this struct may be removed.
385///
386/// Per-shard statistics for detailed monitoring.
387#[derive(Debug, Clone, Copy)]
388pub struct ShardStats {
389    /// Shard index (0 to NUM_SHARDS-1).
390    pub shard_id: usize,
391    /// Current number of entries in this shard.
392    pub size: usize,
393    /// Number of cache hits in this shard.
394    pub hits: u64,
395    /// Number of cache misses in this shard.
396    pub misses: u64,
397}
398
399impl ShardStats {
400    /// **Unstable**: per-shard metric; may be removed with `ShardStats`.
401    pub fn hit_rate(&self) -> f64 {
402        let total = self.hits + self.misses;
403        if total == 0 {
404            0.0
405        } else {
406            self.hits as f64 / total as f64
407        }
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use crate::model::EmbeddingModel;
415
416    #[test]
417    fn test_cache_basic_operations() {
418        let cache = EmbeddingCache::new(100);
419        let key = cache.compute_key(
420            "hello",
421            ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
422            EmbeddingRole::Generic,
423        );
424
425        // Miss
426        assert!(cache.get(&key).is_none());
427
428        // Put
429        let embedding = vec![0.1, 0.2, 0.3];
430        cache.put(key, embedding.clone());
431
432        // Hit — returns Arc<[f32]>
433        let cached = cache.get(&key).unwrap();
434        assert_eq!(&*cached, &embedding[..]);
435    }
436
437    #[test]
438    fn test_cache_eviction() {
439        // With 16 shards, a capacity of 16 gives 1 entry per shard.
440        // To test eviction, we need keys that hash to the same shard.
441        // Use a larger capacity and fill it up.
442        let cache = EmbeddingCache::new(16);
443
444        // Insert 32 entries — each shard has capacity 1, so each shard
445        // can only hold 1 entry. Inserting 2 entries to the same shard
446        // will evict the first.
447        let mut keys = Vec::new();
448        for i in 0..32u32 {
449            let text = format!("text_{i}");
450            let key = cache.compute_key(
451                &text,
452                ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
453                EmbeddingRole::Generic,
454            );
455            keys.push(key);
456            cache.put(key, vec![i as f32]);
457        }
458
459        // Total size should not exceed capacity (16)
460        let stats = cache.stats();
461        assert!(stats.size <= 16, "size {} exceeds capacity 16", stats.size);
462    }
463
464    #[test]
465    fn test_cache_lru_eviction_within_shard() {
466        // Create cache with capacity 32 (2 per shard).
467        let cache = EmbeddingCache::new(32);
468
469        // Find 3 keys that land in the same shard.
470        let mut same_shard_keys = Vec::new();
471        let mut i = 0u32;
472
473        // Find the first key's shard and collect 3 keys for it.
474        let first_key = cache.compute_key(
475            "probe_0",
476            ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
477            EmbeddingRole::Generic,
478        );
479        let target_shard = shard_index(&first_key);
480
481        loop {
482            let key = cache.compute_key(
483                &format!("lru_test_{i}"),
484                ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
485                EmbeddingRole::Generic,
486            );
487            if shard_index(&key) == target_shard {
488                same_shard_keys.push((key, i));
489            }
490            if same_shard_keys.len() == 3 {
491                break;
492            }
493            i += 1;
494        }
495
496        let (k1, v1) = same_shard_keys[0];
497        let (k2, v2) = same_shard_keys[1];
498        let (k3, v3) = same_shard_keys[2];
499
500        // Insert k1 and k2 (shard capacity is 2).
501        cache.put(k1, vec![v1 as f32]);
502        cache.put(k2, vec![v2 as f32]);
503
504        // Access k1 to make it recently used.
505        assert!(cache.get(&k1).is_some());
506
507        // Insert k3 — should evict k2 (least recently used in this shard).
508        cache.put(k3, vec![v3 as f32]);
509
510        assert!(
511            cache.get(&k1).is_some(),
512            "k1 should survive (recently accessed)"
513        );
514        assert!(cache.get(&k2).is_none(), "k2 should be evicted (LRU)");
515        assert!(cache.get(&k3).is_some(), "k3 should exist (just inserted)");
516    }
517
518    #[test]
519    fn test_cache_different_models_different_keys() {
520        let cache = EmbeddingCache::new(100);
521
522        let key_small = cache.compute_key(
523            "text",
524            ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
525            EmbeddingRole::Generic,
526        );
527        let key_base = cache.compute_key(
528            "text",
529            ModelConfig::new(EmbeddingModel::BgeBaseEnV15),
530            EmbeddingRole::Generic,
531        );
532
533        // Same text, different models = different keys
534        assert_ne!(key_small, key_base);
535    }
536
537    #[test]
538    fn test_cache_stats() {
539        let cache = EmbeddingCache::new(100);
540        let key = cache.compute_key(
541            "hello",
542            ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
543            EmbeddingRole::Generic,
544        );
545
546        cache.get(&key); // Miss
547        cache.put(key, vec![0.1]);
548        cache.get(&key); // Hit
549
550        let stats = cache.stats();
551        assert_eq!(stats.size, 1);
552        assert_eq!(stats.hits, 1);
553        assert_eq!(stats.misses, 1);
554        assert!((stats.hit_rate() - 0.5).abs() < 0.001);
555    }
556
557    #[test]
558    fn test_cache_get_many() {
559        // Use capacity large enough that no shard evicts.
560        let cache = EmbeddingCache::new(100);
561
562        let key1 = cache.compute_key(
563            "one",
564            ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
565            EmbeddingRole::Generic,
566        );
567        let key2 = cache.compute_key(
568            "two",
569            ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
570            EmbeddingRole::Generic,
571        );
572        let key3 = cache.compute_key(
573            "three",
574            ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
575            EmbeddingRole::Generic,
576        );
577
578        cache.put(key1, vec![1.0]);
579        cache.put(key3, vec![3.0]);
580
581        let results = cache.get_many(&[key1, key2, key3]);
582        assert_eq!(results.len(), 3);
583        assert_eq!(&**results[0].as_ref().unwrap(), &[1.0f32]);
584        assert!(results[1].is_none());
585        assert_eq!(&**results[2].as_ref().unwrap(), &[3.0f32]);
586    }
587
588    #[test]
589    fn test_cache_put_many() {
590        // Use capacity large enough that no shard evicts (ceil(100/16) = 7 per shard).
591        let cache = EmbeddingCache::new(100);
592
593        let key1 = cache.compute_key(
594            "one",
595            ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
596            EmbeddingRole::Generic,
597        );
598        let key2 = cache.compute_key(
599            "two",
600            ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
601            EmbeddingRole::Generic,
602        );
603
604        cache.put_many(vec![(key1, vec![1.0]), (key2, vec![2.0])]);
605
606        let v1 = cache.get(&key1).unwrap();
607        assert_eq!(&*v1, [1.0f32].as_slice());
608        let v2 = cache.get(&key2).unwrap();
609        assert_eq!(&*v2, [2.0f32].as_slice());
610    }
611
612    #[test]
613    fn test_cache_clear() {
614        let cache = EmbeddingCache::new(100);
615        let key = cache.compute_key(
616            "hello",
617            ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
618            EmbeddingRole::Generic,
619        );
620
621        cache.put(key, vec![0.1]);
622        assert!(cache.get(&key).is_some());
623
624        cache.clear();
625        assert!(cache.get(&key).is_none());
626        assert_eq!(cache.stats().size, 0);
627    }
628
629    #[test]
630    fn test_cache_default_capacity() {
631        let cache = EmbeddingCache::with_default_capacity();
632        assert_eq!(cache.stats().capacity, DEFAULT_CACHE_CAPACITY);
633    }
634
635    #[test]
636    fn test_cache_disabled_is_noop() {
637        let cache = EmbeddingCache::new(0);
638        assert!(!cache.is_enabled());
639
640        let key = cache.compute_key(
641            "hello",
642            ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
643            EmbeddingRole::Generic,
644        );
645        cache.put(key, vec![0.1]);
646        assert!(cache.get(&key).is_none());
647
648        let stats = cache.stats();
649        assert_eq!(stats.capacity, 0);
650        assert_eq!(stats.size, 0);
651    }
652
653    #[test]
654    fn test_concurrent_access() {
655        use std::thread;
656
657        // Use large capacity so no eviction occurs (800 entries across 16 shards).
658        // Per-shard capacity = ceil(4000/16) = 250, so 800 entries fit easily.
659        let cache = Arc::new(EmbeddingCache::new(4000));
660        let mut handles = Vec::new();
661
662        // Spawn 8 threads, each doing 100 put+get operations.
663        for t in 0..8 {
664            let cache = Arc::clone(&cache);
665            handles.push(thread::spawn(move || {
666                for i in 0..100 {
667                    // Each thread uses unique keys to avoid contention on same entry.
668                    let text = format!("thread_{t}_item_{i}");
669                    let key = cache.compute_key(
670                        &text,
671                        ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
672                        EmbeddingRole::Generic,
673                    );
674                    let embedding = vec![t as f32; 384];
675                    cache.put(key, embedding.clone());
676
677                    let result = cache.get(&key);
678                    assert!(result.is_some(), "put followed by get must succeed");
679                    assert_eq!(result.unwrap().len(), 384);
680                }
681            }));
682        }
683
684        for h in handles {
685            h.join().expect("thread panicked");
686        }
687
688        // All 800 entries should be in cache (capacity 4000 >> 800).
689        let stats = cache.stats();
690        assert_eq!(stats.size, 800);
691        assert!(stats.hits >= 800, "at least 800 hits expected");
692    }
693
694    #[test]
695    fn test_shard_distribution() {
696        // Use generous capacity to avoid eviction from uneven distribution.
697        // 4000 total → 250/shard. We insert 800 entries → ~50/shard on average.
698        let cache = EmbeddingCache::new(4000);
699
700        let n = 800;
701        for i in 0..n {
702            let key = cache.compute_key(
703                &format!("item_{i}"),
704                ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
705                EmbeddingRole::Generic,
706            );
707            cache.put(key, vec![i as f32]);
708        }
709
710        let shard_stats = cache.per_shard_stats();
711        assert_eq!(shard_stats.len(), NUM_SHARDS);
712
713        // Each shard should have entries. With uniform hash, each shard gets ~50.
714        for ss in &shard_stats {
715            assert!(
716                ss.size > 0,
717                "shard {} is empty — distribution is pathological",
718                ss.shard_id
719            );
720        }
721
722        // No eviction since 800 << 4000. Total should be exactly 800.
723        let total: usize = shard_stats.iter().map(|s| s.size).sum();
724        assert_eq!(total, n);
725
726        // Check distribution is reasonably uniform: no shard has >3x the average.
727        let avg = n / NUM_SHARDS; // 50
728        for ss in &shard_stats {
729            assert!(
730                ss.size <= avg * 3,
731                "shard {} has {} entries (avg {}), distribution too skewed",
732                ss.shard_id,
733                ss.size,
734                avg
735            );
736        }
737    }
738
739    #[test]
740    fn test_per_shard_stats_hit_tracking() {
741        let cache = EmbeddingCache::new(100);
742
743        // Insert a few entries and access them.
744        let key1 = cache.compute_key(
745            "hello",
746            ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
747            EmbeddingRole::Generic,
748        );
749        let key2 = cache.compute_key(
750            "world",
751            ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
752            EmbeddingRole::Generic,
753        );
754
755        cache.put(key1, vec![1.0]);
756        cache.put(key2, vec![2.0]);
757
758        // Access key1 three times, key2 once.
759        cache.get(&key1);
760        cache.get(&key1);
761        cache.get(&key1);
762        cache.get(&key2);
763
764        let shard_stats = cache.per_shard_stats();
765        let total_hits: u64 = shard_stats.iter().map(|s| s.hits).sum();
766        assert_eq!(total_hits, 4, "total hits should be 4");
767
768        let stats = cache.stats();
769        assert_eq!(stats.hits, 4);
770        assert_eq!(stats.misses, 0);
771    }
772
773    #[test]
774    fn test_small_capacity_rounds_up() {
775        // Capacity smaller than NUM_SHARDS: each shard gets at least 1.
776        let cache = EmbeddingCache::new(3);
777        assert!(cache.is_enabled());
778
779        let key = cache.compute_key(
780            "x",
781            ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
782            EmbeddingRole::Generic,
783        );
784        cache.put(key, vec![42.0]);
785        assert!(cache.get(&key).is_some());
786    }
787
788    // -------------------------------------------------------------------------
789    // Role-aware cache key tests (P0-E2)
790    // -------------------------------------------------------------------------
791
792    /// Query and passage roles must produce different cache keys for the same raw text
793    /// so that embed_query("hello") and embed_passage("hello") are stored separately.
794    #[test]
795    fn test_role_query_vs_passage_different_keys() {
796        let cache = EmbeddingCache::new(100);
797        let model = ModelConfig::new(EmbeddingModel::MultilingualE5Small);
798        let text = "hello world";
799
800        let key_query = cache.compute_key(text, model, EmbeddingRole::Query);
801        let key_passage = cache.compute_key(text, model, EmbeddingRole::Passage);
802        let key_generic = cache.compute_key(text, model, EmbeddingRole::Generic);
803
804        assert_ne!(key_query, key_passage, "query vs passage must differ");
805        assert_ne!(key_query, key_generic, "query vs generic must differ");
806        assert_ne!(key_passage, key_generic, "passage vs generic must differ");
807    }
808
809    /// Role keys are consistent: same inputs always produce same key.
810    #[test]
811    fn test_role_key_deterministic() {
812        let cache = EmbeddingCache::new(100);
813        let model = ModelConfig::new(EmbeddingModel::BgeSmallEnV15);
814
815        let k1 = cache.compute_key("test", model, EmbeddingRole::Query);
816        let k2 = cache.compute_key("test", model, EmbeddingRole::Query);
817        assert_eq!(k1, k2, "identical inputs must produce identical key");
818    }
819
820    /// Storing under one role does not pollute the other role's key.
821    #[test]
822    fn test_role_cache_isolation() {
823        let cache = EmbeddingCache::new(100);
824        let model = ModelConfig::new(EmbeddingModel::MultilingualE5Small);
825
826        let key_query = cache.compute_key("embed me", model, EmbeddingRole::Query);
827        let key_passage = cache.compute_key("embed me", model, EmbeddingRole::Passage);
828
829        // Store under Query role.
830        cache.put(key_query, vec![1.0, 2.0]);
831
832        // Passage role key must still be a miss.
833        assert!(
834            cache.get(&key_passage).is_none(),
835            "passage key must miss after storing under query key"
836        );
837
838        // Query role key must hit.
839        assert!(
840            cache.get(&key_query).is_some(),
841            "query key must hit after storing under query key"
842        );
843    }
844}