use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub size: usize,
pub max_size: usize,
pub evictions: u64,
}
#[derive(Debug, Clone)]
pub struct CachedEmbedding {
pub embedding: Vec<f32>,
pub created_at: u64,
pub hit_count: u32,
}
pub struct EmbeddingCache {
max_size: usize,
cache: HashMap<u64, CachedEmbedding>,
order: VecDeque<u64>,
hits: u64,
misses: u64,
evictions: u64,
}
impl EmbeddingCache {
pub fn new(max_size: usize) -> Self {
Self {
max_size,
cache: HashMap::with_capacity(max_size.min(1024)),
order: VecDeque::with_capacity(max_size.min(1024)),
hits: 0,
misses: 0,
evictions: 0,
}
}
pub fn default_size() -> Self {
Self::new(10_000)
}
fn cache_key(model: &str, text: &str) -> u64 {
let mut hash: u64 = 0xcbf29ce484222325;
let prime: u64 = 0x100000001b3;
for byte in model.as_bytes() {
hash ^= *byte as u64;
hash = hash.wrapping_mul(prime);
}
hash ^= 0xff;
hash = hash.wrapping_mul(prime);
for byte in text.as_bytes() {
hash ^= *byte as u64;
hash = hash.wrapping_mul(prime);
}
hash
}
pub fn get(&mut self, text: &str, model: &str) -> Option<&[f32]> {
let key = Self::cache_key(model, text);
if self.cache.contains_key(&key) {
self.hits += 1;
self.order.retain(|k| *k != key);
self.order.push_back(key);
let entry = self.cache.get_mut(&key).unwrap();
entry.hit_count += 1;
Some(&entry.embedding)
} else {
self.misses += 1;
None
}
}
pub fn put(&mut self, text: &str, model: &str, embedding: Vec<f32>) {
let key = Self::cache_key(model, text);
if self.cache.contains_key(&key) {
self.order.retain(|k| *k != key);
self.order.push_back(key);
self.cache.insert(
key,
CachedEmbedding {
embedding,
created_at: Self::now_micros(),
hit_count: 0,
},
);
return;
}
while self.cache.len() >= self.max_size {
if let Some(evict_key) = self.order.pop_front() {
self.cache.remove(&evict_key);
self.evictions += 1;
} else {
break;
}
}
self.cache.insert(
key,
CachedEmbedding {
embedding,
created_at: Self::now_micros(),
hit_count: 0,
},
);
self.order.push_back(key);
}
pub fn stats(&self) -> CacheStats {
CacheStats {
hits: self.hits,
misses: self.misses,
size: self.cache.len(),
max_size: self.max_size,
evictions: self.evictions,
}
}
pub fn clear(&mut self) {
self.cache.clear();
self.order.clear();
}
fn now_micros() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as u64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_hit_miss() {
let mut cache = EmbeddingCache::new(10);
assert!(cache.get("hello", "model").is_none());
assert_eq!(cache.stats().misses, 1);
cache.put("hello", "model", vec![1.0, 2.0, 3.0]);
let result = cache.get("hello", "model");
assert!(result.is_some());
assert_eq!(result.unwrap(), &[1.0, 2.0, 3.0]);
assert_eq!(cache.stats().hits, 1);
}
#[test]
fn test_lru_eviction() {
let mut cache = EmbeddingCache::new(3);
cache.put("a", "m", vec![1.0]);
cache.put("b", "m", vec![2.0]);
cache.put("c", "m", vec![3.0]);
cache.put("d", "m", vec![4.0]);
assert!(cache.get("a", "m").is_none());
assert!(cache.get("b", "m").is_some());
assert!(cache.get("c", "m").is_some());
assert!(cache.get("d", "m").is_some());
assert_eq!(cache.stats().evictions, 1);
}
}