use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
struct CacheEntry {
embedding: Vec<f32>,
created_at: Instant,
last_accessed: Instant,
}
impl CacheEntry {
fn new(embedding: Vec<f32>) -> Self {
let now = Instant::now();
Self {
embedding,
created_at: now,
last_accessed: now,
}
}
fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed() > ttl
}
fn touch(&mut self) {
self.last_accessed = Instant::now();
}
}
pub struct EmbeddingCache {
cache: Arc<RwLock<HashMap<String, CacheEntry>>>,
max_entries: usize,
ttl_secs: u64,
}
impl EmbeddingCache {
pub fn new(max_entries: usize, ttl_secs: u64) -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
max_entries,
ttl_secs,
}
}
pub fn get(&self, key: &str) -> Option<Vec<f32>> {
let mut cache = self.cache.write().ok()?;
if let Some(entry) = cache.get_mut(key) {
if self.ttl_secs > 0 && entry.is_expired(Duration::from_secs(self.ttl_secs)) {
cache.remove(key);
return None;
}
entry.touch();
Some(entry.embedding.clone())
} else {
None
}
}
pub fn put(&self, key: String, embedding: Vec<f32>) {
let mut cache = self.cache.write().expect("Failed to acquire cache lock");
if cache.len() >= self.max_entries && !cache.contains_key(&key) {
self.evict_lru(&mut cache);
}
cache.insert(key, CacheEntry::new(embedding));
}
fn evict_lru(&self, cache: &mut HashMap<String, CacheEntry>) {
if cache.is_empty() {
return;
}
let lru_key = cache
.iter()
.min_by_key(|(_, entry)| entry.last_accessed)
.map(|(key, _)| key.clone());
if let Some(key) = lru_key {
cache.remove(&key);
}
}
pub fn clear_expired(&self) {
if self.ttl_secs == 0 {
return; }
let mut cache = self.cache.write().expect("Failed to acquire cache lock");
let ttl = Duration::from_secs(self.ttl_secs);
cache.retain(|_, entry| !entry.is_expired(ttl));
}
pub fn clear(&self) {
let mut cache = self.cache.write().expect("Failed to acquire cache lock");
cache.clear();
}
pub fn stats(&self) -> CacheStats {
let cache = self.cache.read().expect("Failed to acquire cache lock");
let expired_count = if self.ttl_secs > 0 {
let ttl = Duration::from_secs(self.ttl_secs);
cache.values().filter(|entry| entry.is_expired(ttl)).count()
} else {
0
};
CacheStats {
total_entries: cache.len(),
max_entries: self.max_entries,
expired_entries: expired_count,
utilization: cache.len() as f32 / self.max_entries as f32,
}
}
pub fn len(&self) -> usize {
self.cache.read().ok().map(|c| c.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub total_entries: usize,
pub max_entries: usize,
pub expired_entries: usize,
pub utilization: f32,
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_cache_basic() {
let cache = EmbeddingCache::new(3, 0);
cache.put("key1".to_string(), vec![1.0, 2.0, 3.0]);
cache.put("key2".to_string(), vec![4.0, 5.0, 6.0]);
assert_eq!(cache.get("key1"), Some(vec![1.0, 2.0, 3.0]));
assert_eq!(cache.get("key2"), Some(vec![4.0, 5.0, 6.0]));
assert_eq!(cache.get("key3"), None);
}
#[test]
fn test_cache_lru_eviction() {
let cache = EmbeddingCache::new(2, 0);
cache.put("key1".to_string(), vec![1.0]);
cache.put("key2".to_string(), vec![2.0]);
cache.get("key1");
cache.put("key3".to_string(), vec![3.0]);
assert_eq!(cache.get("key1"), Some(vec![1.0]));
assert_eq!(cache.get("key2"), None); assert_eq!(cache.get("key3"), Some(vec![3.0]));
}
#[test]
fn test_cache_ttl() {
let cache = EmbeddingCache::new(10, 1);
cache.put("key1".to_string(), vec![1.0, 2.0]);
assert_eq!(cache.get("key1"), Some(vec![1.0, 2.0]));
thread::sleep(Duration::from_millis(1100));
assert_eq!(cache.get("key1"), None); }
#[test]
fn test_cache_stats() {
let cache = EmbeddingCache::new(5, 0);
cache.put("key1".to_string(), vec![1.0]);
cache.put("key2".to_string(), vec![2.0]);
let stats = cache.stats();
assert_eq!(stats.total_entries, 2);
assert_eq!(stats.max_entries, 5);
assert_eq!(stats.utilization, 0.4);
}
#[test]
fn test_cache_clear() {
let cache = EmbeddingCache::new(10, 0);
cache.put("key1".to_string(), vec![1.0]);
cache.put("key2".to_string(), vec![2.0]);
assert_eq!(cache.len(), 2);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_cache_clear_expired() {
let cache = EmbeddingCache::new(10, 1);
cache.put("key1".to_string(), vec![1.0]);
cache.put("key2".to_string(), vec![2.0]);
thread::sleep(Duration::from_millis(1100));
cache.clear_expired();
assert_eq!(cache.len(), 0);
}
}