rrag/caching/
semantic_cache.rs

1//! # Semantic Cache Implementation
2//!
3//! Intelligent caching based on semantic similarity for RAG applications.
4
5use super::{Cache, CacheStats, SemanticCacheConfig, SemanticCacheEntry};
6use crate::RragResult;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::SystemTime;
10
11/// Semantic cache with similarity-based retrieval
12pub struct SemanticCache {
13    /// Configuration
14    config: SemanticCacheConfig,
15
16    /// Main storage indexed by query hash
17    storage: HashMap<String, SemanticCacheEntry>,
18
19    /// Embedding vectors for similarity computation
20    embeddings: HashMap<String, Vec<f32>>,
21
22    /// Semantic clusters for efficient search
23    clusters: Vec<SemanticCluster>,
24
25    /// Query to cluster mapping
26    query_clusters: HashMap<String, usize>,
27
28    /// Cache statistics
29    stats: CacheStats,
30}
31
32/// Semantic cluster for grouping similar queries
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct SemanticCluster {
35    /// Cluster ID
36    pub id: usize,
37
38    /// Cluster centroid (average embedding)
39    pub centroid: Vec<f32>,
40
41    /// Queries in this cluster
42    pub queries: Vec<String>,
43
44    /// Representative query (closest to centroid)
45    pub representative: String,
46
47    /// Cluster quality metrics
48    pub cohesion: f32,
49
50    /// Last updated
51    pub last_updated: SystemTime,
52}
53
54/// Similarity search result
55#[derive(Debug, Clone)]
56pub struct SimilaritySearchResult {
57    /// Query text
58    pub query: String,
59
60    /// Similarity score
61    pub similarity: f32,
62
63    /// Cached entry
64    pub entry: SemanticCacheEntry,
65}
66
67/// Clustering algorithm types
68#[derive(Debug, Clone)]
69pub enum ClusteringAlgorithm {
70    KMeans,
71    HierarchicalClustering,
72    DBSCAN,
73    OnlineKMeans,
74}
75
76impl SemanticCache {
77    /// Create new semantic cache
78    pub fn new(config: SemanticCacheConfig) -> RragResult<Self> {
79        Ok(Self {
80            config,
81            storage: HashMap::new(),
82            embeddings: HashMap::new(),
83            clusters: Vec::new(),
84            query_clusters: HashMap::new(),
85            stats: CacheStats::default(),
86        })
87    }
88
89    /// Find semantically similar cached entries
90    pub fn find_similar(&self, _query: &str, embedding: &[f32]) -> Vec<SimilaritySearchResult> {
91        let mut results = Vec::new();
92
93        for (cached_query, cached_embedding) in &self.embeddings {
94            let similarity = self.compute_similarity(embedding, cached_embedding);
95
96            if similarity >= self.config.similarity_threshold {
97                if let Some(entry) = self.storage.get(cached_query) {
98                    results.push(SimilaritySearchResult {
99                        query: cached_query.clone(),
100                        similarity,
101                        entry: entry.clone(),
102                    });
103                }
104            }
105        }
106
107        // Sort by similarity descending
108        results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
109
110        // Limit results
111        results.truncate(10);
112        results
113    }
114
115    /// Get or find semantically similar entry
116    pub fn get_or_similar(
117        &self,
118        query: &str,
119        embedding: Option<&[f32]>,
120    ) -> Option<SemanticCacheEntry> {
121        // Direct hit first
122        if let Some(entry) = self.storage.get(query) {
123            return Some(entry.clone());
124        }
125
126        // Semantic similarity search
127        if let Some(emb) = embedding {
128            let similar = self.find_similar(query, emb);
129            if let Some(best_match) = similar.first() {
130                return Some(best_match.entry.clone());
131            }
132        }
133
134        None
135    }
136
137    /// Cache entry with semantic clustering
138    pub fn cache_with_clustering(
139        &mut self,
140        query: String,
141        embedding: Vec<f32>,
142        entry: SemanticCacheEntry,
143    ) -> RragResult<()> {
144        // Store embedding
145        self.embeddings.insert(query.clone(), embedding.clone());
146
147        // Find best cluster or create new one
148        if self.config.clustering_enabled {
149            let cluster_id = self.assign_to_cluster(&query, &embedding)?;
150            self.query_clusters.insert(query.clone(), cluster_id);
151        }
152
153        // Store entry
154        self.storage.insert(query, entry);
155
156        // Update clusters if needed
157        if self.config.clustering_enabled && self.storage.len() % 10 == 0 {
158            self.update_clusters()?;
159        }
160
161        Ok(())
162    }
163
164    /// Assign query to best cluster
165    fn assign_to_cluster(&mut self, query: &str, embedding: &[f32]) -> RragResult<usize> {
166        if self.clusters.is_empty() {
167            // Create first cluster
168            let cluster = SemanticCluster {
169                id: 0,
170                centroid: embedding.to_vec(),
171                queries: vec![query.to_string()],
172                representative: query.to_string(),
173                cohesion: 1.0,
174                last_updated: SystemTime::now(),
175            };
176            self.clusters.push(cluster);
177            return Ok(0);
178        }
179
180        // Find best cluster by centroid similarity
181        let mut best_cluster = 0;
182        let mut best_similarity = 0.0;
183
184        for (i, cluster) in self.clusters.iter().enumerate() {
185            let similarity = self.compute_similarity(embedding, &cluster.centroid);
186            if similarity > best_similarity {
187                best_similarity = similarity;
188                best_cluster = i;
189            }
190        }
191
192        // Create new cluster if similarity is too low
193        if best_similarity < self.config.similarity_threshold {
194            if self.clusters.len() < self.config.max_clusters {
195                let cluster_id = self.clusters.len();
196                let cluster = SemanticCluster {
197                    id: cluster_id,
198                    centroid: embedding.to_vec(),
199                    queries: vec![query.to_string()],
200                    representative: query.to_string(),
201                    cohesion: 1.0,
202                    last_updated: SystemTime::now(),
203                };
204                self.clusters.push(cluster);
205                return Ok(cluster_id);
206            }
207        }
208
209        // Add to best cluster
210        if let Some(cluster) = self.clusters.get_mut(best_cluster) {
211            cluster.queries.push(query.to_string());
212            cluster.last_updated = SystemTime::now();
213        }
214
215        Ok(best_cluster)
216    }
217
218    /// Update cluster centroids and representatives
219    fn update_clusters(&mut self) -> RragResult<()> {
220        for cluster in &mut self.clusters {
221            if cluster.queries.is_empty() {
222                continue;
223            }
224
225            // Compute new centroid
226            let mut centroid = vec![0.0; cluster.centroid.len()];
227            let mut count = 0;
228
229            for query in &cluster.queries {
230                if let Some(embedding) = self.embeddings.get(query) {
231                    for (i, &val) in embedding.iter().enumerate() {
232                        if i < centroid.len() {
233                            centroid[i] += val;
234                        }
235                    }
236                    count += 1;
237                }
238            }
239
240            if count > 0 {
241                for val in &mut centroid {
242                    *val /= count as f32;
243                }
244                cluster.centroid = centroid;
245            }
246
247            // Find new representative (closest to centroid)
248            let mut best_query = cluster.representative.clone();
249            let mut best_similarity = 0.0;
250
251            for query in &cluster.queries {
252                if let Some(embedding) = self.embeddings.get(query) {
253                    // Inline cosine similarity calculation to avoid borrowing self
254                    let dot_product: f32 = cluster
255                        .centroid
256                        .iter()
257                        .zip(embedding.iter())
258                        .map(|(x, y)| x * y)
259                        .sum();
260                    let norm_a: f32 = cluster.centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
261                    let norm_b: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
262                    let similarity = if norm_a == 0.0 || norm_b == 0.0 {
263                        0.0
264                    } else {
265                        dot_product / (norm_a * norm_b)
266                    };
267
268                    if similarity > best_similarity {
269                        best_similarity = similarity;
270                        best_query = query.clone();
271                    }
272                }
273            }
274
275            cluster.representative = best_query;
276            cluster.cohesion = best_similarity;
277        }
278
279        Ok(())
280    }
281
282    /// Compute cosine similarity between two embeddings
283    fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
284        if a.len() != b.len() || a.is_empty() {
285            return 0.0;
286        }
287
288        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
289        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
290        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
291
292        if norm_a == 0.0 || norm_b == 0.0 {
293            return 0.0;
294        }
295
296        dot_product / (norm_a * norm_b)
297    }
298
299    /// Get cluster information
300    pub fn get_clusters(&self) -> &[SemanticCluster] {
301        &self.clusters
302    }
303
304    /// Get cache insights
305    pub fn get_insights(&self) -> SemanticCacheInsights {
306        let total_queries = self.storage.len();
307        let total_clusters = self.clusters.len();
308        let avg_cluster_size = if total_clusters > 0 {
309            total_queries as f32 / total_clusters as f32
310        } else {
311            0.0
312        };
313
314        let cluster_cohesions: Vec<f32> = self.clusters.iter().map(|c| c.cohesion).collect();
315        let avg_cohesion = if !cluster_cohesions.is_empty() {
316            cluster_cohesions.iter().sum::<f32>() / cluster_cohesions.len() as f32
317        } else {
318            0.0
319        };
320
321        SemanticCacheInsights {
322            total_queries,
323            total_clusters,
324            avg_cluster_size,
325            avg_cohesion,
326            similarity_threshold: self.config.similarity_threshold,
327            clustering_enabled: self.config.clustering_enabled,
328        }
329    }
330}
331
332impl Cache<String, SemanticCacheEntry> for SemanticCache {
333    fn get(&self, key: &String) -> Option<SemanticCacheEntry> {
334        self.storage.get(key).cloned()
335    }
336
337    fn put(&mut self, key: String, value: SemanticCacheEntry) -> RragResult<()> {
338        // Check capacity
339        if self.storage.len() >= self.config.max_size {
340            self.evict_entry()?;
341        }
342
343        self.storage.insert(key, value);
344        Ok(())
345    }
346
347    fn remove(&mut self, key: &String) -> Option<SemanticCacheEntry> {
348        let entry = self.storage.remove(key);
349        self.embeddings.remove(key);
350
351        // Remove from cluster
352        if let Some(cluster_id) = self.query_clusters.remove(key) {
353            if let Some(cluster) = self.clusters.get_mut(cluster_id) {
354                cluster.queries.retain(|q| q != key);
355            }
356        }
357
358        entry
359    }
360
361    fn contains(&self, key: &String) -> bool {
362        self.storage.contains_key(key)
363    }
364
365    fn clear(&mut self) {
366        self.storage.clear();
367        self.embeddings.clear();
368        self.clusters.clear();
369        self.query_clusters.clear();
370        self.stats = CacheStats::default();
371    }
372
373    fn size(&self) -> usize {
374        self.storage.len()
375    }
376
377    fn stats(&self) -> CacheStats {
378        self.stats.clone()
379    }
380}
381
382impl SemanticCache {
383    /// Evict entry using semantic-aware policy
384    fn evict_entry(&mut self) -> RragResult<()> {
385        if self.storage.is_empty() {
386            return Ok(());
387        }
388
389        // Find entry with lowest access frequency in largest cluster
390        let mut candidate_key: Option<String> = None;
391        let mut min_score = f32::INFINITY;
392
393        for (key, entry) in &self.storage {
394            // Calculate eviction score based on access patterns and cluster size
395            let access_score = entry.metadata.access_count as f32;
396            let time_score = entry
397                .metadata
398                .last_accessed
399                .elapsed()
400                .unwrap_or_default()
401                .as_secs() as f32;
402
403            // Prefer evicting from larger clusters
404            let cluster_score = if let Some(&cluster_id) = self.query_clusters.get(key) {
405                if let Some(cluster) = self.clusters.get(cluster_id) {
406                    cluster.queries.len() as f32
407                } else {
408                    1.0
409                }
410            } else {
411                1.0
412            };
413
414            // Combined score (lower is better for eviction)
415            let eviction_score = access_score / (time_score + 1.0) / cluster_score;
416
417            if eviction_score < min_score {
418                min_score = eviction_score;
419                candidate_key = Some(key.clone());
420            }
421        }
422
423        if let Some(key) = candidate_key {
424            self.remove(&key);
425            self.stats.evictions += 1;
426        }
427
428        Ok(())
429    }
430}
431
432/// Semantic cache insights
433#[derive(Debug, Clone, Serialize, Deserialize)]
434pub struct SemanticCacheInsights {
435    /// Total cached queries
436    pub total_queries: usize,
437
438    /// Total clusters
439    pub total_clusters: usize,
440
441    /// Average cluster size
442    pub avg_cluster_size: f32,
443
444    /// Average cluster cohesion
445    pub avg_cohesion: f32,
446
447    /// Configured similarity threshold
448    pub similarity_threshold: f32,
449
450    /// Whether clustering is enabled
451    pub clustering_enabled: bool,
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use std::collections::HashMap;
458
459    fn create_test_config() -> SemanticCacheConfig {
460        SemanticCacheConfig {
461            enabled: true,
462            max_size: 100,
463            ttl: std::time::Duration::from_secs(3600),
464            similarity_threshold: 0.8,
465            clustering_enabled: true,
466            max_clusters: 10,
467        }
468    }
469
470    fn create_test_entry() -> SemanticCacheEntry {
471        SemanticCacheEntry {
472            representative: "test query".to_string(),
473            cluster_id: None,
474            similar_entries: vec![],
475            results: vec![CachedSearchResult {
476                document_id: "doc1".to_string(),
477                content: "test content".to_string(),
478                score: 0.9,
479                rank: 0,
480                metadata: HashMap::new(),
481            }],
482            metadata: CacheEntryMetadata::new(),
483        }
484    }
485
486    #[test]
487    fn test_semantic_cache_creation() {
488        let config = create_test_config();
489        let cache = SemanticCache::new(config).unwrap();
490
491        assert_eq!(cache.size(), 0);
492        assert_eq!(cache.clusters.len(), 0);
493    }
494
495    #[test]
496    fn test_basic_cache_operations() {
497        let config = create_test_config();
498        let mut cache = SemanticCache::new(config).unwrap();
499
500        let entry = create_test_entry();
501        let key = "test_query".to_string();
502
503        // Test put and get
504        cache.put(key.clone(), entry.clone()).unwrap();
505        assert_eq!(cache.size(), 1);
506
507        let retrieved = cache.get(&key);
508        assert!(retrieved.is_some());
509        assert_eq!(retrieved.unwrap().representative, entry.representative);
510
511        // Test remove
512        let removed = cache.remove(&key);
513        assert!(removed.is_some());
514        assert_eq!(cache.size(), 0);
515    }
516
517    #[test]
518    fn test_similarity_computation() {
519        let config = create_test_config();
520        let cache = SemanticCache::new(config).unwrap();
521
522        let vec_a = vec![1.0, 0.0, 0.0];
523        let vec_b = vec![1.0, 0.0, 0.0];
524        let vec_c = vec![0.0, 1.0, 0.0];
525
526        // Test identical vectors
527        let similarity = cache.compute_similarity(&vec_a, &vec_b);
528        assert!((similarity - 1.0).abs() < 0.001);
529
530        // Test orthogonal vectors
531        let similarity = cache.compute_similarity(&vec_a, &vec_c);
532        assert!((similarity - 0.0).abs() < 0.001);
533    }
534
535    #[test]
536    fn test_clustering() {
537        let config = create_test_config();
538        let mut cache = SemanticCache::new(config).unwrap();
539
540        let entry = create_test_entry();
541        let embedding = vec![1.0, 0.0, 0.0];
542
543        cache
544            .cache_with_clustering("test query".to_string(), embedding, entry)
545            .unwrap();
546
547        assert_eq!(cache.clusters.len(), 1);
548        assert_eq!(cache.clusters[0].queries.len(), 1);
549    }
550
551    #[test]
552    fn test_cache_insights() {
553        let config = create_test_config();
554        let cache = SemanticCache::new(config).unwrap();
555
556        let insights = cache.get_insights();
557        assert_eq!(insights.total_queries, 0);
558        assert_eq!(insights.total_clusters, 0);
559        assert!(insights.clustering_enabled);
560    }
561}