1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use sha2::{Digest, Sha256};
5
6#[derive(Debug, Clone)]
7pub struct CachedResponse {
8 pub content: String,
9 pub model: String,
10 pub tokens_saved: u32,
11 pub created_at: Instant,
12 pub expires_at: Instant,
13 pub hits: u32,
14 pub involved_tools: bool,
15 pub embedding: Option<Vec<f32>>,
16}
17
18#[derive(Debug)]
27pub struct SemanticCache {
28 enabled: bool,
29 ttl: Duration,
30 tool_ttl: Duration,
31 max_entries: usize,
32 similarity_threshold: f32,
33 entries: HashMap<String, CachedResponse>,
34 hit_count: usize,
35 miss_count: usize,
36}
37
38impl SemanticCache {
39 pub fn new(enabled: bool, ttl_seconds: u64, max_entries: usize) -> Self {
40 Self::with_threshold(enabled, ttl_seconds, max_entries, 0.85)
41 }
42
43 pub fn with_threshold(
44 enabled: bool,
45 ttl_seconds: u64,
46 max_entries: usize,
47 similarity_threshold: f32,
48 ) -> Self {
49 Self {
50 enabled,
51 ttl: Duration::from_secs(ttl_seconds),
52 tool_ttl: Duration::from_secs(ttl_seconds / 4),
53 max_entries,
54 similarity_threshold,
55 entries: HashMap::new(),
56 hit_count: 0,
57 miss_count: 0,
58 }
59 }
60
61 pub fn lookup_exact(&mut self, prompt_hash: &str) -> Option<CachedResponse> {
63 if !self.enabled {
64 self.miss_count += 1;
65 return None;
66 }
67
68 if let Some(entry) = self.entries.get_mut(prompt_hash)
69 && Instant::now() < entry.expires_at
70 {
71 entry.hits += 1;
72 self.hit_count += 1;
73 return Some(entry.clone());
74 }
75
76 self.miss_count += 1;
77 None
78 }
79
80 pub fn lookup_semantic(&mut self, prompt: &str) -> Option<CachedResponse> {
84 if !self.enabled {
85 return None;
86 }
87
88 let query_emb = compute_ngram_embedding(prompt);
89 let now = Instant::now();
90
91 let mut best_match: Option<(&str, f32)> = None;
92
93 for (key, entry) in &self.entries {
94 if now >= entry.expires_at {
95 continue;
96 }
97 if let Some(ref emb) = entry.embedding {
98 let sim = cosine_similarity(&query_emb, emb);
99 if sim >= self.similarity_threshold
100 && best_match
101 .as_ref()
102 .is_none_or(|(_, best_sim)| sim > *best_sim)
103 {
104 best_match = Some((key, sim));
105 }
106 }
107 }
108
109 if let Some((key, _)) = best_match {
110 let key = key.to_string();
111 if let Some(entry) = self.entries.get_mut(&key) {
112 entry.hits += 1;
113 self.hit_count += 1;
114 return Some(entry.clone());
115 }
116 }
117
118 self.miss_count += 1;
119 None
120 }
121
122 pub fn lookup_tool_ttl(&mut self, prompt_hash: &str) -> Option<CachedResponse> {
126 if !self.enabled {
127 return None;
128 }
129
130 if let Some(entry) = self.entries.get_mut(prompt_hash) {
131 let effective_ttl = if entry.involved_tools {
132 entry.created_at + self.tool_ttl
133 } else {
134 entry.expires_at
135 };
136
137 if Instant::now() < effective_ttl {
138 entry.hits += 1;
139 self.hit_count += 1;
140 return Some(entry.clone());
141 }
142 }
143
144 None
145 }
146
147 pub fn lookup_semantic_with_embedding(
150 &mut self,
151 query_embedding: &[f32],
152 ) -> Option<CachedResponse> {
153 if !self.enabled {
154 return None;
155 }
156
157 let now = Instant::now();
158 let mut best_match: Option<(&str, f32)> = None;
159
160 for (key, entry) in &self.entries {
161 if now >= entry.expires_at {
162 continue;
163 }
164 if let Some(ref emb) = entry.embedding {
165 let sim = cosine_similarity(query_embedding, emb);
166 if sim >= self.similarity_threshold
167 && best_match
168 .as_ref()
169 .is_none_or(|(_, best_sim)| sim > *best_sim)
170 {
171 best_match = Some((key, sim));
172 }
173 }
174 }
175
176 if let Some((key, _)) = best_match {
177 let key = key.to_string();
178 if let Some(entry) = self.entries.get_mut(&key) {
179 entry.hits += 1;
180 self.hit_count += 1;
181 return Some(entry.clone());
182 }
183 }
184
185 self.miss_count += 1;
186 None
187 }
188
189 pub fn lookup(&mut self, prompt_hash: &str, prompt_text: &str) -> Option<CachedResponse> {
191 if let Some(hit) = self.lookup_exact(prompt_hash) {
192 return Some(hit);
193 }
194 if let Some(hit) = self.lookup_tool_ttl(prompt_hash) {
195 return Some(hit);
196 }
197 self.lookup_semantic(prompt_text)
198 }
199
200 pub fn lookup_strict(&mut self, prompt_hash: &str) -> Option<CachedResponse> {
205 if let Some(hit) = self.lookup_exact(prompt_hash) {
206 return Some(hit);
207 }
208 self.lookup_tool_ttl(prompt_hash)
209 }
210
211 pub fn lookup_with_embedding(
213 &mut self,
214 prompt_hash: &str,
215 query_embedding: &[f32],
216 ) -> Option<CachedResponse> {
217 if let Some(hit) = self.lookup_exact(prompt_hash) {
218 return Some(hit);
219 }
220 if let Some(hit) = self.lookup_tool_ttl(prompt_hash) {
221 return Some(hit);
222 }
223 self.lookup_semantic_with_embedding(query_embedding)
224 }
225
226 pub fn store(&mut self, prompt_hash: &str, response: CachedResponse) {
227 if !self.enabled {
228 return;
229 }
230
231 if self.entries.len() >= self.max_entries {
232 self.evict_lfu();
233 }
234
235 let now = Instant::now();
236 let entry = CachedResponse {
237 created_at: now,
238 expires_at: now + self.ttl,
239 hits: 0,
240 ..response
241 };
242 self.entries.insert(prompt_hash.to_string(), entry);
243 }
244
245 pub fn store_with_embedding(
247 &mut self,
248 prompt_hash: &str,
249 prompt_text: &str,
250 mut response: CachedResponse,
251 ) {
252 response.embedding = Some(compute_ngram_embedding(prompt_text));
253 self.store(prompt_hash, response);
254 }
255
256 pub fn compute_hash(system: &str, messages: &str, user_msg: &str) -> String {
257 let mut hasher = Sha256::new();
258 hasher.update(system.as_bytes());
259 hasher.update(b"|");
260 hasher.update(messages.as_bytes());
261 hasher.update(b"|");
262 hasher.update(user_msg.as_bytes());
263 format!("{:x}", hasher.finalize())
264 }
265
266 pub fn evict_expired(&mut self) {
267 let now = Instant::now();
268 self.entries.retain(|_, v| v.expires_at > now);
269 }
270
271 pub fn evict_lfu(&mut self) {
273 if let Some(key) = self
274 .entries
275 .iter()
276 .min_by_key(|(_, v)| v.hits)
277 .map(|(k, _)| k.clone())
278 {
279 self.entries.remove(&key);
280 }
281 }
282
283 pub fn hit_count(&self) -> usize {
284 self.hit_count
285 }
286
287 pub fn miss_count(&self) -> usize {
288 self.miss_count
289 }
290
291 pub fn size(&self) -> usize {
292 self.entries.len()
293 }
294
295 pub fn export_entries(&self) -> Vec<(String, ExportedCacheEntry)> {
297 self.entries
298 .iter()
299 .map(|(key, entry)| {
300 let ttl_remaining = entry
301 .expires_at
302 .checked_duration_since(Instant::now())
303 .unwrap_or_default();
304 (
305 key.clone(),
306 ExportedCacheEntry {
307 content: entry.content.clone(),
308 model: entry.model.clone(),
309 tokens_saved: entry.tokens_saved,
310 hits: entry.hits,
311 involved_tools: entry.involved_tools,
312 embedding: entry.embedding.clone(),
313 ttl_remaining_secs: ttl_remaining.as_secs(),
314 },
315 )
316 })
317 .collect()
318 }
319
320 pub fn import_entries(&mut self, entries: Vec<(String, ExportedCacheEntry)>) {
322 if !self.enabled {
323 return;
324 }
325
326 for (key, exported) in entries {
327 if exported.ttl_remaining_secs == 0 {
328 continue;
329 }
330
331 let now = Instant::now();
332 let expires = now + Duration::from_secs(exported.ttl_remaining_secs);
333
334 self.entries.insert(
335 key,
336 CachedResponse {
337 content: exported.content,
338 model: exported.model,
339 tokens_saved: exported.tokens_saved,
340 created_at: now,
341 expires_at: expires,
342 hits: exported.hits,
343 involved_tools: exported.involved_tools,
344 embedding: exported.embedding,
345 },
346 );
347 }
348 }
349}
350
351#[derive(Debug, Clone)]
352pub struct ExportedCacheEntry {
353 pub content: String,
354 pub model: String,
355 pub tokens_saved: u32,
356 pub hits: u32,
357 pub involved_tools: bool,
358 pub embedding: Option<Vec<f32>>,
359 pub ttl_remaining_secs: u64,
360}
361
362const NGRAM_DIM: usize = 128;
363
364fn compute_ngram_embedding(text: &str) -> Vec<f32> {
366 let mut vec = vec![0.0f32; NGRAM_DIM];
367 let lower = text.to_lowercase();
368 let chars: Vec<char> = lower.chars().collect();
369 if chars.len() < 3 {
370 return vec;
371 }
372 for window in chars.windows(3) {
373 let hash = window
374 .iter()
375 .fold(0u32, |acc, &c| acc.wrapping_mul(31).wrapping_add(c as u32));
376 vec[(hash as usize) % NGRAM_DIM] += 1.0;
377 }
378 let norm = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
379 if norm > 0.0 {
380 for v in &mut vec {
381 *v /= norm;
382 }
383 }
384 vec
385}
386
387fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
388 if a.len() != b.len() {
389 return 0.0;
390 }
391 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
392 let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
393 let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
394 if norm_a == 0.0 || norm_b == 0.0 {
395 return 0.0;
396 }
397 dot / (norm_a * norm_b)
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use proptest::prelude::*;
404
405 proptest! {
406 #[test]
407 fn cosine_self_similarity_is_one(v in proptest::collection::vec(-1.0f32..1.0, 8..32)) {
408 let sim = cosine_similarity(&v, &v);
409 prop_assert!((sim - 1.0).abs() < 0.001);
410 }
411
412 #[test]
413 fn proptest_cosine_similarity_is_commutative(
414 a in proptest::collection::vec(-1.0f32..1.0, 8..32),
415 b in proptest::collection::vec(-1.0f32..1.0, 8..32),
416 ) {
417 let len = a.len().min(b.len());
418 let a = &a[..len];
419 let b = &b[..len];
420 let sim_ab = cosine_similarity(a, b);
421 let sim_ba = cosine_similarity(b, a);
422 prop_assert!((sim_ab - sim_ba).abs() < 0.001,
423 "cosine_similarity not commutative: sim(a,b)={} vs sim(b,a)={}", sim_ab, sim_ba);
424 }
425
426 #[test]
427 fn proptest_cosine_similarity_bounded_for_nonneg(
428 a in proptest::collection::vec(0.0f32..1.0, 8..32),
429 b in proptest::collection::vec(0.0f32..1.0, 8..32),
430 ) {
431 let len = a.len().min(b.len());
432 let a = &a[..len];
433 let b = &b[..len];
434 let sim = cosine_similarity(a, b);
435 prop_assert!((-0.001..=1.001).contains(&sim),
436 "cosine similarity {} out of bounds [0, 1] for non-negative vectors", sim);
437 }
438
439 #[test]
440 fn proptest_ngram_embedding_has_fixed_dimension(text in "\\PC{1,200}") {
441 let emb = compute_ngram_embedding(&text);
442 prop_assert_eq!(emb.len(), NGRAM_DIM,
443 "embedding dimension should be {} but was {}", NGRAM_DIM, emb.len());
444 }
445
446 #[test]
447 fn proptest_ngram_embedding_is_deterministic(text in "[a-zA-Z0-9 ]{1,100}") {
448 let emb1 = compute_ngram_embedding(&text);
449 let emb2 = compute_ngram_embedding(&text);
450 prop_assert_eq!(emb1, emb2, "same text must produce identical embeddings");
451 }
452 }
453
454 fn make_response(content: &str) -> CachedResponse {
455 let now = Instant::now();
456 CachedResponse {
457 content: content.into(),
458 model: "test-model".into(),
459 tokens_saved: 100,
460 created_at: now,
461 expires_at: now + Duration::from_secs(3600),
462 hits: 0,
463 involved_tools: false,
464 embedding: None,
465 }
466 }
467
468 fn make_tool_response(content: &str) -> CachedResponse {
469 let mut r = make_response(content);
470 r.involved_tools = true;
471 r
472 }
473
474 #[test]
475 fn store_and_exact_hit() {
476 let mut cache = SemanticCache::new(true, 3600, 100);
477 let hash = SemanticCache::compute_hash("sys", "msgs", "hello");
478
479 cache.store(&hash, make_response("world"));
480 let result = cache.lookup_exact(&hash);
481 assert!(result.is_some());
482 assert_eq!(result.unwrap().content, "world");
483 assert_eq!(cache.hit_count(), 1);
484 assert_eq!(cache.size(), 1);
485 }
486
487 #[test]
488 fn miss_for_unknown_hash() {
489 let mut cache = SemanticCache::new(true, 3600, 100);
490 let result = cache.lookup_exact("nonexistent_hash");
491 assert!(result.is_none());
492 assert_eq!(cache.miss_count(), 1);
493 }
494
495 #[test]
496 fn expiration_eviction() {
497 let mut cache = SemanticCache::new(true, 0, 100);
498 let hash = SemanticCache::compute_hash("sys", "msgs", "expire-me");
499
500 cache.store(&hash, make_response("ephemeral"));
501 std::thread::sleep(Duration::from_millis(5));
502 cache.evict_expired();
503 assert_eq!(cache.size(), 0);
504 }
505
506 #[test]
507 fn lfu_eviction_at_capacity() {
508 let mut cache = SemanticCache::new(true, 3600, 2);
509
510 let h1 = "hash_1".to_string();
511 let h2 = "hash_2".to_string();
512 let h3 = "hash_3".to_string();
513
514 cache.store(&h1, make_response("first"));
515 cache.store(&h2, make_response("second"));
516
517 cache.lookup_exact(&h2);
518
519 cache.store(&h3, make_response("third"));
520 assert_eq!(cache.size(), 2);
521 assert!(cache.lookup_exact(&h1).is_none());
522 assert!(cache.lookup_exact(&h2).is_some());
523 }
524
525 #[test]
526 fn semantic_similarity_finds_near_matches() {
527 let mut cache = SemanticCache::new(true, 3600, 100);
528 let prompt1 = "What is the capital city of France?";
529 let hash1 = SemanticCache::compute_hash("sys", "", prompt1);
530
531 cache.store_with_embedding(&hash1, prompt1, make_response("Paris"));
532
533 let similar_prompt = "What is the capital of France?";
534 let result = cache.lookup_semantic(similar_prompt);
535 assert!(result.is_some(), "semantically similar prompt should hit");
536 assert_eq!(result.unwrap().content, "Paris");
537 }
538
539 #[test]
540 fn semantic_dissimilar_miss() {
541 let mut cache = SemanticCache::new(true, 3600, 100);
542 let prompt1 = "What is the capital city of France?";
543 let hash1 = SemanticCache::compute_hash("sys", "", prompt1);
544
545 cache.store_with_embedding(&hash1, prompt1, make_response("Paris"));
546
547 let different_prompt = "How do quantum computers work in detail?";
548 let result = cache.lookup_semantic(different_prompt);
549 assert!(result.is_none(), "dissimilar prompt should miss");
550 }
551
552 #[test]
553 fn tool_ttl_shorter_than_normal() {
554 let mut cache = SemanticCache::new(true, 100, 100);
555
556 let hash = "tool_hash";
557 cache.store(hash, make_tool_response("tool result"));
558
559 let hit = cache.lookup_tool_ttl(hash);
560 assert!(hit.is_some(), "fresh tool entry should hit");
561
562 let non_tool_hash = "normal_hash";
563 cache.store(non_tool_hash, make_response("normal result"));
564 let hit = cache.lookup_tool_ttl(non_tool_hash);
565 assert!(
566 hit.is_some(),
567 "fresh non-tool entry should hit via tool_ttl"
568 );
569 }
570
571 #[test]
572 fn multi_level_lookup_prefers_exact() {
573 let mut cache = SemanticCache::new(true, 3600, 100);
574 let prompt = "hello world test prompt";
575 let hash = SemanticCache::compute_hash("sys", "", prompt);
576
577 cache.store_with_embedding(&hash, prompt, make_response("exact match"));
578
579 let result = cache.lookup(&hash, prompt);
580 assert!(result.is_some());
581 assert_eq!(result.unwrap().content, "exact match");
582 }
583
584 #[test]
585 fn ngram_embedding_properties() {
586 let emb1 = compute_ngram_embedding("hello world");
587 let emb2 = compute_ngram_embedding("hello world");
588 assert_eq!(emb1, emb2, "same text should produce identical embeddings");
589
590 let emb3 = compute_ngram_embedding("completely different text");
591 assert_ne!(emb1, emb3);
592
593 let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
594 assert!(
595 (norm - 1.0).abs() < 0.01,
596 "embedding should be unit-normalized"
597 );
598 }
599
600 #[test]
601 fn cosine_similarity_properties() {
602 let a = vec![1.0, 0.0, 0.0];
603 let b = vec![1.0, 0.0, 0.0];
604 assert!((cosine_similarity(&a, &b) - 1.0).abs() < f64::EPSILON as f32);
605
606 let c = vec![0.0, 1.0, 0.0];
607 assert!(cosine_similarity(&a, &c).abs() < f64::EPSILON as f32);
608 }
609
610 #[test]
612 fn cache_zero_capacity_still_stores_one() {
613 let mut cache = SemanticCache::new(true, 3600, 0);
614 let hash = SemanticCache::compute_hash("", "", "q");
615 cache.store(&hash, make_response("a"));
616 assert_eq!(cache.size(), 1);
617 let hit = cache.lookup_exact(&hash);
618 assert!(hit.is_some());
619 assert_eq!(hit.unwrap().content, "a");
620 }
621
622 #[test]
623 fn cache_duplicate_key_overwrites() {
624 let mut cache = SemanticCache::new(true, 3600, 10);
625 let hash = "dup_key".to_string();
626 cache.store(&hash, make_response("first"));
627 cache.store(&hash, make_response("second"));
628 assert_eq!(cache.size(), 1);
629 let hit = cache.lookup_exact(&hash);
630 assert!(hit.is_some());
631 assert_eq!(hit.unwrap().content, "second");
632 }
633
634 #[test]
635 fn export_entries_produces_valid_data() {
636 let mut cache = SemanticCache::new(true, 3600, 10);
637 cache.store_with_embedding("hash1", "prompt one", make_response("response one"));
638 cache.store("hash2", make_response("response two"));
639
640 let exported = cache.export_entries();
641 assert_eq!(exported.len(), 2);
642
643 for (key, entry) in &exported {
644 assert!(!key.is_empty());
645 assert!(!entry.content.is_empty());
646 assert!(entry.ttl_remaining_secs > 0);
647 }
648 }
649
650 #[test]
651 fn import_entries_restores_lookups() {
652 let mut cache = SemanticCache::new(true, 3600, 10);
653 cache.store("h1", make_response("original"));
654
655 let exported = cache.export_entries();
656
657 let mut fresh = SemanticCache::new(true, 3600, 10);
658 assert_eq!(fresh.size(), 0);
659
660 fresh.import_entries(exported);
661 assert_eq!(fresh.size(), 1);
662
663 let hit = fresh.lookup_exact("h1");
664 assert!(hit.is_some());
665 assert_eq!(hit.unwrap().content, "original");
666 }
667
668 #[test]
669 fn import_skips_expired_entries() {
670 let entries = vec![(
671 "expired".to_string(),
672 ExportedCacheEntry {
673 content: "old".into(),
674 model: "m".into(),
675 tokens_saved: 0,
676 hits: 0,
677 involved_tools: false,
678 embedding: None,
679 ttl_remaining_secs: 0,
680 },
681 )];
682
683 let mut cache = SemanticCache::new(true, 3600, 10);
684 cache.import_entries(entries);
685 assert_eq!(cache.size(), 0);
686 }
687
688 #[test]
689 fn export_import_roundtrip_preserves_embeddings() {
690 let mut cache = SemanticCache::new(true, 3600, 10);
691 cache.store_with_embedding("emb_hash", "test prompt", make_response("resp"));
692
693 let exported = cache.export_entries();
694 let entry = &exported[0].1;
695 assert!(entry.embedding.is_some());
696
697 let mut fresh = SemanticCache::new(true, 3600, 10);
698 fresh.import_entries(exported);
699
700 let hit = fresh.lookup_semantic("test prompt");
701 assert!(hit.is_some());
702 }
703
704 #[test]
705 fn with_threshold_uses_custom_value() {
706 let mut cache = SemanticCache::with_threshold(true, 3600, 100, 0.99);
707 let prompt1 = "What is the capital city of France?";
708 let hash1 = SemanticCache::compute_hash("sys", "", prompt1);
709 cache.store_with_embedding(&hash1, prompt1, make_response("Paris"));
710
711 let similar = "What is the capital of France?";
713 let result = cache.lookup_semantic(similar);
714 assert!(result.is_none(), "high threshold should reject near-match");
715 }
716
717 #[test]
718 fn lookup_with_embedding_uses_provided_vector() {
719 let mut cache = SemanticCache::new(true, 3600, 100);
720 let emb = vec![1.0, 0.0, 0.0, 0.0];
721 let mut resp = make_response("answer");
722 resp.embedding = Some(emb.clone());
723 cache.store("h1", resp);
724
725 let result = cache.lookup_with_embedding("nonexistent_hash", &emb);
727 assert!(result.is_some());
728 assert_eq!(result.unwrap().content, "answer");
729 }
730
731 #[test]
732 fn lookup_with_embedding_prefers_exact() {
733 let mut cache = SemanticCache::new(true, 3600, 100);
734 cache.store("exact_h", make_response("exact"));
735
736 let emb = vec![1.0, 0.0];
737 let result = cache.lookup_with_embedding("exact_h", &emb);
738 assert!(result.is_some());
739 assert_eq!(result.unwrap().content, "exact");
740 }
741
742 #[test]
743 fn lookup_strict_does_not_use_semantic_near_match() {
744 let mut cache = SemanticCache::new(true, 3600, 100);
745 let prompt1 = "What is the capital city of France?";
746 let hash1 = SemanticCache::compute_hash("sys", "", prompt1);
747 cache.store_with_embedding(&hash1, prompt1, make_response("Paris"));
748
749 let hash2 = SemanticCache::compute_hash("sys", "", "What is the capital of France?");
751 let result = cache.lookup_strict(&hash2);
752 assert!(result.is_none());
753 }
754
755 #[test]
756 fn disabled_cache_ignores_import() {
757 let entries = vec![(
758 "key".to_string(),
759 ExportedCacheEntry {
760 content: "data".into(),
761 model: "m".into(),
762 tokens_saved: 10,
763 hits: 0,
764 involved_tools: false,
765 embedding: None,
766 ttl_remaining_secs: 3600,
767 },
768 )];
769
770 let mut cache = SemanticCache::new(false, 3600, 10);
771 cache.import_entries(entries);
772 assert_eq!(cache.size(), 0);
773 }
774}