use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub struct EmbeddingCacheStats {
pub hits: u64,
pub misses: u64,
pub entries: usize,
pub bytes_used: usize,
pub max_bytes: usize,
pub hit_rate: f64,
}
struct LruNode {
embedding: Arc<[f32]>,
size_bytes: usize,
prev: Option<String>,
next: Option<String>,
}
struct CacheState {
entries: HashMap<String, LruNode>,
head: Option<String>,
tail: Option<String>,
bytes_used: usize,
}
impl CacheState {
fn new() -> Self {
Self {
entries: HashMap::new(),
head: None,
tail: None,
bytes_used: 0,
}
}
fn move_to_front(&mut self, key: &str) {
if self.head.as_deref() == Some(key) {
return; }
if let Some(node) = self.entries.get(key) {
let prev = node.prev.clone();
let next = node.next.clone();
if let Some(ref prev_key) = prev {
if let Some(prev_node) = self.entries.get_mut(prev_key) {
prev_node.next = next.clone();
}
}
if let Some(ref next_key) = next {
if let Some(next_node) = self.entries.get_mut(next_key) {
next_node.prev = prev.clone();
}
}
if self.tail.as_deref() == Some(key) {
self.tail = prev;
}
}
if let Some(node) = self.entries.get_mut(key) {
node.prev = None;
node.next = self.head.clone();
}
if let Some(ref old_head) = self.head {
if let Some(head_node) = self.entries.get_mut(old_head) {
head_node.prev = Some(key.to_string());
}
}
self.head = Some(key.to_string());
if self.tail.is_none() {
self.tail = self.head.clone();
}
}
fn evict_lru(&mut self) -> Option<usize> {
let tail_key = self.tail.take()?;
if let Some(node) = self.entries.remove(&tail_key) {
self.tail = node.prev.clone();
if let Some(ref new_tail_key) = self.tail {
if let Some(new_tail) = self.entries.get_mut(new_tail_key) {
new_tail.next = None;
}
}
if self.head.as_deref() == Some(&tail_key) {
self.head = None;
}
self.bytes_used -= node.size_bytes;
return Some(node.size_bytes);
}
None
}
}
pub struct EmbeddingCache {
state: Mutex<CacheState>,
max_bytes: usize,
hits: AtomicU64,
misses: AtomicU64,
}
impl EmbeddingCache {
pub fn new(max_bytes: usize) -> Self {
Self {
state: Mutex::new(CacheState::new()),
max_bytes,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
pub fn default_capacity() -> Self {
Self::new(100 * 1024 * 1024) }
pub fn get(&self, key: &str) -> Option<Arc<[f32]>> {
let mut state = self.state.lock().unwrap();
if state.entries.contains_key(key) {
state.move_to_front(key);
self.hits.fetch_add(1, Ordering::Relaxed);
state.entries.get(key).map(|n| n.embedding.clone())
} else {
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
pub fn put(&self, key: String, embedding: Vec<f32>) {
let size_bytes = embedding.len() * std::mem::size_of::<f32>();
if size_bytes > self.max_bytes {
return;
}
let arc: Arc<[f32]> = embedding.into();
let mut state = self.state.lock().unwrap();
if let Some(old_node) = state.entries.remove(&key) {
state.bytes_used -= old_node.size_bytes;
if let Some(ref prev_key) = old_node.prev {
if let Some(prev_node) = state.entries.get_mut(prev_key) {
prev_node.next = old_node.next.clone();
}
}
if let Some(ref next_key) = old_node.next {
if let Some(next_node) = state.entries.get_mut(next_key) {
next_node.prev = old_node.prev.clone();
}
}
if state.head.as_deref() == Some(&key) {
state.head = old_node.next.clone();
}
if state.tail.as_deref() == Some(&key) {
state.tail = old_node.prev.clone();
}
}
while state.bytes_used + size_bytes > self.max_bytes {
if state.evict_lru().is_none() {
break;
}
}
let old_head = state.head.clone();
let node = LruNode {
embedding: arc,
size_bytes,
prev: None,
next: old_head.clone(),
};
if let Some(ref old_head_key) = old_head {
if let Some(head_node) = state.entries.get_mut(old_head_key) {
head_node.prev = Some(key.clone());
}
}
state.entries.insert(key.clone(), node);
state.bytes_used += size_bytes;
state.head = Some(key);
if state.tail.is_none() {
state.tail = state.head.clone();
}
}
pub fn stats(&self) -> EmbeddingCacheStats {
let state = self.state.lock().unwrap();
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
EmbeddingCacheStats {
hits,
misses,
entries: state.entries.len(),
bytes_used: state.bytes_used,
max_bytes: self.max_bytes,
hit_rate: if total > 0 {
(hits as f64 / total as f64) * 100.0
} else {
0.0
},
}
}
pub fn clear(&self) {
let mut state = self.state.lock().unwrap();
state.entries.clear();
state.head = None;
state.tail = None;
state.bytes_used = 0;
}
pub fn len(&self) -> usize {
self.state.lock().unwrap().entries.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for EmbeddingCache {
fn default() -> Self {
Self::default_capacity()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_operations() {
let cache = EmbeddingCache::new(1024 * 1024);
let embedding = vec![1.0, 2.0, 3.0];
cache.put("test-key".to_string(), embedding.clone());
let retrieved = cache.get("test-key").unwrap();
assert_eq!(&*retrieved, &[1.0, 2.0, 3.0]);
assert!(cache.get("nonexistent").is_none());
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.entries, 1);
}
#[test]
fn test_lru_eviction() {
let cache = EmbeddingCache::new(48);
cache.put("a".to_string(), vec![1.0, 2.0, 3.0, 4.0]);
cache.put("b".to_string(), vec![5.0, 6.0, 7.0, 8.0]);
cache.put("c".to_string(), vec![9.0, 10.0, 11.0, 12.0]);
assert_eq!(cache.len(), 3);
cache.put("d".to_string(), vec![13.0, 14.0, 15.0, 16.0]);
assert_eq!(cache.len(), 3);
assert!(cache.get("a").is_none()); assert!(cache.get("b").is_some());
assert!(cache.get("c").is_some());
assert!(cache.get("d").is_some());
}
#[test]
fn test_access_updates_lru() {
let cache = EmbeddingCache::new(32);
cache.put("a".to_string(), vec![1.0, 2.0, 3.0, 4.0]);
cache.put("b".to_string(), vec![5.0, 6.0, 7.0, 8.0]);
let _ = cache.get("a");
cache.put("c".to_string(), vec![9.0, 10.0, 11.0, 12.0]);
assert!(cache.get("a").is_some()); assert!(cache.get("b").is_none()); assert!(cache.get("c").is_some());
}
#[test]
fn test_clear() {
let cache = EmbeddingCache::new(1024 * 1024);
cache.put("a".to_string(), vec![1.0, 2.0, 3.0]);
cache.put("b".to_string(), vec![4.0, 5.0, 6.0]);
assert_eq!(cache.len(), 2);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.get("a").is_none());
assert!(cache.get("b").is_none());
let stats = cache.stats();
assert_eq!(stats.entries, 0);
assert_eq!(stats.bytes_used, 0);
}
#[test]
fn test_update_existing() {
let cache = EmbeddingCache::new(1024 * 1024);
cache.put("key".to_string(), vec![1.0, 2.0, 3.0]);
let v1 = cache.get("key").unwrap();
assert_eq!(&*v1, &[1.0, 2.0, 3.0]);
cache.put("key".to_string(), vec![4.0, 5.0, 6.0, 7.0]);
let v2 = cache.get("key").unwrap();
assert_eq!(&*v2, &[4.0, 5.0, 6.0, 7.0]);
assert_eq!(cache.len(), 1);
}
#[test]
fn test_zero_copy() {
let cache = EmbeddingCache::new(1024 * 1024);
cache.put("key".to_string(), vec![1.0, 2.0, 3.0]);
let ref1 = cache.get("key").unwrap();
let ref2 = cache.get("key").unwrap();
assert!(Arc::ptr_eq(&ref1, &ref2));
}
#[test]
fn test_stats_tracking() {
let cache = EmbeddingCache::new(1024 * 1024);
let stats = cache.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
assert_eq!(stats.hit_rate, 0.0);
cache.put("a".to_string(), vec![1.0, 2.0]);
cache.get("a");
cache.get("nonexistent");
cache.get("a");
let stats = cache.stats();
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate - 66.666).abs() < 1.0);
}
}