1use std::sync::Mutex;
31use std::time::{Duration, Instant};
32
33use oxibonsai_rag::embedding::{Embedder, TfIdfEmbedder};
34use oxibonsai_rag::vector_store::cosine_similarity;
35
36#[derive(Debug, Clone)]
42pub struct SemanticCacheConfig {
43 pub similarity_threshold: f32,
45 pub max_entries: usize,
47 pub ttl: Duration,
49 pub cache_streaming: bool,
51 pub min_prompt_chars: usize,
54}
55
56impl Default for SemanticCacheConfig {
57 fn default() -> Self {
58 Self {
59 similarity_threshold: 0.92,
60 max_entries: 1000,
61 ttl: Duration::from_secs(3600),
62 cache_streaming: false,
63 min_prompt_chars: 20,
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
74pub struct CachedResponse {
75 pub response: String,
77 pub prompt: String,
79 pub similarity: f32,
81 pub created_at: Instant,
83 pub hit_count: u64,
85}
86
87impl CachedResponse {
88 pub fn is_expired(&self, ttl: Duration) -> bool {
90 self.created_at.elapsed() > ttl
91 }
92
93 pub fn age(&self) -> Duration {
95 self.created_at.elapsed()
96 }
97}
98
99struct CacheEntry {
105 prompt: String,
106 response: String,
107 vector: Vec<f32>,
109 created_at: Instant,
110 last_accessed: u64,
112 hit_count: u64,
113}
114
115#[derive(Debug, Clone, serde::Serialize)]
121pub struct SemanticCacheStats {
122 pub total_requests: u64,
124 pub cache_hits: u64,
126 pub cache_misses: u64,
128 pub hit_rate: f32,
130 pub entries: usize,
132 pub evictions: u64,
134 pub expired_evictions: u64,
136 pub avg_similarity_on_hit: f32,
138}
139
140impl Default for SemanticCacheStats {
141 fn default() -> Self {
142 Self {
143 total_requests: 0,
144 cache_hits: 0,
145 cache_misses: 0,
146 hit_rate: 0.0,
147 entries: 0,
148 evictions: 0,
149 expired_evictions: 0,
150 avg_similarity_on_hit: 0.0,
151 }
152 }
153}
154
155pub struct SemanticCache {
169 config: SemanticCacheConfig,
170 entries: Mutex<Vec<CacheEntry>>,
171 embedder: Mutex<TfIdfEmbedder>,
172 stats: Mutex<SemanticCacheStats>,
173 all_prompts: Mutex<Vec<String>>,
175 access_clock: Mutex<u64>,
177 similarity_sum: Mutex<f64>,
179}
180
181const BOOTSTRAP_DIM: usize = 64;
184
185const REFIT_BATCH_SIZE: usize = 16;
188
189impl SemanticCache {
190 pub fn new(config: SemanticCacheConfig) -> Self {
195 let bootstrap_docs = [
197 "hello world query prompt response cache",
198 "semantic similarity cosine embedding language model",
199 "retrieval augmented generation inference rust",
200 ];
201 let embedder = TfIdfEmbedder::fit(&bootstrap_docs, BOOTSTRAP_DIM);
202
203 Self {
204 config,
205 entries: Mutex::new(Vec::new()),
206 embedder: Mutex::new(embedder),
207 stats: Mutex::new(SemanticCacheStats::default()),
208 all_prompts: Mutex::new(Vec::new()),
209 access_clock: Mutex::new(0),
210 similarity_sum: Mutex::new(0.0),
211 }
212 }
213
214 pub fn lookup(&self, prompt: &str) -> Option<CachedResponse> {
221 if !self.is_cacheable(prompt) {
222 let mut stats = self.stats.lock().expect("stats lock poisoned");
223 stats.total_requests += 1;
224 stats.cache_misses += 1;
225 self.update_hit_rate(&mut stats);
226 return None;
227 }
228
229 let query_vec = {
231 let embedder = self.embedder.lock().expect("embedder lock poisoned");
232 match embedder.embed(prompt) {
233 Ok(v) => v,
234 Err(_) => {
235 let mut stats = self.stats.lock().expect("stats lock poisoned");
236 stats.total_requests += 1;
237 stats.cache_misses += 1;
238 self.update_hit_rate(&mut stats);
239 return None;
240 }
241 }
242 };
243
244 let mut entries = self.entries.lock().expect("entries lock poisoned");
245 let ttl = self.config.ttl;
246 let threshold = self.config.similarity_threshold;
247
248 let mut best_score = f32::NEG_INFINITY;
250 let mut best_idx: Option<usize> = None;
251
252 for (idx, entry) in entries.iter().enumerate() {
253 if entry.created_at.elapsed() > ttl {
254 continue; }
256 if entry.vector.len() != query_vec.len() {
257 continue; }
259 let score = cosine_similarity(&query_vec, &entry.vector);
260 if score >= threshold && score > best_score {
261 best_score = score;
262 best_idx = Some(idx);
263 }
264 }
265
266 let mut stats = self.stats.lock().expect("stats lock poisoned");
267 stats.total_requests += 1;
268
269 match best_idx {
270 Some(idx) => {
271 let clock = {
273 let mut c = self.access_clock.lock().expect("clock lock poisoned");
274 *c += 1;
275 *c
276 };
277 let entry = &mut entries[idx];
278 entry.hit_count += 1;
279 entry.last_accessed = clock;
280
281 let response = CachedResponse {
282 response: entry.response.clone(),
283 prompt: entry.prompt.clone(),
284 similarity: best_score,
285 created_at: entry.created_at,
286 hit_count: entry.hit_count,
287 };
288
289 stats.cache_hits += 1;
290 self.update_hit_rate(&mut stats);
291
292 {
294 let mut sim_sum = self
295 .similarity_sum
296 .lock()
297 .expect("similarity_sum lock poisoned");
298 *sim_sum += best_score as f64;
299 stats.avg_similarity_on_hit = (*sim_sum / stats.cache_hits as f64) as f32;
300 }
301
302 Some(response)
303 }
304 None => {
305 stats.cache_misses += 1;
306 self.update_hit_rate(&mut stats);
307 None
308 }
309 }
310 }
311
312 pub fn insert(&self, prompt: &str, response: &str) {
317 if !self.is_cacheable(prompt) {
318 return;
319 }
320
321 {
323 let mut all_prompts = self.all_prompts.lock().expect("all_prompts lock poisoned");
324 all_prompts.push(prompt.to_string());
325
326 let should_refit = all_prompts.len() == 1 || all_prompts.len() % REFIT_BATCH_SIZE == 0;
328 drop(all_prompts); if should_refit {
331 self.refit_embedder();
332 }
333 }
334
335 let vector = {
337 let embedder = self.embedder.lock().expect("embedder lock poisoned");
338 match embedder.embed(prompt) {
339 Ok(v) => v,
340 Err(_) => return, }
342 };
343
344 let clock = {
345 let mut c = self.access_clock.lock().expect("clock lock poisoned");
346 *c += 1;
347 *c
348 };
349
350 let mut entries = self.entries.lock().expect("entries lock poisoned");
351
352 if entries.len() >= self.config.max_entries {
354 let lru_idx = entries
355 .iter()
356 .enumerate()
357 .min_by_key(|(_, e)| e.last_accessed)
358 .map(|(i, _)| i)
359 .expect("entries is non-empty");
360 entries.swap_remove(lru_idx);
361
362 let mut stats = self.stats.lock().expect("stats lock poisoned");
363 stats.evictions += 1;
364 }
365
366 entries.push(CacheEntry {
367 prompt: prompt.to_string(),
368 response: response.to_string(),
369 vector,
370 created_at: Instant::now(),
371 last_accessed: clock,
372 hit_count: 0,
373 });
374
375 let mut stats = self.stats.lock().expect("stats lock poisoned");
376 stats.entries = entries.len();
377 }
378
379 pub fn evict_expired(&self) -> usize {
383 let ttl = self.config.ttl;
384 let mut entries = self.entries.lock().expect("entries lock poisoned");
385 let before = entries.len();
386 entries.retain(|e| e.created_at.elapsed() <= ttl);
387 let removed = before - entries.len();
388
389 let mut stats = self.stats.lock().expect("stats lock poisoned");
390 stats.expired_evictions += removed as u64;
391 stats.entries = entries.len();
392
393 removed
394 }
395
396 pub fn clear(&self) {
398 self.entries.lock().expect("entries lock poisoned").clear();
399 self.all_prompts
400 .lock()
401 .expect("all_prompts lock poisoned")
402 .clear();
403 *self
404 .similarity_sum
405 .lock()
406 .expect("similarity_sum lock poisoned") = 0.0;
407 *self.stats.lock().expect("stats lock poisoned") = SemanticCacheStats::default();
408 }
409
410 pub fn len(&self) -> usize {
412 self.entries.lock().expect("entries lock poisoned").len()
413 }
414
415 pub fn is_empty(&self) -> bool {
417 self.len() == 0
418 }
419
420 pub fn stats(&self) -> SemanticCacheStats {
422 self.stats.lock().expect("stats lock poisoned").clone()
423 }
424
425 fn is_cacheable(&self, prompt: &str) -> bool {
429 prompt.len() >= self.config.min_prompt_chars
430 }
431
432 fn refit_embedder(&self) {
438 let all_prompts = self.all_prompts.lock().expect("all_prompts lock poisoned");
439 if all_prompts.is_empty() {
440 return;
441 }
442
443 let max_features = BOOTSTRAP_DIM.max(all_prompts.len() * 4).min(4096);
446
447 let doc_refs: Vec<&str> = all_prompts.iter().map(|s| s.as_str()).collect();
448 let new_embedder = TfIdfEmbedder::fit(&doc_refs, max_features);
449 drop(all_prompts);
450
451 let mut embedder = self.embedder.lock().expect("embedder lock poisoned");
452 *embedder = new_embedder;
453 }
454
455 fn update_hit_rate(&self, stats: &mut SemanticCacheStats) {
457 stats.hit_rate = if stats.total_requests == 0 {
458 0.0
459 } else {
460 stats.cache_hits as f32 / stats.total_requests as f32
461 };
462 }
463}
464
465pub struct CachedInference {
485 pub cache: SemanticCache,
487}
488
489impl CachedInference {
490 pub fn new(config: SemanticCacheConfig) -> Self {
492 Self {
493 cache: SemanticCache::new(config),
494 }
495 }
496
497 pub fn run_or_cache<F>(&self, prompt: &str, run_inference: F) -> (String, bool)
505 where
506 F: FnOnce() -> String,
507 {
508 if let Some(cached) = self.cache.lookup(prompt) {
510 return (cached.response, true);
511 }
512
513 let response = run_inference();
515 self.cache.insert(prompt, &response);
516 (response, false)
517 }
518}
519
520#[cfg(test)]
525mod tests {
526 use super::*;
527
528 fn short_ttl_config() -> SemanticCacheConfig {
529 SemanticCacheConfig {
530 ttl: Duration::from_millis(50),
531 ..Default::default()
532 }
533 }
534
535 fn low_threshold_config() -> SemanticCacheConfig {
536 SemanticCacheConfig {
537 similarity_threshold: 0.1,
538 ..Default::default()
539 }
540 }
541
542 #[test]
545 fn test_semantic_cache_miss_on_empty() {
546 let cache = SemanticCache::new(SemanticCacheConfig::default());
547 assert!(cache.lookup("What is the meaning of life?").is_none());
548 }
549
550 #[test]
551 fn test_semantic_cache_exact_match() {
552 let cache = SemanticCache::new(low_threshold_config());
553 let prompt = "What is the capital of France and why is it important?";
554 cache.insert(prompt, "Paris is the capital of France.");
555 let result = cache.lookup(prompt);
556 assert!(result.is_some(), "exact prompt should hit the cache");
557 let cached = result.expect("just asserted Some");
558 assert_eq!(cached.response, "Paris is the capital of France.");
559 assert!(cached.similarity > 0.9, "similarity={}", cached.similarity);
561 }
562
563 #[test]
564 fn test_semantic_cache_insert_and_lookup() {
565 let config = SemanticCacheConfig {
566 similarity_threshold: 0.5,
567 ..Default::default()
568 };
569 let cache = SemanticCache::new(config);
570 let prompt = "Explain the concept of machine learning in detail";
571 cache.insert(prompt, "Machine learning is a branch of AI.");
572 assert_eq!(cache.len(), 1);
573 let hit = cache.lookup(prompt);
574 assert!(hit.is_some());
575 }
576
577 #[test]
580 fn test_semantic_cache_ttl_expiry() {
581 let config = short_ttl_config();
582 let cache = SemanticCache::new(config);
583 let prompt = "Tell me everything about neural networks and deep learning";
584 cache.insert(prompt, "Neural networks are computational graphs.");
585 assert!(
587 cache.lookup(prompt).is_some(),
588 "should hit before TTL expires"
589 );
590 std::thread::sleep(Duration::from_millis(100));
592 assert!(
594 cache.lookup(prompt).is_none(),
595 "should miss after TTL expires"
596 );
597 }
598
599 #[test]
602 fn test_semantic_cache_min_prompt_length() {
603 let cache = SemanticCache::new(SemanticCacheConfig::default());
604 let short = "Hi";
606 cache.insert(short, "Hello!");
607 assert_eq!(cache.len(), 0, "short prompt should not be cached");
608 assert!(cache.lookup(short).is_none());
609 }
610
611 #[test]
614 fn test_semantic_cache_evict_expired() {
615 let config = short_ttl_config();
616 let cache = SemanticCache::new(config);
617
618 for i in 0..5 {
619 let prompt = format!(
620 "This is a sufficiently long prompt number {} for caching purposes",
621 i
622 );
623 cache.insert(&prompt, "response");
624 }
625 assert_eq!(cache.len(), 5);
626
627 std::thread::sleep(Duration::from_millis(100));
628 let removed = cache.evict_expired();
629 assert_eq!(removed, 5, "all entries should have expired");
630 assert_eq!(cache.len(), 0);
631
632 let stats = cache.stats();
633 assert_eq!(stats.expired_evictions, 5);
634 }
635
636 #[test]
639 fn test_semantic_cache_stats_hit_rate() {
640 let config = low_threshold_config();
641 let cache = SemanticCache::new(config);
642
643 let prompt = "Describe the architecture of transformer neural networks in depth";
644 cache.insert(prompt, "Transformers use attention mechanisms.");
645
646 let _ = cache.lookup(prompt);
648 let _ = cache.lookup("Completely unrelated gibberish zzzzzzzz that matches nothing");
650
651 let stats = cache.stats();
652 assert_eq!(stats.cache_hits, 1);
653 assert_eq!(stats.cache_misses, 1);
654 assert_eq!(stats.total_requests, 2);
655 assert!(
656 (stats.hit_rate - 0.5).abs() < 1e-5,
657 "hit_rate={}",
658 stats.hit_rate
659 );
660 }
661
662 #[test]
665 fn test_semantic_cache_clear() {
666 let config = low_threshold_config();
667 let cache = SemanticCache::new(config);
668
669 for i in 0..10 {
670 let prompt = format!(
671 "This is prompt number {} that is long enough to be cached by the system",
672 i
673 );
674 cache.insert(&prompt, "some response");
675 }
676 assert!(!cache.is_empty());
677 cache.clear();
678 assert!(cache.is_empty());
679 assert_eq!(cache.stats().total_requests, 0);
680 }
681
682 #[test]
685 fn test_cached_inference_returns_cached() {
686 let config = low_threshold_config();
687 let ci = CachedInference::new(config);
688
689 let prompt = "What is Rust and why is it used for systems programming?";
690 let (r1, hit1) = ci.run_or_cache(prompt, || "Rust is a systems language.".to_string());
691 assert!(!hit1, "first call must be a miss");
692 assert_eq!(r1, "Rust is a systems language.");
693
694 let (r2, hit2) = ci.run_or_cache(prompt, || panic!("should not be called"));
695 assert!(hit2, "second identical call must be a hit");
696 assert_eq!(r2, "Rust is a systems language.");
697 }
698
699 #[test]
700 fn test_cached_inference_calls_fn_on_miss() {
701 let ci = CachedInference::new(SemanticCacheConfig::default());
702 let mut called = false;
703 let (resp, hit) = ci.run_or_cache(
704 "Explain quantum entanglement in detail for a physics student",
705 || {
706 called = true;
707 "Quantum entanglement is a phenomenon…".to_string()
708 },
709 );
710 assert!(!hit);
711 assert!(called);
712 assert!(!resp.is_empty());
713 }
714
715 #[test]
718 fn test_cache_config_defaults() {
719 let cfg = SemanticCacheConfig::default();
720 assert!((cfg.similarity_threshold - 0.92).abs() < 1e-6);
721 assert_eq!(cfg.max_entries, 1000);
722 assert_eq!(cfg.ttl, Duration::from_secs(3600));
723 assert!(!cfg.cache_streaming);
724 assert_eq!(cfg.min_prompt_chars, 20);
725 }
726
727 #[test]
730 fn test_cached_response_is_expired() {
731 let resp = CachedResponse {
732 response: "answer".to_string(),
733 prompt: "question".to_string(),
734 similarity: 0.95,
735 created_at: Instant::now(),
736 hit_count: 1,
737 };
738 assert!(!resp.is_expired(Duration::from_secs(60)));
739 std::thread::sleep(Duration::from_millis(1));
742 assert!(resp.is_expired(Duration::ZERO));
743 }
744}