oxify_connect_llm/
semantic_cache.rs

1//! Semantic Caching for LLM Requests
2//!
3//! This module provides semantic caching that recognizes semantically similar requests
4//! rather than requiring exact matches. This significantly improves cache hit rates and
5//! reduces costs by catching variations of the same query.
6//!
7//! # Examples
8//!
9//! ```
10//! use oxify_connect_llm::{
11//!     LlmProvider, LlmRequest, OpenAIProvider, OllamaProvider,
12//!     SemanticCache, SemanticCachedProvider, SimilarityThreshold,
13//! };
14//!
15//! # async fn example() -> oxify_connect_llm::Result<()> {
16//! // Create an embedding provider for semantic similarity
17//! let embedding_provider = OllamaProvider::for_embeddings("nomic-embed-text".to_string());
18//!
19//! // Create a semantic cache with 0.85 similarity threshold
20//! let cache = SemanticCache::new(
21//!     Box::new(embedding_provider),
22//!     SimilarityThreshold::new(0.85),
23//!     100, // max cache size
24//! );
25//!
26//! // Wrap your LLM provider with semantic caching
27//! let provider = OpenAIProvider::new("key".to_string(), "gpt-4".to_string());
28//! let cached_provider = SemanticCachedProvider::new(provider, cache);
29//!
30//! // These queries will be recognized as semantically similar:
31//! // "What is Rust?" and "Can you explain Rust?"
32//! // "How to learn Python" and "Best way to study Python"
33//! # Ok(())
34//! # }
35//! ```
36
37use crate::{EmbeddingProvider, EmbeddingRequest, LlmProvider, LlmRequest, LlmResponse, Result};
38use async_trait::async_trait;
39use std::sync::Arc;
40use tokio::sync::Mutex;
41
42/// Similarity threshold for semantic cache matching (0.0 to 1.0)
43#[derive(Debug, Clone, Copy)]
44pub struct SimilarityThreshold(f32);
45
46impl SimilarityThreshold {
47    /// Create a new similarity threshold
48    ///
49    /// # Panics
50    /// Panics if threshold is not between 0.0 and 1.0
51    pub fn new(threshold: f32) -> Self {
52        assert!(
53            (0.0..=1.0).contains(&threshold),
54            "Threshold must be between 0.0 and 1.0"
55        );
56        Self(threshold)
57    }
58
59    /// Get the threshold value
60    pub fn value(&self) -> f32 {
61        self.0
62    }
63}
64
65impl Default for SimilarityThreshold {
66    fn default() -> Self {
67        Self(0.85) // Default to 85% similarity
68    }
69}
70
71/// Statistics for semantic cache performance
72#[derive(Debug, Clone, Default)]
73pub struct SemanticCacheStats {
74    /// Number of cache hits (semantically similar queries found)
75    pub hits: u64,
76    /// Number of cache misses (no similar query found)
77    pub misses: u64,
78    /// Number of embedding generation failures
79    pub embedding_errors: u64,
80    /// Average similarity score for cache hits
81    pub avg_similarity: f32,
82    /// Total number of cached entries
83    pub cached_entries: usize,
84}
85
86impl SemanticCacheStats {
87    /// Calculate the cache hit rate (0.0 to 1.0)
88    pub fn hit_rate(&self) -> f32 {
89        let total = self.hits + self.misses;
90        if total == 0 {
91            0.0
92        } else {
93            self.hits as f32 / total as f32
94        }
95    }
96}
97
98/// Entry in the semantic cache
99#[derive(Clone)]
100struct CacheEntry {
101    #[allow(dead_code)]
102    prompt: String,
103    embedding: Vec<f32>,
104    response: LlmResponse,
105    access_count: u64,
106}
107
108/// Semantic cache using embeddings for similarity matching
109pub struct SemanticCache {
110    embedding_provider: Arc<Box<dyn EmbeddingProvider>>,
111    threshold: SimilarityThreshold,
112    max_size: usize,
113    entries: Arc<Mutex<Vec<CacheEntry>>>,
114    stats: Arc<Mutex<SemanticCacheStats>>,
115}
116
117impl SemanticCache {
118    /// Create a new semantic cache
119    ///
120    /// # Arguments
121    /// * `embedding_provider` - Provider for generating embeddings
122    /// * `threshold` - Minimum similarity score for cache hit (0.0 to 1.0)
123    /// * `max_size` - Maximum number of entries to cache
124    pub fn new(
125        embedding_provider: Box<dyn EmbeddingProvider>,
126        threshold: SimilarityThreshold,
127        max_size: usize,
128    ) -> Self {
129        Self {
130            embedding_provider: Arc::new(embedding_provider),
131            threshold,
132            max_size,
133            entries: Arc::new(Mutex::new(Vec::new())),
134            stats: Arc::new(Mutex::new(SemanticCacheStats::default())),
135        }
136    }
137
138    /// Get cache statistics
139    pub async fn stats(&self) -> SemanticCacheStats {
140        let stats = self.stats.lock().await;
141        let entries = self.entries.lock().await;
142        let mut stats_copy = stats.clone();
143        stats_copy.cached_entries = entries.len();
144        stats_copy
145    }
146
147    /// Clear the cache and reset statistics
148    pub async fn clear(&self) {
149        let mut entries = self.entries.lock().await;
150        entries.clear();
151        let mut stats = self.stats.lock().await;
152        *stats = SemanticCacheStats::default();
153    }
154
155    /// Generate embedding for a prompt
156    async fn generate_embedding(&self, prompt: &str) -> Result<Vec<f32>> {
157        let request = EmbeddingRequest {
158            texts: vec![prompt.to_string()],
159            model: None,
160        };
161        let response = self.embedding_provider.embed(request).await?;
162        Ok(response.embeddings.into_iter().next().unwrap())
163    }
164
165    /// Calculate cosine similarity between two embeddings
166    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
167        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
168        let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
169        let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
170
171        if magnitude_a == 0.0 || magnitude_b == 0.0 {
172            0.0
173        } else {
174            dot_product / (magnitude_a * magnitude_b)
175        }
176    }
177
178    /// Try to get a cached response for a prompt
179    pub async fn get(&self, prompt: &str) -> Option<LlmResponse> {
180        // Generate embedding for the query
181        let query_embedding = match self.generate_embedding(prompt).await {
182            Ok(emb) => emb,
183            Err(_) => {
184                let mut stats = self.stats.lock().await;
185                stats.embedding_errors += 1;
186                return None;
187            }
188        };
189
190        // Search for similar entries
191        let mut entries = self.entries.lock().await;
192        let mut best_match: Option<(usize, f32)> = None;
193
194        for (idx, entry) in entries.iter().enumerate() {
195            let similarity = Self::cosine_similarity(&query_embedding, &entry.embedding);
196            if similarity >= self.threshold.value() {
197                if let Some((_, best_sim)) = best_match {
198                    if similarity > best_sim {
199                        best_match = Some((idx, similarity));
200                    }
201                } else {
202                    best_match = Some((idx, similarity));
203                }
204            }
205        }
206
207        if let Some((idx, similarity)) = best_match {
208            // Cache hit - update access count and stats
209            entries[idx].access_count += 1;
210            let response = entries[idx].response.clone();
211
212            let mut stats = self.stats.lock().await;
213            stats.hits += 1;
214            // Update running average of similarity scores
215            let total_hits = stats.hits;
216            stats.avg_similarity =
217                ((stats.avg_similarity * (total_hits - 1) as f32) + similarity) / total_hits as f32;
218
219            tracing::debug!(
220                "Semantic cache hit: similarity={:.3}, prompt='{}'",
221                similarity,
222                prompt
223            );
224
225            Some(response)
226        } else {
227            // Cache miss
228            let mut stats = self.stats.lock().await;
229            stats.misses += 1;
230
231            tracing::debug!("Semantic cache miss: prompt='{}'", prompt);
232
233            None
234        }
235    }
236
237    /// Store a response in the cache
238    pub async fn put(&self, prompt: String, response: LlmResponse) {
239        // Generate embedding for the prompt
240        let embedding = match self.generate_embedding(&prompt).await {
241            Ok(emb) => emb,
242            Err(_) => {
243                let mut stats = self.stats.lock().await;
244                stats.embedding_errors += 1;
245                return;
246            }
247        };
248
249        let mut entries = self.entries.lock().await;
250
251        // Add new entry
252        entries.push(CacheEntry {
253            prompt,
254            embedding,
255            response,
256            access_count: 1,
257        });
258
259        // Evict least accessed entry if cache is full
260        if entries.len() > self.max_size {
261            // Find entry with lowest access count
262            let min_idx = entries
263                .iter()
264                .enumerate()
265                .min_by_key(|(_, e)| e.access_count)
266                .map(|(idx, _)| idx)
267                .unwrap();
268            entries.remove(min_idx);
269        }
270    }
271}
272
273/// Provider wrapper that adds semantic caching
274pub struct SemanticCachedProvider<P> {
275    provider: Arc<P>,
276    cache: Arc<SemanticCache>,
277}
278
279impl<P> SemanticCachedProvider<P> {
280    /// Create a new semantic cached provider
281    pub fn new(provider: P, cache: SemanticCache) -> Self {
282        Self {
283            provider: Arc::new(provider),
284            cache: Arc::new(cache),
285        }
286    }
287
288    /// Get cache statistics
289    pub async fn cache_stats(&self) -> SemanticCacheStats {
290        self.cache.stats().await
291    }
292
293    /// Clear the cache
294    pub async fn clear_cache(&self) {
295        self.cache.clear().await
296    }
297}
298
299#[async_trait]
300impl<P: LlmProvider> LlmProvider for SemanticCachedProvider<P> {
301    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
302        // Try to get from cache
303        if let Some(cached_response) = self.cache.get(&request.prompt).await {
304            return Ok(cached_response);
305        }
306
307        // Cache miss - call provider
308        let response = self.provider.complete(request.clone()).await?;
309
310        // Store in cache
311        self.cache.put(request.prompt, response.clone()).await;
312
313        Ok(response)
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::Usage;
321
322    // Mock embedding provider that returns simple embeddings
323    struct MockEmbeddingProvider;
324
325    #[async_trait]
326    impl EmbeddingProvider for MockEmbeddingProvider {
327        async fn embed(&self, request: EmbeddingRequest) -> Result<crate::EmbeddingResponse> {
328            // Create simple embeddings based on text length and content
329            let embeddings: Vec<Vec<f32>> = request
330                .texts
331                .iter()
332                .map(|text| {
333                    let mut embedding = vec![0.0; 128];
334                    // Simple hash-like embedding based on characters
335                    for (i, ch) in text.chars().enumerate() {
336                        embedding[i % 128] += (ch as u32 as f32) / 1000.0;
337                    }
338                    // Normalize
339                    let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
340                    if magnitude > 0.0 {
341                        embedding.iter_mut().for_each(|x| *x /= magnitude);
342                    }
343                    embedding
344                })
345                .collect();
346
347            Ok(crate::EmbeddingResponse {
348                embeddings,
349                model: "mock".to_string(),
350                usage: None,
351            })
352        }
353    }
354
355    // Mock LLM provider
356    struct MockLlmProvider;
357
358    #[async_trait]
359    impl LlmProvider for MockLlmProvider {
360        async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
361            Ok(LlmResponse {
362                content: format!("Response to: {}", request.prompt),
363                model: "mock".to_string(),
364                usage: Some(Usage {
365                    prompt_tokens: 10,
366                    completion_tokens: 20,
367                    total_tokens: 30,
368                }),
369                tool_calls: Vec::new(),
370            })
371        }
372    }
373
374    #[tokio::test]
375    async fn test_similarity_threshold() {
376        let threshold = SimilarityThreshold::new(0.85);
377        assert_eq!(threshold.value(), 0.85);
378
379        let default_threshold = SimilarityThreshold::default();
380        assert_eq!(default_threshold.value(), 0.85);
381    }
382
383    #[tokio::test]
384    #[should_panic(expected = "Threshold must be between 0.0 and 1.0")]
385    async fn test_invalid_threshold() {
386        let _threshold = SimilarityThreshold::new(1.5);
387    }
388
389    #[tokio::test]
390    async fn test_cosine_similarity() {
391        let a = vec![1.0, 0.0, 0.0];
392        let b = vec![1.0, 0.0, 0.0];
393        assert_eq!(SemanticCache::cosine_similarity(&a, &b), 1.0);
394
395        let a = vec![1.0, 0.0, 0.0];
396        let b = vec![0.0, 1.0, 0.0];
397        assert_eq!(SemanticCache::cosine_similarity(&a, &b), 0.0);
398
399        let a = vec![1.0, 1.0];
400        let b = vec![1.0, 1.0];
401        assert!((SemanticCache::cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
402    }
403
404    #[tokio::test]
405    async fn test_semantic_cache_miss() {
406        let cache = SemanticCache::new(
407            Box::new(MockEmbeddingProvider),
408            SimilarityThreshold::new(0.9),
409            10,
410        );
411
412        let result = cache.get("test query").await;
413        assert!(result.is_none());
414
415        let stats = cache.stats().await;
416        assert_eq!(stats.misses, 1);
417        assert_eq!(stats.hits, 0);
418    }
419
420    #[tokio::test]
421    async fn test_semantic_cache_hit() {
422        let cache = SemanticCache::new(
423            Box::new(MockEmbeddingProvider),
424            SimilarityThreshold::new(0.9),
425            10,
426        );
427
428        // Store a response
429        let response = LlmResponse {
430            content: "test response".to_string(),
431            model: "test".to_string(),
432            usage: None,
433            tool_calls: Vec::new(),
434        };
435        cache.put("test query".to_string(), response.clone()).await;
436
437        // Try to get the exact same query
438        let result = cache.get("test query").await;
439        assert!(result.is_some());
440        assert_eq!(result.unwrap().content, "test response");
441
442        let stats = cache.stats().await;
443        assert_eq!(stats.hits, 1);
444        assert_eq!(stats.misses, 0);
445    }
446
447    #[tokio::test]
448    async fn test_semantic_cache_similar_queries() {
449        let cache = SemanticCache::new(
450            Box::new(MockEmbeddingProvider),
451            SimilarityThreshold::new(0.7), // Lower threshold to catch similar queries
452            10,
453        );
454
455        // Store a response for one query
456        let response = LlmResponse {
457            content: "Rust is a systems programming language".to_string(),
458            model: "test".to_string(),
459            usage: None,
460            tool_calls: Vec::new(),
461        };
462        cache
463            .put("What is Rust?".to_string(), response.clone())
464            .await;
465
466        // Try a similar query (should hit with high similarity)
467        let result = cache.get("What is Rust?").await;
468        assert!(result.is_some());
469    }
470
471    #[tokio::test]
472    async fn test_semantic_cache_eviction() {
473        let cache = SemanticCache::new(
474            Box::new(MockEmbeddingProvider),
475            SimilarityThreshold::new(0.9),
476            2, // Small cache size
477        );
478
479        // Add 3 entries (will evict one)
480        for i in 1..=3 {
481            let response = LlmResponse {
482                content: format!("response {}", i),
483                model: "test".to_string(),
484                usage: None,
485                tool_calls: Vec::new(),
486            };
487            cache.put(format!("query {}", i), response).await;
488        }
489
490        let stats = cache.stats().await;
491        assert_eq!(stats.cached_entries, 2);
492    }
493
494    #[tokio::test]
495    async fn test_cached_provider() {
496        let cache = SemanticCache::new(
497            Box::new(MockEmbeddingProvider),
498            SimilarityThreshold::new(0.9),
499            10,
500        );
501
502        let provider = MockLlmProvider;
503        let cached_provider = SemanticCachedProvider::new(provider, cache);
504
505        // First request - cache miss
506        let request = LlmRequest {
507            prompt: "test query".to_string(),
508            system_prompt: None,
509            temperature: None,
510            max_tokens: None,
511            tools: Vec::new(),
512            images: Vec::new(),
513        };
514        let response1 = cached_provider.complete(request.clone()).await.unwrap();
515
516        // Second request - cache hit
517        let response2 = cached_provider.complete(request).await.unwrap();
518
519        assert_eq!(response1.content, response2.content);
520
521        let stats = cached_provider.cache_stats().await;
522        assert_eq!(stats.hits, 1);
523        assert_eq!(stats.misses, 1);
524    }
525
526    #[tokio::test]
527    async fn test_cache_stats() {
528        let cache = SemanticCache::new(
529            Box::new(MockEmbeddingProvider),
530            SimilarityThreshold::new(0.9),
531            10,
532        );
533
534        let stats = cache.stats().await;
535        assert_eq!(stats.hit_rate(), 0.0);
536
537        // Add some hits and misses
538        let response = LlmResponse {
539            content: "test".to_string(),
540            model: "test".to_string(),
541            usage: None,
542            tool_calls: Vec::new(),
543        };
544        cache.put("query".to_string(), response).await;
545        cache.get("query").await; // hit
546        cache.get("other query").await; // miss
547
548        let stats = cache.stats().await;
549        assert_eq!(stats.hits, 1);
550        assert_eq!(stats.misses, 1);
551        assert_eq!(stats.hit_rate(), 0.5);
552    }
553
554    #[tokio::test]
555    async fn test_clear_cache() {
556        let cache = SemanticCache::new(
557            Box::new(MockEmbeddingProvider),
558            SimilarityThreshold::new(0.9),
559            10,
560        );
561
562        let response = LlmResponse {
563            content: "test".to_string(),
564            model: "test".to_string(),
565            usage: None,
566            tool_calls: Vec::new(),
567        };
568        cache.put("query".to_string(), response).await;
569
570        cache.clear().await;
571
572        let stats = cache.stats().await;
573        assert_eq!(stats.cached_entries, 0);
574        assert_eq!(stats.hits, 0);
575        assert_eq!(stats.misses, 0);
576    }
577}