1use std::num::NonZeroUsize;
45use std::sync::atomic::{AtomicU64, Ordering};
46use std::time::{Duration, Instant};
47
48use lru::LruCache;
49use parking_lot::Mutex;
50use serde::{Deserialize, Serialize};
51use sha2::{Digest, Sha256};
52
53use crate::types::Result;
54
55#[derive(Debug, Clone, Default, Serialize, Deserialize)]
61pub struct CacheStats {
62 pub hits: u64,
64 pub misses: u64,
66 pub size_bytes: u64,
68 pub entry_count: usize,
70 pub evictions: u64,
72}
73
74impl CacheStats {
75 pub fn hit_rate(&self) -> f64 {
77 let total = self.hits + self.misses;
78 if total == 0 {
79 0.0
80 } else {
81 (self.hits as f64 / total as f64) * 100.0
82 }
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct CacheConfig {
89 #[serde(default = "default_max_size_bytes")]
91 pub max_size_bytes: u64,
92
93 #[serde(default)]
95 pub default_ttl: Option<Duration>,
96
97 #[serde(default = "default_enabled")]
99 pub enabled: bool,
100}
101
102fn default_max_size_bytes() -> u64 {
103 256 * 1024 * 1024 }
105
106fn default_enabled() -> bool {
107 true
108}
109
110impl Default for CacheConfig {
111 fn default() -> Self {
112 Self {
113 max_size_bytes: default_max_size_bytes(),
114 default_ttl: None,
115 enabled: default_enabled(),
116 }
117 }
118}
119
120pub trait EmbeddingCache: Send + Sync {
129 fn get(&self, key: &str) -> Option<Vec<f32>>;
131
132 fn set(&self, key: &str, embedding: Vec<f32>, ttl: Option<Duration>) -> Result<()>;
134
135 fn invalidate(&self, key: &str) -> Result<()>;
137
138 fn clear(&self) -> Result<()>;
140
141 fn stats(&self) -> CacheStats;
143
144 fn compute_key(&self, text: &str, model: &str) -> String {
146 let mut hasher = Sha256::new();
147 hasher.update(text.as_bytes());
148 hasher.update(b"|");
149 hasher.update(model.as_bytes());
150 format!("{:x}", hasher.finalize())
151 }
152
153 fn is_enabled(&self) -> bool;
155}
156
157#[derive(Debug, Clone)]
163struct CacheEntry {
164 embedding: Vec<f32>,
166 expires_at: Option<Instant>,
168 size_bytes: usize,
170}
171
172impl CacheEntry {
173 fn new(embedding: Vec<f32>, ttl: Option<Duration>) -> Self {
174 let now = Instant::now();
175 let size_bytes = embedding.len() * std::mem::size_of::<f32>();
176 Self {
177 embedding,
178 expires_at: ttl.map(|d| now + d),
179 size_bytes,
180 }
181 }
182
183 fn is_expired(&self) -> bool {
184 self.expires_at
185 .map(|exp| Instant::now() > exp)
186 .unwrap_or(false)
187 }
188}
189
190const DEFAULT_MAX_ENTRIES: usize = 10_000;
196
197pub struct LruEmbeddingCache {
208 cache: Mutex<LruCache<String, CacheEntry>>,
210 config: CacheConfig,
212 current_size: AtomicU64,
214 hits: AtomicU64,
216 misses: AtomicU64,
218 evictions: AtomicU64,
220}
221
222impl LruEmbeddingCache {
223 pub fn new(config: CacheConfig) -> Self {
225 let avg_entry_size = 384 * std::mem::size_of::<f32>(); let max_entries = (config.max_size_bytes as usize / avg_entry_size).max(100);
229 let capacity = NonZeroUsize::new(max_entries)
230 .unwrap_or(NonZeroUsize::new(DEFAULT_MAX_ENTRIES).unwrap());
231
232 Self {
233 cache: Mutex::new(LruCache::new(capacity)),
234 config,
235 current_size: AtomicU64::new(0),
236 hits: AtomicU64::new(0),
237 misses: AtomicU64::new(0),
238 evictions: AtomicU64::new(0),
239 }
240 }
241
242 pub fn with_defaults() -> Self {
244 Self::new(CacheConfig::default())
245 }
246
247 pub fn with_max_size(max_size_bytes: u64) -> Self {
249 Self::new(CacheConfig {
250 max_size_bytes,
251 ..Default::default()
252 })
253 }
254
255 pub fn with_max_entries(max_entries: usize) -> Self {
257 let capacity = NonZeroUsize::new(max_entries)
258 .unwrap_or(NonZeroUsize::new(DEFAULT_MAX_ENTRIES).unwrap());
259 Self {
260 cache: Mutex::new(LruCache::new(capacity)),
261 config: CacheConfig::default(),
262 current_size: AtomicU64::new(0),
263 hits: AtomicU64::new(0),
264 misses: AtomicU64::new(0),
265 evictions: AtomicU64::new(0),
266 }
267 }
268
269 pub fn cleanup_expired(&self) {
271 let mut cache = self.cache.lock();
272 let mut expired_keys = Vec::new();
273
274 for (key, entry) in cache.iter() {
276 if entry.is_expired() {
277 expired_keys.push(key.clone());
278 }
279 }
280
281 for key in expired_keys {
283 if let Some(entry) = cache.pop(&key) {
284 self.current_size
285 .fetch_sub(entry.size_bytes as u64, Ordering::Relaxed);
286 }
287 }
288 }
289
290 pub fn size_bytes(&self) -> u64 {
292 self.current_size.load(Ordering::Relaxed)
293 }
294
295 pub fn len(&self) -> usize {
297 self.cache.lock().len()
298 }
299
300 pub fn is_empty(&self) -> bool {
302 self.cache.lock().is_empty()
303 }
304}
305
306impl EmbeddingCache for LruEmbeddingCache {
307 fn get(&self, key: &str) -> Option<Vec<f32>> {
308 if !self.config.enabled {
309 return None;
310 }
311
312 let mut cache = self.cache.lock();
313
314 if let Some(entry) = cache.get(key) {
316 if entry.is_expired() {
317 let entry = cache.pop(key).unwrap();
319 self.current_size
320 .fetch_sub(entry.size_bytes as u64, Ordering::Relaxed);
321 self.misses.fetch_add(1, Ordering::Relaxed);
322 return None;
323 }
324 self.hits.fetch_add(1, Ordering::Relaxed);
325 Some(entry.embedding.clone())
326 } else {
327 self.misses.fetch_add(1, Ordering::Relaxed);
328 None
329 }
330 }
331
332 fn set(&self, key: &str, embedding: Vec<f32>, ttl: Option<Duration>) -> Result<()> {
333 if !self.config.enabled {
334 return Ok(());
335 }
336
337 let entry = CacheEntry::new(embedding, ttl.or(self.config.default_ttl));
338 let entry_size = entry.size_bytes;
339
340 let mut cache = self.cache.lock();
341
342 if let Some(old_entry) = cache.pop(key) {
344 self.current_size
345 .fetch_sub(old_entry.size_bytes as u64, Ordering::Relaxed);
346 }
347
348 let was_at_capacity = cache.len() == cache.cap().get();
350
351 if let Some((_, evicted)) = cache.push(key.to_string(), entry) {
353 self.current_size
355 .fetch_sub(evicted.size_bytes as u64, Ordering::Relaxed);
356 self.evictions.fetch_add(1, Ordering::Relaxed);
357 } else if was_at_capacity {
358 self.evictions.fetch_add(1, Ordering::Relaxed);
361 }
362
363 self.current_size
365 .fetch_add(entry_size as u64, Ordering::Relaxed);
366
367 Ok(())
368 }
369
370 fn invalidate(&self, key: &str) -> Result<()> {
371 let mut cache = self.cache.lock();
372 if let Some(entry) = cache.pop(key) {
373 self.current_size
374 .fetch_sub(entry.size_bytes as u64, Ordering::Relaxed);
375 }
376 Ok(())
377 }
378
379 fn clear(&self) -> Result<()> {
380 let mut cache = self.cache.lock();
381 cache.clear();
382 self.current_size.store(0, Ordering::Relaxed);
383 Ok(())
384 }
385
386 fn stats(&self) -> CacheStats {
387 CacheStats {
388 hits: self.hits.load(Ordering::Relaxed),
389 misses: self.misses.load(Ordering::Relaxed),
390 size_bytes: self.current_size.load(Ordering::Relaxed),
391 entry_count: self.cache.lock().len(),
392 evictions: self.evictions.load(Ordering::Relaxed),
393 }
394 }
395
396 fn is_enabled(&self) -> bool {
397 self.config.enabled
398 }
399}
400
401#[derive(Debug, Default)]
409pub struct NoOpCache;
410
411impl NoOpCache {
412 pub fn new() -> Self {
414 Self
415 }
416}
417
418impl EmbeddingCache for NoOpCache {
419 fn get(&self, _key: &str) -> Option<Vec<f32>> {
420 None
421 }
422
423 fn set(&self, _key: &str, _embedding: Vec<f32>, _ttl: Option<Duration>) -> Result<()> {
424 Ok(())
425 }
426
427 fn invalidate(&self, _key: &str) -> Result<()> {
428 Ok(())
429 }
430
431 fn clear(&self) -> Result<()> {
432 Ok(())
433 }
434
435 fn stats(&self) -> CacheStats {
436 CacheStats::default()
437 }
438
439 fn is_enabled(&self) -> bool {
440 false
441 }
442}
443
444#[cfg(test)]
449mod tests {
450 use super::*;
451
452 #[test]
453 fn test_cache_key_computation() {
454 let cache = LruEmbeddingCache::with_defaults();
455
456 let key1 = cache.compute_key("hello world", "bge-small-en-v1.5");
457 let key2 = cache.compute_key("hello world", "bge-small-en-v1.5");
458 let key3 = cache.compute_key("hello world", "bge-base-en-v1.5");
459 let key4 = cache.compute_key("different text", "bge-small-en-v1.5");
460
461 assert_eq!(key1, key2);
463 assert_ne!(key1, key3);
465 assert_ne!(key1, key4);
467 }
468
469 #[test]
470 fn test_cache_set_and_get() {
471 let cache = LruEmbeddingCache::with_defaults();
472 let key = "test_key";
473 let embedding = vec![1.0, 2.0, 3.0, 4.0];
474
475 assert!(cache.get(key).is_none());
477 assert_eq!(cache.stats().misses, 1);
478
479 cache.set(key, embedding.clone(), None).unwrap();
481 let retrieved = cache.get(key);
482
483 assert!(retrieved.is_some());
484 assert_eq!(retrieved.unwrap(), embedding);
485 assert_eq!(cache.stats().hits, 1);
486 }
487
488 #[test]
489 fn test_cache_invalidate() {
490 let cache = LruEmbeddingCache::with_defaults();
491 let key = "test_key";
492 let embedding = vec![1.0, 2.0, 3.0];
493
494 cache.set(key, embedding, None).unwrap();
495 assert!(cache.get(key).is_some());
496
497 cache.invalidate(key).unwrap();
498 assert!(cache.get(key).is_none());
499 }
500
501 #[test]
502 fn test_cache_clear() {
503 let cache = LruEmbeddingCache::with_defaults();
504
505 cache.set("key1", vec![1.0, 2.0], None).unwrap();
506 cache.set("key2", vec![3.0, 4.0], None).unwrap();
507
508 assert_eq!(cache.len(), 2);
509 assert!(cache.size_bytes() > 0);
510
511 cache.clear().unwrap();
512
513 assert_eq!(cache.len(), 0);
514 assert_eq!(cache.size_bytes(), 0);
515 }
516
517 #[test]
518 fn test_cache_lru_eviction() {
519 let cache = LruEmbeddingCache::with_max_entries(2);
521
522 let embedding1 = vec![1.0, 2.0, 3.0, 4.0];
523 let embedding2 = vec![5.0, 6.0, 7.0, 8.0];
524 let embedding3 = vec![9.0, 10.0, 11.0, 12.0];
525
526 cache.set("key1", embedding1.clone(), None).unwrap();
527 cache.set("key2", embedding2.clone(), None).unwrap();
528
529 assert!(cache.get("key1").is_some());
531 assert!(cache.get("key2").is_some());
532
533 cache.set("key3", embedding3.clone(), None).unwrap();
535
536 assert!(cache.get("key1").is_none());
538 assert!(cache.get("key2").is_some());
540 assert!(cache.get("key3").is_some());
541
542 assert!(cache.stats().evictions > 0);
543 }
544
545 #[test]
546 fn test_cache_ttl_expiry() {
547 let cache = LruEmbeddingCache::with_defaults();
548 let key = "test_key";
549 let embedding = vec![1.0, 2.0, 3.0];
550
551 cache
553 .set(key, embedding, Some(Duration::from_nanos(1)))
554 .unwrap();
555
556 std::thread::sleep(Duration::from_millis(1));
558
559 assert!(cache.get(key).is_none());
561 }
562
563 #[test]
564 fn test_cache_stats() {
565 let cache = LruEmbeddingCache::with_defaults();
566
567 cache.set("key1", vec![1.0, 2.0], None).unwrap();
569 let _ = cache.get("key1"); let _ = cache.get("key2"); let _ = cache.get("key3"); let stats = cache.stats();
574 assert_eq!(stats.hits, 1);
575 assert_eq!(stats.misses, 2);
576 assert_eq!(stats.entry_count, 1);
577 assert!(stats.size_bytes > 0);
578 }
579
580 #[test]
581 fn test_cache_hit_rate() {
582 let stats = CacheStats {
583 hits: 75,
584 misses: 25,
585 size_bytes: 0,
586 entry_count: 0,
587 evictions: 0,
588 };
589
590 assert!((stats.hit_rate() - 75.0).abs() < 0.001);
591 }
592
593 #[test]
594 fn test_noop_cache() {
595 let cache = NoOpCache::new();
596
597 cache.set("key", vec![1.0, 2.0], None).unwrap();
599
600 assert!(cache.get("key").is_none());
602
603 let stats = cache.stats();
605 assert_eq!(stats.hits, 0);
606 assert_eq!(stats.misses, 0);
607 assert!(!cache.is_enabled());
608 }
609
610 #[test]
611 fn test_cache_disabled() {
612 let cache = LruEmbeddingCache::new(CacheConfig {
613 enabled: false,
614 ..Default::default()
615 });
616
617 cache.set("key", vec![1.0, 2.0], None).unwrap();
619
620 assert!(cache.get("key").is_none());
622 assert!(!cache.is_enabled());
623 }
624
625 #[test]
626 fn test_cache_update_existing() {
627 let cache = LruEmbeddingCache::with_defaults();
628 let key = "test_key";
629
630 cache.set(key, vec![1.0, 2.0], None).unwrap();
631 let size1 = cache.size_bytes();
632
633 cache.set(key, vec![3.0, 4.0, 5.0, 6.0], None).unwrap();
635 let size2 = cache.size_bytes();
636
637 assert!(size2 > size1);
639 assert_eq!(cache.len(), 1);
640
641 let retrieved = cache.get(key).unwrap();
643 assert_eq!(retrieved, vec![3.0, 4.0, 5.0, 6.0]);
644 }
645}