use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use lru::LruCache;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use crate::types::Result;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub size_bytes: u64,
pub entry_count: usize,
pub evictions: u64,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
(self.hits as f64 / total as f64) * 100.0
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
#[serde(default = "default_max_size_bytes")]
pub max_size_bytes: u64,
#[serde(default)]
pub default_ttl: Option<Duration>,
#[serde(default = "default_enabled")]
pub enabled: bool,
}
fn default_max_size_bytes() -> u64 {
256 * 1024 * 1024 }
fn default_enabled() -> bool {
true
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_size_bytes: default_max_size_bytes(),
default_ttl: None,
enabled: default_enabled(),
}
}
}
pub trait EmbeddingCache: Send + Sync {
fn get(&self, key: &str) -> Option<Vec<f32>>;
fn set(&self, key: &str, embedding: Vec<f32>, ttl: Option<Duration>) -> Result<()>;
fn invalidate(&self, key: &str) -> Result<()>;
fn clear(&self) -> Result<()>;
fn stats(&self) -> CacheStats;
fn compute_key(&self, text: &str, model: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(text.as_bytes());
hasher.update(b"|");
hasher.update(model.as_bytes());
format!("{:x}", hasher.finalize())
}
fn is_enabled(&self) -> bool;
}
#[derive(Debug, Clone)]
struct CacheEntry {
embedding: Vec<f32>,
expires_at: Option<Instant>,
size_bytes: usize,
}
impl CacheEntry {
fn new(embedding: Vec<f32>, ttl: Option<Duration>) -> Self {
let now = Instant::now();
let size_bytes = embedding.len() * std::mem::size_of::<f32>();
Self {
embedding,
expires_at: ttl.map(|d| now + d),
size_bytes,
}
}
fn is_expired(&self) -> bool {
self.expires_at
.map(|exp| Instant::now() > exp)
.unwrap_or(false)
}
}
const DEFAULT_MAX_ENTRIES: usize = 10_000;
pub struct LruEmbeddingCache {
cache: Mutex<LruCache<String, CacheEntry>>,
config: CacheConfig,
current_size: AtomicU64,
hits: AtomicU64,
misses: AtomicU64,
evictions: AtomicU64,
}
impl LruEmbeddingCache {
pub fn new(config: CacheConfig) -> Self {
let avg_entry_size = 384 * std::mem::size_of::<f32>(); let max_entries = (config.max_size_bytes as usize / avg_entry_size).max(100);
let capacity = NonZeroUsize::new(max_entries)
.unwrap_or(NonZeroUsize::new(DEFAULT_MAX_ENTRIES).unwrap());
Self {
cache: Mutex::new(LruCache::new(capacity)),
config,
current_size: AtomicU64::new(0),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
evictions: AtomicU64::new(0),
}
}
pub fn with_defaults() -> Self {
Self::new(CacheConfig::default())
}
pub fn with_max_size(max_size_bytes: u64) -> Self {
Self::new(CacheConfig {
max_size_bytes,
..Default::default()
})
}
pub fn with_max_entries(max_entries: usize) -> Self {
let capacity = NonZeroUsize::new(max_entries)
.unwrap_or(NonZeroUsize::new(DEFAULT_MAX_ENTRIES).unwrap());
Self {
cache: Mutex::new(LruCache::new(capacity)),
config: CacheConfig::default(),
current_size: AtomicU64::new(0),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
evictions: AtomicU64::new(0),
}
}
pub fn cleanup_expired(&self) {
let mut cache = self.cache.lock();
let mut expired_keys = Vec::new();
for (key, entry) in cache.iter() {
if entry.is_expired() {
expired_keys.push(key.clone());
}
}
for key in expired_keys {
if let Some(entry) = cache.pop(&key) {
self.current_size
.fetch_sub(entry.size_bytes as u64, Ordering::Relaxed);
}
}
}
pub fn size_bytes(&self) -> u64 {
self.current_size.load(Ordering::Relaxed)
}
pub fn len(&self) -> usize {
self.cache.lock().len()
}
pub fn is_empty(&self) -> bool {
self.cache.lock().is_empty()
}
}
impl EmbeddingCache for LruEmbeddingCache {
fn get(&self, key: &str) -> Option<Vec<f32>> {
if !self.config.enabled {
return None;
}
let mut cache = self.cache.lock();
if let Some(entry) = cache.get(key) {
if entry.is_expired() {
let entry = cache.pop(key).unwrap();
self.current_size
.fetch_sub(entry.size_bytes as u64, Ordering::Relaxed);
self.misses.fetch_add(1, Ordering::Relaxed);
return None;
}
self.hits.fetch_add(1, Ordering::Relaxed);
Some(entry.embedding.clone())
} else {
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
fn set(&self, key: &str, embedding: Vec<f32>, ttl: Option<Duration>) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let entry = CacheEntry::new(embedding, ttl.or(self.config.default_ttl));
let entry_size = entry.size_bytes;
let mut cache = self.cache.lock();
if let Some(old_entry) = cache.pop(key) {
self.current_size
.fetch_sub(old_entry.size_bytes as u64, Ordering::Relaxed);
}
let was_at_capacity = cache.len() == cache.cap().get();
if let Some((_, evicted)) = cache.push(key.to_string(), entry) {
self.current_size
.fetch_sub(evicted.size_bytes as u64, Ordering::Relaxed);
self.evictions.fetch_add(1, Ordering::Relaxed);
} else if was_at_capacity {
self.evictions.fetch_add(1, Ordering::Relaxed);
}
self.current_size
.fetch_add(entry_size as u64, Ordering::Relaxed);
Ok(())
}
fn invalidate(&self, key: &str) -> Result<()> {
let mut cache = self.cache.lock();
if let Some(entry) = cache.pop(key) {
self.current_size
.fetch_sub(entry.size_bytes as u64, Ordering::Relaxed);
}
Ok(())
}
fn clear(&self) -> Result<()> {
let mut cache = self.cache.lock();
cache.clear();
self.current_size.store(0, Ordering::Relaxed);
Ok(())
}
fn stats(&self) -> CacheStats {
CacheStats {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
size_bytes: self.current_size.load(Ordering::Relaxed),
entry_count: self.cache.lock().len(),
evictions: self.evictions.load(Ordering::Relaxed),
}
}
fn is_enabled(&self) -> bool {
self.config.enabled
}
}
#[derive(Debug, Default)]
pub struct NoOpCache;
impl NoOpCache {
pub fn new() -> Self {
Self
}
}
impl EmbeddingCache for NoOpCache {
fn get(&self, _key: &str) -> Option<Vec<f32>> {
None
}
fn set(&self, _key: &str, _embedding: Vec<f32>, _ttl: Option<Duration>) -> Result<()> {
Ok(())
}
fn invalidate(&self, _key: &str) -> Result<()> {
Ok(())
}
fn clear(&self) -> Result<()> {
Ok(())
}
fn stats(&self) -> CacheStats {
CacheStats::default()
}
fn is_enabled(&self) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_key_computation() {
let cache = LruEmbeddingCache::with_defaults();
let key1 = cache.compute_key("hello world", "bge-small-en-v1.5");
let key2 = cache.compute_key("hello world", "bge-small-en-v1.5");
let key3 = cache.compute_key("hello world", "bge-base-en-v1.5");
let key4 = cache.compute_key("different text", "bge-small-en-v1.5");
assert_eq!(key1, key2);
assert_ne!(key1, key3);
assert_ne!(key1, key4);
}
#[test]
fn test_cache_set_and_get() {
let cache = LruEmbeddingCache::with_defaults();
let key = "test_key";
let embedding = vec![1.0, 2.0, 3.0, 4.0];
assert!(cache.get(key).is_none());
assert_eq!(cache.stats().misses, 1);
cache.set(key, embedding.clone(), None).unwrap();
let retrieved = cache.get(key);
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap(), embedding);
assert_eq!(cache.stats().hits, 1);
}
#[test]
fn test_cache_invalidate() {
let cache = LruEmbeddingCache::with_defaults();
let key = "test_key";
let embedding = vec![1.0, 2.0, 3.0];
cache.set(key, embedding, None).unwrap();
assert!(cache.get(key).is_some());
cache.invalidate(key).unwrap();
assert!(cache.get(key).is_none());
}
#[test]
fn test_cache_clear() {
let cache = LruEmbeddingCache::with_defaults();
cache.set("key1", vec![1.0, 2.0], None).unwrap();
cache.set("key2", vec![3.0, 4.0], None).unwrap();
assert_eq!(cache.len(), 2);
assert!(cache.size_bytes() > 0);
cache.clear().unwrap();
assert_eq!(cache.len(), 0);
assert_eq!(cache.size_bytes(), 0);
}
#[test]
fn test_cache_lru_eviction() {
let cache = LruEmbeddingCache::with_max_entries(2);
let embedding1 = vec![1.0, 2.0, 3.0, 4.0];
let embedding2 = vec![5.0, 6.0, 7.0, 8.0];
let embedding3 = vec![9.0, 10.0, 11.0, 12.0];
cache.set("key1", embedding1.clone(), None).unwrap();
cache.set("key2", embedding2.clone(), None).unwrap();
assert!(cache.get("key1").is_some());
assert!(cache.get("key2").is_some());
cache.set("key3", embedding3.clone(), None).unwrap();
assert!(cache.get("key1").is_none());
assert!(cache.get("key2").is_some());
assert!(cache.get("key3").is_some());
assert!(cache.stats().evictions > 0);
}
#[test]
fn test_cache_ttl_expiry() {
let cache = LruEmbeddingCache::with_defaults();
let key = "test_key";
let embedding = vec![1.0, 2.0, 3.0];
cache
.set(key, embedding, Some(Duration::from_nanos(1)))
.unwrap();
std::thread::sleep(Duration::from_millis(1));
assert!(cache.get(key).is_none());
}
#[test]
fn test_cache_stats() {
let cache = LruEmbeddingCache::with_defaults();
cache.set("key1", vec![1.0, 2.0], None).unwrap();
let _ = cache.get("key1"); let _ = cache.get("key2"); let _ = cache.get("key3");
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 2);
assert_eq!(stats.entry_count, 1);
assert!(stats.size_bytes > 0);
}
#[test]
fn test_cache_hit_rate() {
let stats = CacheStats {
hits: 75,
misses: 25,
size_bytes: 0,
entry_count: 0,
evictions: 0,
};
assert!((stats.hit_rate() - 75.0).abs() < 0.001);
}
#[test]
fn test_noop_cache() {
let cache = NoOpCache::new();
cache.set("key", vec![1.0, 2.0], None).unwrap();
assert!(cache.get("key").is_none());
let stats = cache.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
assert!(!cache.is_enabled());
}
#[test]
fn test_cache_disabled() {
let cache = LruEmbeddingCache::new(CacheConfig {
enabled: false,
..Default::default()
});
cache.set("key", vec![1.0, 2.0], None).unwrap();
assert!(cache.get("key").is_none());
assert!(!cache.is_enabled());
}
#[test]
fn test_cache_update_existing() {
let cache = LruEmbeddingCache::with_defaults();
let key = "test_key";
cache.set(key, vec![1.0, 2.0], None).unwrap();
let size1 = cache.size_bytes();
cache.set(key, vec![3.0, 4.0, 5.0, 6.0], None).unwrap();
let size2 = cache.size_bytes();
assert!(size2 > size1);
assert_eq!(cache.len(), 1);
let retrieved = cache.get(key).unwrap();
assert_eq!(retrieved, vec![3.0, 4.0, 5.0, 6.0]);
}
}