1use crate::{EmbeddingProvider, EmbeddingRequest, LlmProvider, LlmRequest, LlmResponse, Result};
38use async_trait::async_trait;
39use std::sync::Arc;
40use tokio::sync::Mutex;
41
42#[derive(Debug, Clone, Copy)]
44pub struct SimilarityThreshold(f32);
45
46impl SimilarityThreshold {
47 pub fn new(threshold: f32) -> Self {
52 assert!(
53 (0.0..=1.0).contains(&threshold),
54 "Threshold must be between 0.0 and 1.0"
55 );
56 Self(threshold)
57 }
58
59 pub fn value(&self) -> f32 {
61 self.0
62 }
63}
64
65impl Default for SimilarityThreshold {
66 fn default() -> Self {
67 Self(0.85) }
69}
70
71#[derive(Debug, Clone, Default)]
73pub struct SemanticCacheStats {
74 pub hits: u64,
76 pub misses: u64,
78 pub embedding_errors: u64,
80 pub avg_similarity: f32,
82 pub cached_entries: usize,
84}
85
86impl SemanticCacheStats {
87 pub fn hit_rate(&self) -> f32 {
89 let total = self.hits + self.misses;
90 if total == 0 {
91 0.0
92 } else {
93 self.hits as f32 / total as f32
94 }
95 }
96}
97
98#[derive(Clone)]
100struct CacheEntry {
101 #[allow(dead_code)]
102 prompt: String,
103 embedding: Vec<f32>,
104 response: LlmResponse,
105 access_count: u64,
106}
107
108pub struct SemanticCache {
110 embedding_provider: Arc<Box<dyn EmbeddingProvider>>,
111 threshold: SimilarityThreshold,
112 max_size: usize,
113 entries: Arc<Mutex<Vec<CacheEntry>>>,
114 stats: Arc<Mutex<SemanticCacheStats>>,
115}
116
117impl SemanticCache {
118 pub fn new(
125 embedding_provider: Box<dyn EmbeddingProvider>,
126 threshold: SimilarityThreshold,
127 max_size: usize,
128 ) -> Self {
129 Self {
130 embedding_provider: Arc::new(embedding_provider),
131 threshold,
132 max_size,
133 entries: Arc::new(Mutex::new(Vec::new())),
134 stats: Arc::new(Mutex::new(SemanticCacheStats::default())),
135 }
136 }
137
138 pub async fn stats(&self) -> SemanticCacheStats {
140 let stats = self.stats.lock().await;
141 let entries = self.entries.lock().await;
142 let mut stats_copy = stats.clone();
143 stats_copy.cached_entries = entries.len();
144 stats_copy
145 }
146
147 pub async fn clear(&self) {
149 let mut entries = self.entries.lock().await;
150 entries.clear();
151 let mut stats = self.stats.lock().await;
152 *stats = SemanticCacheStats::default();
153 }
154
155 async fn generate_embedding(&self, prompt: &str) -> Result<Vec<f32>> {
157 let request = EmbeddingRequest {
158 texts: vec![prompt.to_string()],
159 model: None,
160 };
161 let response = self.embedding_provider.embed(request).await?;
162 Ok(response.embeddings.into_iter().next().unwrap())
163 }
164
165 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
167 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
168 let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
169 let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
170
171 if magnitude_a == 0.0 || magnitude_b == 0.0 {
172 0.0
173 } else {
174 dot_product / (magnitude_a * magnitude_b)
175 }
176 }
177
178 pub async fn get(&self, prompt: &str) -> Option<LlmResponse> {
180 let query_embedding = match self.generate_embedding(prompt).await {
182 Ok(emb) => emb,
183 Err(_) => {
184 let mut stats = self.stats.lock().await;
185 stats.embedding_errors += 1;
186 return None;
187 }
188 };
189
190 let mut entries = self.entries.lock().await;
192 let mut best_match: Option<(usize, f32)> = None;
193
194 for (idx, entry) in entries.iter().enumerate() {
195 let similarity = Self::cosine_similarity(&query_embedding, &entry.embedding);
196 if similarity >= self.threshold.value() {
197 if let Some((_, best_sim)) = best_match {
198 if similarity > best_sim {
199 best_match = Some((idx, similarity));
200 }
201 } else {
202 best_match = Some((idx, similarity));
203 }
204 }
205 }
206
207 if let Some((idx, similarity)) = best_match {
208 entries[idx].access_count += 1;
210 let response = entries[idx].response.clone();
211
212 let mut stats = self.stats.lock().await;
213 stats.hits += 1;
214 let total_hits = stats.hits;
216 stats.avg_similarity =
217 ((stats.avg_similarity * (total_hits - 1) as f32) + similarity) / total_hits as f32;
218
219 tracing::debug!(
220 "Semantic cache hit: similarity={:.3}, prompt='{}'",
221 similarity,
222 prompt
223 );
224
225 Some(response)
226 } else {
227 let mut stats = self.stats.lock().await;
229 stats.misses += 1;
230
231 tracing::debug!("Semantic cache miss: prompt='{}'", prompt);
232
233 None
234 }
235 }
236
237 pub async fn put(&self, prompt: String, response: LlmResponse) {
239 let embedding = match self.generate_embedding(&prompt).await {
241 Ok(emb) => emb,
242 Err(_) => {
243 let mut stats = self.stats.lock().await;
244 stats.embedding_errors += 1;
245 return;
246 }
247 };
248
249 let mut entries = self.entries.lock().await;
250
251 entries.push(CacheEntry {
253 prompt,
254 embedding,
255 response,
256 access_count: 1,
257 });
258
259 if entries.len() > self.max_size {
261 let min_idx = entries
263 .iter()
264 .enumerate()
265 .min_by_key(|(_, e)| e.access_count)
266 .map(|(idx, _)| idx)
267 .unwrap();
268 entries.remove(min_idx);
269 }
270 }
271}
272
273pub struct SemanticCachedProvider<P> {
275 provider: Arc<P>,
276 cache: Arc<SemanticCache>,
277}
278
279impl<P> SemanticCachedProvider<P> {
280 pub fn new(provider: P, cache: SemanticCache) -> Self {
282 Self {
283 provider: Arc::new(provider),
284 cache: Arc::new(cache),
285 }
286 }
287
288 pub async fn cache_stats(&self) -> SemanticCacheStats {
290 self.cache.stats().await
291 }
292
293 pub async fn clear_cache(&self) {
295 self.cache.clear().await
296 }
297}
298
299#[async_trait]
300impl<P: LlmProvider> LlmProvider for SemanticCachedProvider<P> {
301 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
302 if let Some(cached_response) = self.cache.get(&request.prompt).await {
304 return Ok(cached_response);
305 }
306
307 let response = self.provider.complete(request.clone()).await?;
309
310 self.cache.put(request.prompt, response.clone()).await;
312
313 Ok(response)
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use crate::Usage;
321
322 struct MockEmbeddingProvider;
324
325 #[async_trait]
326 impl EmbeddingProvider for MockEmbeddingProvider {
327 async fn embed(&self, request: EmbeddingRequest) -> Result<crate::EmbeddingResponse> {
328 let embeddings: Vec<Vec<f32>> = request
330 .texts
331 .iter()
332 .map(|text| {
333 let mut embedding = vec![0.0; 128];
334 for (i, ch) in text.chars().enumerate() {
336 embedding[i % 128] += (ch as u32 as f32) / 1000.0;
337 }
338 let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
340 if magnitude > 0.0 {
341 embedding.iter_mut().for_each(|x| *x /= magnitude);
342 }
343 embedding
344 })
345 .collect();
346
347 Ok(crate::EmbeddingResponse {
348 embeddings,
349 model: "mock".to_string(),
350 usage: None,
351 })
352 }
353 }
354
355 struct MockLlmProvider;
357
358 #[async_trait]
359 impl LlmProvider for MockLlmProvider {
360 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
361 Ok(LlmResponse {
362 content: format!("Response to: {}", request.prompt),
363 model: "mock".to_string(),
364 usage: Some(Usage {
365 prompt_tokens: 10,
366 completion_tokens: 20,
367 total_tokens: 30,
368 }),
369 tool_calls: Vec::new(),
370 })
371 }
372 }
373
374 #[tokio::test]
375 async fn test_similarity_threshold() {
376 let threshold = SimilarityThreshold::new(0.85);
377 assert_eq!(threshold.value(), 0.85);
378
379 let default_threshold = SimilarityThreshold::default();
380 assert_eq!(default_threshold.value(), 0.85);
381 }
382
383 #[tokio::test]
384 #[should_panic(expected = "Threshold must be between 0.0 and 1.0")]
385 async fn test_invalid_threshold() {
386 let _threshold = SimilarityThreshold::new(1.5);
387 }
388
389 #[tokio::test]
390 async fn test_cosine_similarity() {
391 let a = vec![1.0, 0.0, 0.0];
392 let b = vec![1.0, 0.0, 0.0];
393 assert_eq!(SemanticCache::cosine_similarity(&a, &b), 1.0);
394
395 let a = vec![1.0, 0.0, 0.0];
396 let b = vec![0.0, 1.0, 0.0];
397 assert_eq!(SemanticCache::cosine_similarity(&a, &b), 0.0);
398
399 let a = vec![1.0, 1.0];
400 let b = vec![1.0, 1.0];
401 assert!((SemanticCache::cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
402 }
403
404 #[tokio::test]
405 async fn test_semantic_cache_miss() {
406 let cache = SemanticCache::new(
407 Box::new(MockEmbeddingProvider),
408 SimilarityThreshold::new(0.9),
409 10,
410 );
411
412 let result = cache.get("test query").await;
413 assert!(result.is_none());
414
415 let stats = cache.stats().await;
416 assert_eq!(stats.misses, 1);
417 assert_eq!(stats.hits, 0);
418 }
419
420 #[tokio::test]
421 async fn test_semantic_cache_hit() {
422 let cache = SemanticCache::new(
423 Box::new(MockEmbeddingProvider),
424 SimilarityThreshold::new(0.9),
425 10,
426 );
427
428 let response = LlmResponse {
430 content: "test response".to_string(),
431 model: "test".to_string(),
432 usage: None,
433 tool_calls: Vec::new(),
434 };
435 cache.put("test query".to_string(), response.clone()).await;
436
437 let result = cache.get("test query").await;
439 assert!(result.is_some());
440 assert_eq!(result.unwrap().content, "test response");
441
442 let stats = cache.stats().await;
443 assert_eq!(stats.hits, 1);
444 assert_eq!(stats.misses, 0);
445 }
446
447 #[tokio::test]
448 async fn test_semantic_cache_similar_queries() {
449 let cache = SemanticCache::new(
450 Box::new(MockEmbeddingProvider),
451 SimilarityThreshold::new(0.7), 10,
453 );
454
455 let response = LlmResponse {
457 content: "Rust is a systems programming language".to_string(),
458 model: "test".to_string(),
459 usage: None,
460 tool_calls: Vec::new(),
461 };
462 cache
463 .put("What is Rust?".to_string(), response.clone())
464 .await;
465
466 let result = cache.get("What is Rust?").await;
468 assert!(result.is_some());
469 }
470
471 #[tokio::test]
472 async fn test_semantic_cache_eviction() {
473 let cache = SemanticCache::new(
474 Box::new(MockEmbeddingProvider),
475 SimilarityThreshold::new(0.9),
476 2, );
478
479 for i in 1..=3 {
481 let response = LlmResponse {
482 content: format!("response {}", i),
483 model: "test".to_string(),
484 usage: None,
485 tool_calls: Vec::new(),
486 };
487 cache.put(format!("query {}", i), response).await;
488 }
489
490 let stats = cache.stats().await;
491 assert_eq!(stats.cached_entries, 2);
492 }
493
494 #[tokio::test]
495 async fn test_cached_provider() {
496 let cache = SemanticCache::new(
497 Box::new(MockEmbeddingProvider),
498 SimilarityThreshold::new(0.9),
499 10,
500 );
501
502 let provider = MockLlmProvider;
503 let cached_provider = SemanticCachedProvider::new(provider, cache);
504
505 let request = LlmRequest {
507 prompt: "test query".to_string(),
508 system_prompt: None,
509 temperature: None,
510 max_tokens: None,
511 tools: Vec::new(),
512 images: Vec::new(),
513 };
514 let response1 = cached_provider.complete(request.clone()).await.unwrap();
515
516 let response2 = cached_provider.complete(request).await.unwrap();
518
519 assert_eq!(response1.content, response2.content);
520
521 let stats = cached_provider.cache_stats().await;
522 assert_eq!(stats.hits, 1);
523 assert_eq!(stats.misses, 1);
524 }
525
526 #[tokio::test]
527 async fn test_cache_stats() {
528 let cache = SemanticCache::new(
529 Box::new(MockEmbeddingProvider),
530 SimilarityThreshold::new(0.9),
531 10,
532 );
533
534 let stats = cache.stats().await;
535 assert_eq!(stats.hit_rate(), 0.0);
536
537 let response = LlmResponse {
539 content: "test".to_string(),
540 model: "test".to_string(),
541 usage: None,
542 tool_calls: Vec::new(),
543 };
544 cache.put("query".to_string(), response).await;
545 cache.get("query").await; cache.get("other query").await; let stats = cache.stats().await;
549 assert_eq!(stats.hits, 1);
550 assert_eq!(stats.misses, 1);
551 assert_eq!(stats.hit_rate(), 0.5);
552 }
553
554 #[tokio::test]
555 async fn test_clear_cache() {
556 let cache = SemanticCache::new(
557 Box::new(MockEmbeddingProvider),
558 SimilarityThreshold::new(0.9),
559 10,
560 );
561
562 let response = LlmResponse {
563 content: "test".to_string(),
564 model: "test".to_string(),
565 usage: None,
566 tool_calls: Vec::new(),
567 };
568 cache.put("query".to_string(), response).await;
569
570 cache.clear().await;
571
572 let stats = cache.stats().await;
573 assert_eq!(stats.cached_entries, 0);
574 assert_eq!(stats.hits, 0);
575 assert_eq!(stats.misses, 0);
576 }
577}