1use std::sync::Mutex;
2use dashmap::DashMap;
3use fastembed::{TextEmbedding, InitOptions, EmbeddingModel};
4use crate::Result;
5use tracing::{debug, info};
6
7pub trait Cache: Send + Sync {
9 fn get(&self, prompt: &str) -> Option<String>;
11
12 fn set(&self, prompt: &str, response: String);
14}
15
16pub struct SemanticCache {
19 model: Mutex<TextEmbedding>,
20 storage: DashMap<String, (Vec<f32>, String)>,
23 threshold: f32,
24}
25
26impl SemanticCache {
27 pub fn new() -> Result<Self> {
29 info!("Initializing semantic cache with local embedding model...");
30 let model = TextEmbedding::try_new(
31 InitOptions::new(EmbeddingModel::AllMiniLML6V2)
32 .with_show_download_progress(true)
33 ).map_err(|e| crate::AetherError::InjectionError(e.to_string()))?;
34
35 Ok(Self {
36 model: Mutex::new(model),
37 storage: DashMap::new(),
38 threshold: 0.90, })
40 }
41
42 pub fn with_threshold(mut self, threshold: f32) -> Self {
44 self.threshold = threshold;
45 self
46 }
47
48 fn cosine_similarity(v1: &[f32], v2: &[f32]) -> f32 {
49 let dot_product: f32 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
50 let norm_v1: f32 = v1.iter().map(|v| v * v).sum::<f32>().sqrt();
51 let norm_v2: f32 = v2.iter().map(|v| v * v).sum::<f32>().sqrt();
52 dot_product / (norm_v1 * norm_v2)
53 }
54}
55
56impl Cache for SemanticCache {
57 fn get(&self, prompt: &str) -> Option<String> {
58 let mut model = self.model.lock().ok()?;
59 let embedding = model.embed(vec![prompt], None).ok()?.first()?.clone();
60
61 let mut best_match: Option<(f32, String)> = None;
63
64 for entry in self.storage.iter() {
65 let (stored_embedding, response) = entry.value();
66 let similarity = Self::cosine_similarity(&embedding, stored_embedding);
67
68 if similarity >= self.threshold {
69 if best_match.as_ref().map_or(true, |(score, _)| similarity > *score) {
70 best_match = Some((similarity, response.clone()));
71 }
72 }
73 }
74
75 if let Some((score, response)) = best_match {
76 debug!("Semantic cache hit! Similarity: {:.2}", score);
77 Some(response)
78 } else {
79 None
80 }
81 }
82
83 fn set(&self, prompt: &str, response: String) {
84 let mut model = match self.model.lock() {
85 Ok(m) => m,
86 Err(_) => return,
87 };
88 if let Ok(embeddings) = model.embed(vec![prompt], None) {
89 if let Some(embedding) = embeddings.first() {
90 self.storage.insert(prompt.to_string(), (embedding.clone(), response));
91 }
92 }
93 }
94}
95
96pub struct ExactCache {
98 storage: DashMap<String, String>,
99}
100
101impl ExactCache {
102 pub fn new() -> Self {
103 Self { storage: DashMap::new() }
104 }
105}
106
107impl Cache for ExactCache {
108 fn get(&self, prompt: &str) -> Option<String> {
109 self.storage.get(prompt).map(|v| v.value().clone())
110 }
111
112 fn set(&self, prompt: &str, response: String) {
113 self.storage.insert(prompt.to_string(), response);
114 }
115}
116
117pub struct TieredCache {
119 exact: ExactCache,
120 semantic: SemanticCache,
121}
122
123impl TieredCache {
124 pub fn new() -> Result<Self> {
126 Ok(Self {
127 exact: ExactCache::new(),
128 semantic: SemanticCache::new()?,
129 })
130 }
131}
132
133impl Cache for TieredCache {
134 fn get(&self, prompt: &str) -> Option<String> {
135 if let Some(res) = self.exact.get(prompt) {
137 return Some(res);
138 }
139
140 self.semantic.get(prompt)
142 }
143
144 fn set(&self, prompt: &str, response: String) {
145 self.exact.set(prompt, response.clone());
147 self.semantic.set(prompt, response);
148 }
149}