use crate::model::ModelConfig;
use lru::LruCache;
use parking_lot::RwLock;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::debug;
pub type CacheKey = [u8; 32];
pub const DEFAULT_CACHE_CAPACITY: usize = 4000;
const NUM_SHARDS: usize = 16;
const SHARD_MASK: usize = NUM_SHARDS - 1;
const _: () = assert!(
NUM_SHARDS.is_power_of_two(),
"NUM_SHARDS must be a power of 2"
);
struct CacheShard {
lru: RwLock<LruCache<CacheKey, Arc<[f32]>>>,
hits: AtomicU64,
misses: AtomicU64,
}
impl CacheShard {
fn new(capacity: NonZeroUsize) -> Self {
Self {
lru: RwLock::new(LruCache::new(capacity)),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
#[inline]
fn get(&self, key: &CacheKey) -> Option<Arc<[f32]>> {
let mut lru = self.lru.write();
let result = lru.get(key).cloned();
if result.is_some() {
self.hits.fetch_add(1, Ordering::Relaxed);
} else {
self.misses.fetch_add(1, Ordering::Relaxed);
}
result
}
#[inline]
fn put(&self, key: CacheKey, embedding: Arc<[f32]>) {
let mut lru = self.lru.write();
lru.put(key, embedding);
}
fn len(&self) -> usize {
self.lru.read().len()
}
fn clear(&self) {
self.lru.write().clear();
}
fn hits(&self) -> u64 {
self.hits.load(Ordering::Relaxed)
}
fn misses(&self) -> u64 {
self.misses.load(Ordering::Relaxed)
}
}
pub struct EmbeddingCache {
shards: Vec<CacheShard>,
enabled: bool,
capacity: usize,
}
#[inline(always)]
fn shard_index(key: &CacheKey) -> usize {
key[0] as usize & SHARD_MASK
}
impl EmbeddingCache {
pub fn new(capacity: usize) -> Self {
let enabled = capacity != 0;
let per_shard = if enabled {
let base = capacity.div_ceil(NUM_SHARDS);
if base == 0 { 1 } else { base }
} else {
1 };
let per_shard_nz = NonZeroUsize::new(per_shard).expect("per_shard is always >= 1");
let shards = (0..NUM_SHARDS)
.map(|_| CacheShard::new(per_shard_nz))
.collect();
Self {
shards,
enabled,
capacity,
}
}
pub fn with_default_capacity() -> Self {
Self::new(DEFAULT_CACHE_CAPACITY)
}
pub fn compute_key(&self, text: &str, model_config: ModelConfig) -> CacheKey {
let mut hasher = blake3::Hasher::new();
hasher.update(text.as_bytes());
let model_key = format!(
"{}:{}:{}",
model_config.model,
model_config.model.key_version(),
model_config.dimensions(),
);
hasher.update(model_key.as_bytes());
*hasher.finalize().as_bytes()
}
pub fn get(&self, key: &CacheKey) -> Option<Arc<[f32]>> {
if !self.enabled {
return None;
}
let idx = shard_index(key);
let result = self.shards[idx].get(key);
if result.is_some() {
debug!("cache hit for key {:?}", &key[..8]);
}
result
}
pub fn put(&self, key: CacheKey, embedding: Vec<f32>) {
if !self.enabled {
return;
}
let idx = shard_index(&key);
self.shards[idx].put(key, Arc::from(embedding));
debug!("cached embedding for key {:?}", &key[..8]);
}
pub fn get_many(&self, keys: &[CacheKey]) -> Vec<Option<Arc<[f32]>>> {
if !self.enabled {
return vec![None; keys.len()];
}
keys.iter()
.map(|key| {
let idx = shard_index(key);
self.shards[idx].get(key)
})
.collect()
}
pub fn put_many(&self, entries: Vec<(CacheKey, Vec<f32>)>) {
if !self.enabled {
return;
}
for (key, embedding) in entries {
let idx = shard_index(&key);
self.shards[idx].put(key, Arc::from(embedding));
}
}
pub fn stats(&self) -> CacheStats {
if !self.enabled {
let (hits, misses) = self.aggregate_counters();
return CacheStats {
size: 0,
capacity: 0,
hits,
misses,
};
}
let size: usize = self.shards.iter().map(CacheShard::len).sum();
let (hits, misses) = self.aggregate_counters();
CacheStats {
size,
capacity: self.capacity,
hits,
misses,
}
}
pub fn per_shard_stats(&self) -> Vec<ShardStats> {
self.shards
.iter()
.enumerate()
.map(|(i, s)| ShardStats {
shard_id: i,
size: s.len(),
hits: s.hits(),
misses: s.misses(),
})
.collect()
}
pub fn clear(&self) {
if !self.enabled {
return;
}
for shard in &self.shards {
shard.clear();
}
debug!("cache cleared");
}
#[inline]
pub fn is_enabled(&self) -> bool {
self.enabled
}
fn aggregate_counters(&self) -> (u64, u64) {
let hits: u64 = self.shards.iter().map(CacheShard::hits).sum();
let misses: u64 = self.shards.iter().map(CacheShard::misses).sum();
(hits, misses)
}
}
impl Default for EmbeddingCache {
fn default() -> Self {
Self::with_default_capacity()
}
}
#[derive(Debug, Clone, Copy)]
pub struct CacheStats {
pub size: usize,
pub capacity: usize,
pub hits: u64,
pub misses: u64,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct ShardStats {
pub shard_id: usize,
pub size: usize,
pub hits: u64,
pub misses: u64,
}
impl ShardStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::EmbeddingModel;
#[test]
fn test_cache_basic_operations() {
let cache = EmbeddingCache::new(100);
let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
assert!(cache.get(&key).is_none());
let embedding = vec![0.1, 0.2, 0.3];
cache.put(key, embedding.clone());
let cached = cache.get(&key).unwrap();
assert_eq!(&*cached, &embedding[..]);
}
#[test]
fn test_cache_eviction() {
let cache = EmbeddingCache::new(16);
let mut keys = Vec::new();
for i in 0..32u32 {
let text = format!("text_{}", i);
let key = cache.compute_key(&text, ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
keys.push(key);
cache.put(key, vec![i as f32]);
}
let stats = cache.stats();
assert!(stats.size <= 16, "size {} exceeds capacity 16", stats.size);
}
#[test]
fn test_cache_lru_eviction_within_shard() {
let cache = EmbeddingCache::new(32);
let mut same_shard_keys = Vec::new();
let mut i = 0u32;
let target_shard;
let first_key =
cache.compute_key("probe_0", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
target_shard = shard_index(&first_key);
loop {
let key = cache.compute_key(
&format!("lru_test_{}", i),
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
);
if shard_index(&key) == target_shard {
same_shard_keys.push((key, i));
}
if same_shard_keys.len() == 3 {
break;
}
i += 1;
}
let (k1, v1) = same_shard_keys[0];
let (k2, v2) = same_shard_keys[1];
let (k3, v3) = same_shard_keys[2];
cache.put(k1, vec![v1 as f32]);
cache.put(k2, vec![v2 as f32]);
assert!(cache.get(&k1).is_some());
cache.put(k3, vec![v3 as f32]);
assert!(
cache.get(&k1).is_some(),
"k1 should survive (recently accessed)"
);
assert!(cache.get(&k2).is_none(), "k2 should be evicted (LRU)");
assert!(cache.get(&k3).is_some(), "k3 should exist (just inserted)");
}
#[test]
fn test_cache_different_models_different_keys() {
let cache = EmbeddingCache::new(100);
let key_small = cache.compute_key("text", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key_base = cache.compute_key("text", ModelConfig::new(EmbeddingModel::BgeBaseEnV15));
assert_ne!(key_small, key_base);
}
#[test]
fn test_cache_stats() {
let cache = EmbeddingCache::new(100);
let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
cache.get(&key); cache.put(key, vec![0.1]);
cache.get(&key);
let stats = cache.stats();
assert_eq!(stats.size, 1);
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate() - 0.5).abs() < 0.001);
}
#[test]
fn test_cache_get_many() {
let cache = EmbeddingCache::new(100);
let key1 = cache.compute_key("one", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key2 = cache.compute_key("two", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key3 = cache.compute_key("three", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
cache.put(key1, vec![1.0]);
cache.put(key3, vec![3.0]);
let results = cache.get_many(&[key1, key2, key3]);
assert_eq!(results.len(), 3);
assert_eq!(&**results[0].as_ref().unwrap(), &[1.0f32]);
assert!(results[1].is_none());
assert_eq!(&**results[2].as_ref().unwrap(), &[3.0f32]);
}
#[test]
fn test_cache_put_many() {
let cache = EmbeddingCache::new(100);
let key1 = cache.compute_key("one", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key2 = cache.compute_key("two", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
cache.put_many(vec![(key1, vec![1.0]), (key2, vec![2.0])]);
let v1 = cache.get(&key1).unwrap();
assert_eq!(&*v1, [1.0f32].as_slice());
let v2 = cache.get(&key2).unwrap();
assert_eq!(&*v2, [2.0f32].as_slice());
}
#[test]
fn test_cache_clear() {
let cache = EmbeddingCache::new(100);
let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
cache.put(key, vec![0.1]);
assert!(cache.get(&key).is_some());
cache.clear();
assert!(cache.get(&key).is_none());
assert_eq!(cache.stats().size, 0);
}
#[test]
fn test_cache_default_capacity() {
let cache = EmbeddingCache::with_default_capacity();
assert_eq!(cache.stats().capacity, DEFAULT_CACHE_CAPACITY);
}
#[test]
fn test_cache_disabled_is_noop() {
let cache = EmbeddingCache::new(0);
assert!(!cache.is_enabled());
let key = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
cache.put(key, vec![0.1]);
assert!(cache.get(&key).is_none());
let stats = cache.stats();
assert_eq!(stats.capacity, 0);
assert_eq!(stats.size, 0);
}
#[test]
fn test_concurrent_access() {
use std::thread;
let cache = Arc::new(EmbeddingCache::new(4000));
let mut handles = Vec::new();
for t in 0..8 {
let cache = Arc::clone(&cache);
handles.push(thread::spawn(move || {
for i in 0..100 {
let text = format!("thread_{}_item_{}", t, i);
let key =
cache.compute_key(&text, ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let embedding = vec![t as f32; 384];
cache.put(key, embedding.clone());
let result = cache.get(&key);
assert!(result.is_some(), "put followed by get must succeed");
assert_eq!(result.unwrap().len(), 384);
}
}));
}
for h in handles {
h.join().expect("thread panicked");
}
let stats = cache.stats();
assert_eq!(stats.size, 800);
assert!(stats.hits >= 800, "at least 800 hits expected");
}
#[test]
fn test_shard_distribution() {
let cache = EmbeddingCache::new(4000);
let n = 800;
for i in 0..n {
let key = cache.compute_key(
&format!("item_{}", i),
ModelConfig::new(EmbeddingModel::BgeSmallEnV15),
);
cache.put(key, vec![i as f32]);
}
let shard_stats = cache.per_shard_stats();
assert_eq!(shard_stats.len(), NUM_SHARDS);
for ss in &shard_stats {
assert!(
ss.size > 0,
"shard {} is empty — distribution is pathological",
ss.shard_id
);
}
let total: usize = shard_stats.iter().map(|s| s.size).sum();
assert_eq!(total, n);
let avg = n / NUM_SHARDS; for ss in &shard_stats {
assert!(
ss.size <= avg * 3,
"shard {} has {} entries (avg {}), distribution too skewed",
ss.shard_id,
ss.size,
avg
);
}
}
#[test]
fn test_per_shard_stats_hit_tracking() {
let cache = EmbeddingCache::new(100);
let key1 = cache.compute_key("hello", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
let key2 = cache.compute_key("world", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
cache.put(key1, vec![1.0]);
cache.put(key2, vec![2.0]);
cache.get(&key1);
cache.get(&key1);
cache.get(&key1);
cache.get(&key2);
let shard_stats = cache.per_shard_stats();
let total_hits: u64 = shard_stats.iter().map(|s| s.hits).sum();
assert_eq!(total_hits, 4, "total hits should be 4");
let stats = cache.stats();
assert_eq!(stats.hits, 4);
assert_eq!(stats.misses, 0);
}
#[test]
fn test_small_capacity_rounds_up() {
let cache = EmbeddingCache::new(3);
assert!(cache.is_enabled());
let key = cache.compute_key("x", ModelConfig::new(EmbeddingModel::BgeSmallEnV15));
cache.put(key, vec![42.0]);
assert!(cache.get(&key).is_some());
}
}