use lru::LruCache;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::hash::{Hash, Hasher};
use std::time::{Duration, Instant};
struct CacheEntry<V> {
value: V,
created_at: Instant,
ttl: Duration,
}
impl<V> CacheEntry<V> {
fn is_expired(&self) -> bool {
self.created_at.elapsed() > self.ttl
}
}
pub struct EmbeddingCache {
inner: RwLock<LruCache<u64, CacheEntry<Vec<f32>>>>,
ttl: Duration,
max_entries: usize,
hits: RwLock<u64>,
misses: RwLock<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub hit_rate: f64,
pub size: usize,
pub capacity: usize,
}
impl EmbeddingCache {
pub fn new(ttl_secs: u64, max_entries: usize) -> Self {
Self {
inner: RwLock::new(LruCache::new(
std::num::NonZeroUsize::new(max_entries).unwrap_or(std::num::NonZeroUsize::MIN),
)),
ttl: Duration::from_secs(ttl_secs),
max_entries,
hits: RwLock::new(0),
misses: RwLock::new(0),
}
}
pub fn content_hash(content: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
content.hash(&mut hasher);
hasher.finish()
}
pub fn get(&self, content: &str) -> Option<Vec<f32>> {
let key = Self::content_hash(content);
let mut inner = self.inner.write();
match inner.get(&key) {
Some(entry) if !entry.is_expired() => {
*self.hits.write() += 1;
Some(entry.value.clone())
}
Some(_) => {
inner.pop(&key);
*self.misses.write() += 1;
None
}
None => {
*self.misses.write() += 1;
None
}
}
}
pub fn insert(&self, content: &str, embedding: Vec<f32>) {
let key = Self::content_hash(content);
let mut inner = self.inner.write();
inner.push(
key,
CacheEntry {
value: embedding,
created_at: Instant::now(),
ttl: self.ttl,
},
);
}
pub fn evict_expired(&self) -> usize {
let mut inner = self.inner.write();
let mut evicted = 0;
let keys: Vec<_> = inner
.iter()
.filter(|(_, entry)| entry.is_expired())
.map(|(k, _)| *k)
.collect();
for key in keys {
inner.pop(&key);
evicted += 1;
}
evicted
}
pub fn evict_lru(&self, target_size: usize) -> usize {
let mut inner = self.inner.write();
let mut evicted = 0;
while inner.len() > target_size {
if inner.pop_lru().is_none() {
break;
}
evicted += 1;
}
evicted
}
pub fn stats(&self) -> CacheStats {
let hits = *self.hits.read();
let misses = *self.misses.read();
let total = hits + misses;
CacheStats {
hits,
misses,
hit_rate: if total > 0 {
hits as f64 / total as f64
} else {
0.0
},
size: self.inner.read().len(),
capacity: self.max_entries,
}
}
pub fn clear(&self) {
self.inner.write().clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_cache_basic() {
let cache = EmbeddingCache::new(60, 100);
cache.insert("hello", vec![1.0, 2.0, 3.0]);
let result = cache.get("hello");
assert!(result.is_some());
assert_eq!(result.unwrap(), vec![1.0, 2.0, 3.0]);
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 0);
}
#[test]
fn test_cache_miss() {
let cache = EmbeddingCache::new(60, 100);
let result = cache.get("nonexistent");
assert!(result.is_none());
let stats = cache.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 1);
}
#[test]
fn test_cache_ttl() {
let cache = EmbeddingCache::new(1, 100);
cache.insert("test", vec![1.0]);
assert!(cache.get("test").is_some());
thread::sleep(Duration::from_secs(2));
assert!(cache.get("test").is_none());
}
#[test]
fn test_cache_eviction() {
let cache = EmbeddingCache::new(60, 2);
cache.insert("a", vec![1.0]);
cache.insert("b", vec![2.0]);
cache.insert("c", vec![3.0]);
assert!(cache.get("a").is_none());
assert!(cache.get("b").is_some());
assert!(cache.get("c").is_some());
}
}