use crate::model::ModelConfig;
use crate::service::EmbeddingRole;
use lru::LruCache;
use parking_lot::RwLock;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::debug;
pub type CacheKey = [u8; 32];
pub const DEFAULT_CACHE_CAPACITY: usize = 4000;
const NUM_SHARDS: usize = 16;
const SHARD_MASK: usize = NUM_SHARDS - 1;
const _: () = assert!(
NUM_SHARDS.is_power_of_two(),
"NUM_SHARDS must be a power of 2"
);
struct CacheShard {
lru: RwLock<LruCache<CacheKey, Arc<[f32]>>>,
hits: AtomicU64,
misses: AtomicU64,
}
impl CacheShard {
fn new(capacity: NonZeroUsize) -> Self {
Self {
lru: RwLock::new(LruCache::new(capacity)),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
#[inline]
fn get(&self, key: &CacheKey) -> Option<Arc<[f32]>> {
let mut lru = self.lru.write();
let result = lru.get(key).cloned();
if result.is_some() {
self.hits.fetch_add(1, Ordering::Relaxed);
} else {
self.misses.fetch_add(1, Ordering::Relaxed);
}
result
}
#[inline]
fn put(&self, key: CacheKey, embedding: Arc<[f32]>) {
let mut lru = self.lru.write();
lru.put(key, embedding);
}
fn len(&self) -> usize {
self.lru.read().len()
}
fn clear(&self) {
self.lru.write().clear();
}
fn hits(&self) -> u64 {
self.hits.load(Ordering::Relaxed)
}
fn misses(&self) -> u64 {
self.misses.load(Ordering::Relaxed)
}
}
pub struct EmbeddingCache {
shards: Vec<CacheShard>,
enabled: bool,
capacity: usize,
}
#[inline(always)]
fn shard_index(key: &CacheKey) -> usize {
key[0] as usize & SHARD_MASK
}
impl EmbeddingCache {
pub fn new(capacity: usize) -> Self {
let enabled = capacity != 0;
let per_shard = if enabled {
let base = capacity.div_ceil(NUM_SHARDS);
if base == 0 { 1 } else { base }
} else {
1 };
let per_shard_nz = NonZeroUsize::new(per_shard).expect("per_shard is always >= 1");
let shards = (0..NUM_SHARDS)
.map(|_| CacheShard::new(per_shard_nz))
.collect();
Self {
shards,
enabled,
capacity,
}
}
pub fn with_default_capacity() -> Self {
Self::new(DEFAULT_CACHE_CAPACITY)
}
pub fn compute_key(
&self,
text: &str,
model_config: ModelConfig,
role: EmbeddingRole,
) -> CacheKey {
let mut hasher = blake3::Hasher::new();
hasher.update(text.as_bytes());
let model_key = format!(
"{}:{}:{}:{}",
model_config.model,
model_config.model.key_version(),
model_config.dimensions(),
role.cache_tag(),
);
hasher.update(model_key.as_bytes());
*hasher.finalize().as_bytes()
}
pub fn get(&self, key: &CacheKey) -> Option<Arc<[f32]>> {
if !self.enabled {
return None;
}
let idx = shard_index(key);
let result = self.shards[idx].get(key);
if result.is_some() {
debug!("cache hit for key {:?}", &key[..8]);
}
result
}
pub fn put(&self, key: CacheKey, embedding: Vec<f32>) {
if !self.enabled {
return;
}
let idx = shard_index(&key);
self.shards[idx].put(key, Arc::from(embedding));
debug!("cached embedding for key {:?}", &key[..8]);
}
pub fn get_many(&self, keys: &[CacheKey]) -> Vec<Option<Arc<[f32]>>> {
if !self.enabled {
return vec![None; keys.len()];
}
keys.iter()
.map(|key| {
let idx = shard_index(key);
self.shards[idx].get(key)
})
.collect()
}
pub fn put_many(&self, entries: Vec<(CacheKey, Vec<f32>)>) {
if !self.enabled {
return;
}
for (key, embedding) in entries {
let idx = shard_index(&key);
self.shards[idx].put(key, Arc::from(embedding));
}
}
pub fn stats(&self) -> CacheStats {
if !self.enabled {
let (hits, misses) = self.aggregate_counters();
return CacheStats {
size: 0,
capacity: 0,
hits,
misses,
};
}
let size: usize = self.shards.iter().map(CacheShard::len).sum();
let (hits, misses) = self.aggregate_counters();
CacheStats {
size,
capacity: self.capacity,
hits,
misses,
}
}
pub fn per_shard_stats(&self) -> Vec<ShardStats> {
self.shards
.iter()
.enumerate()
.map(|(i, s)| ShardStats {
shard_id: i,
size: s.len(),
hits: s.hits(),
misses: s.misses(),
})
.collect()
}
pub fn clear(&self) {
if !self.enabled {
return;
}
for shard in &self.shards {
shard.clear();
}
debug!("cache cleared");
}
#[inline]
pub fn is_enabled(&self) -> bool {
self.enabled
}
fn aggregate_counters(&self) -> (u64, u64) {
let hits: u64 = self.shards.iter().map(CacheShard::hits).sum();
let misses: u64 = self.shards.iter().map(CacheShard::misses).sum();
(hits, misses)
}
}
impl Default for EmbeddingCache {
fn default() -> Self {
Self::with_default_capacity()
}
}
#[derive(Debug, Clone, Copy)]
pub struct CacheStats {
pub size: usize,
pub capacity: usize,
pub hits: u64,
pub misses: u64,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct ShardStats {
pub shard_id: usize,
pub size: usize,
pub hits: u64,
pub misses: u64,
}
impl ShardStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::EmbeddingModel;
#[test]
fn test_cache_basic_operations() {
let cache = EmbeddingCache::new(100);
let key = cache.compute_key(
"hello",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
assert!(cache.get(&key).is_none());
let embedding = vec![0.1, 0.2, 0.3];
cache.put(key, embedding.clone());
let cached = cache.get(&key).unwrap();
assert_eq!(&*cached, &embedding[..]);
}
#[test]
fn test_cache_eviction() {
let cache = EmbeddingCache::new(16);
let mut keys = Vec::new();
for i in 0..32u32 {
let text = format!("text_{i}");
let key = cache.compute_key(
&text,
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
keys.push(key);
cache.put(key, vec![i as f32]);
}
let stats = cache.stats();
assert!(stats.size <= 16, "size {} exceeds capacity 16", stats.size);
}
#[test]
fn test_cache_lru_eviction_within_shard() {
let cache = EmbeddingCache::new(32);
let mut same_shard_keys = Vec::new();
let mut i = 0u32;
let first_key = cache.compute_key(
"probe_0",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
let target_shard = shard_index(&first_key);
loop {
let key = cache.compute_key(
&format!("lru_test_{i}"),
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
if shard_index(&key) == target_shard {
same_shard_keys.push((key, i));
}
if same_shard_keys.len() == 3 {
break;
}
i += 1;
}
let (k1, v1) = same_shard_keys[0];
let (k2, v2) = same_shard_keys[1];
let (k3, v3) = same_shard_keys[2];
cache.put(k1, vec![v1 as f32]);
cache.put(k2, vec![v2 as f32]);
assert!(cache.get(&k1).is_some());
cache.put(k3, vec![v3 as f32]);
assert!(
cache.get(&k1).is_some(),
"k1 should survive (recently accessed)"
);
assert!(cache.get(&k2).is_none(), "k2 should be evicted (LRU)");
assert!(cache.get(&k3).is_some(), "k3 should exist (just inserted)");
}
#[test]
fn test_cache_different_models_different_keys() {
let cache = EmbeddingCache::new(100);
let key_small = cache.compute_key(
"text",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
let key_base = cache.compute_key(
"text",
ModelConfig::new(EmbeddingModel::BgeBaseEnV15),
EmbeddingRole::Generic,
);
assert_ne!(key_small, key_base);
}
#[test]
fn test_cache_stats() {
let cache = EmbeddingCache::new(100);
let key = cache.compute_key(
"hello",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
cache.get(&key); cache.put(key, vec![0.1]);
cache.get(&key);
let stats = cache.stats();
assert_eq!(stats.size, 1);
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate() - 0.5).abs() < 0.001);
}
#[test]
fn test_cache_get_many() {
let cache = EmbeddingCache::new(100);
let key1 = cache.compute_key(
"one",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
let key2 = cache.compute_key(
"two",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
let key3 = cache.compute_key(
"three",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
cache.put(key1, vec![1.0]);
cache.put(key3, vec![3.0]);
let results = cache.get_many(&[key1, key2, key3]);
assert_eq!(results.len(), 3);
assert_eq!(&**results[0].as_ref().unwrap(), &[1.0f32]);
assert!(results[1].is_none());
assert_eq!(&**results[2].as_ref().unwrap(), &[3.0f32]);
}
#[test]
fn test_cache_put_many() {
let cache = EmbeddingCache::new(100);
let key1 = cache.compute_key(
"one",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
let key2 = cache.compute_key(
"two",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
cache.put_many(vec![(key1, vec![1.0]), (key2, vec![2.0])]);
let v1 = cache.get(&key1).unwrap();
assert_eq!(&*v1, [1.0f32].as_slice());
let v2 = cache.get(&key2).unwrap();
assert_eq!(&*v2, [2.0f32].as_slice());
}
#[test]
fn test_cache_clear() {
let cache = EmbeddingCache::new(100);
let key = cache.compute_key(
"hello",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
cache.put(key, vec![0.1]);
assert!(cache.get(&key).is_some());
cache.clear();
assert!(cache.get(&key).is_none());
assert_eq!(cache.stats().size, 0);
}
#[test]
fn test_cache_default_capacity() {
let cache = EmbeddingCache::with_default_capacity();
assert_eq!(cache.stats().capacity, DEFAULT_CACHE_CAPACITY);
}
#[test]
fn test_cache_disabled_is_noop() {
let cache = EmbeddingCache::new(0);
assert!(!cache.is_enabled());
let key = cache.compute_key(
"hello",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
cache.put(key, vec![0.1]);
assert!(cache.get(&key).is_none());
let stats = cache.stats();
assert_eq!(stats.capacity, 0);
assert_eq!(stats.size, 0);
}
#[test]
fn test_concurrent_access() {
use std::thread;
let cache = Arc::new(EmbeddingCache::new(4000));
let mut handles = Vec::new();
for t in 0..8 {
let cache = Arc::clone(&cache);
handles.push(thread::spawn(move || {
for i in 0..100 {
let text = format!("thread_{t}_item_{i}");
let key = cache.compute_key(
&text,
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
let embedding = vec![t as f32; 384];
cache.put(key, embedding.clone());
let result = cache.get(&key);
assert!(result.is_some(), "put followed by get must succeed");
assert_eq!(result.unwrap().len(), 384);
}
}));
}
for h in handles {
h.join().expect("thread panicked");
}
let stats = cache.stats();
assert_eq!(stats.size, 800);
assert!(stats.hits >= 800, "at least 800 hits expected");
}
#[test]
fn test_shard_distribution() {
let cache = EmbeddingCache::new(4000);
let n = 800;
for i in 0..n {
let key = cache.compute_key(
&format!("item_{i}"),
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
cache.put(key, vec![i as f32]);
}
let shard_stats = cache.per_shard_stats();
assert_eq!(shard_stats.len(), NUM_SHARDS);
for ss in &shard_stats {
assert!(
ss.size > 0,
"shard {} is empty — distribution is pathological",
ss.shard_id
);
}
let total: usize = shard_stats.iter().map(|s| s.size).sum();
assert_eq!(total, n);
let avg = n / NUM_SHARDS; for ss in &shard_stats {
assert!(
ss.size <= avg * 3,
"shard {} has {} entries (avg {}), distribution too skewed",
ss.shard_id,
ss.size,
avg
);
}
}
#[test]
fn test_per_shard_stats_hit_tracking() {
let cache = EmbeddingCache::new(100);
let key1 = cache.compute_key(
"hello",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
let key2 = cache.compute_key(
"world",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
cache.put(key1, vec![1.0]);
cache.put(key2, vec![2.0]);
cache.get(&key1);
cache.get(&key1);
cache.get(&key1);
cache.get(&key2);
let shard_stats = cache.per_shard_stats();
let total_hits: u64 = shard_stats.iter().map(|s| s.hits).sum();
assert_eq!(total_hits, 4, "total hits should be 4");
let stats = cache.stats();
assert_eq!(stats.hits, 4);
assert_eq!(stats.misses, 0);
}
#[test]
fn test_small_capacity_rounds_up() {
let cache = EmbeddingCache::new(3);
assert!(cache.is_enabled());
let key = cache.compute_key(
"x",
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
EmbeddingRole::Generic,
);
cache.put(key, vec![42.0]);
assert!(cache.get(&key).is_some());
}
#[test]
fn test_role_query_vs_passage_different_keys() {
let cache = EmbeddingCache::new(100);
let model = ModelConfig::new(EmbeddingModel::MultilingualE5Small);
let text = "hello world";
let key_query = cache.compute_key(text, model, EmbeddingRole::Query);
let key_passage = cache.compute_key(text, model, EmbeddingRole::Passage);
let key_generic = cache.compute_key(text, model, EmbeddingRole::Generic);
assert_ne!(key_query, key_passage, "query vs passage must differ");
assert_ne!(key_query, key_generic, "query vs generic must differ");
assert_ne!(key_passage, key_generic, "passage vs generic must differ");
}
#[test]
fn test_role_key_deterministic() {
let cache = EmbeddingCache::new(100);
let model = ModelConfig::new(EmbeddingModel::BgeSmallEnV15);
let k1 = cache.compute_key("test", model, EmbeddingRole::Query);
let k2 = cache.compute_key("test", model, EmbeddingRole::Query);
assert_eq!(k1, k2, "identical inputs must produce identical key");
}
#[test]
fn test_role_cache_isolation() {
let cache = EmbeddingCache::new(100);
let model = ModelConfig::new(EmbeddingModel::MultilingualE5Small);
let key_query = cache.compute_key("embed me", model, EmbeddingRole::Query);
let key_passage = cache.compute_key("embed me", model, EmbeddingRole::Passage);
cache.put(key_query, vec![1.0, 2.0]);
assert!(
cache.get(&key_passage).is_none(),
"passage key must miss after storing under query key"
);
assert!(
cache.get(&key_query).is_some(),
"query key must hit after storing under query key"
);
}
}