1use super::{Cache, CacheStats, SemanticCacheConfig, SemanticCacheEntry};
6use crate::RragResult;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::SystemTime;
10
11pub struct SemanticCache {
13 config: SemanticCacheConfig,
15
16 storage: HashMap<String, SemanticCacheEntry>,
18
19 embeddings: HashMap<String, Vec<f32>>,
21
22 clusters: Vec<SemanticCluster>,
24
25 query_clusters: HashMap<String, usize>,
27
28 stats: CacheStats,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct SemanticCluster {
35 pub id: usize,
37
38 pub centroid: Vec<f32>,
40
41 pub queries: Vec<String>,
43
44 pub representative: String,
46
47 pub cohesion: f32,
49
50 pub last_updated: SystemTime,
52}
53
54#[derive(Debug, Clone)]
56pub struct SimilaritySearchResult {
57 pub query: String,
59
60 pub similarity: f32,
62
63 pub entry: SemanticCacheEntry,
65}
66
67#[derive(Debug, Clone)]
69pub enum ClusteringAlgorithm {
70 KMeans,
71 HierarchicalClustering,
72 DBSCAN,
73 OnlineKMeans,
74}
75
76impl SemanticCache {
77 pub fn new(config: SemanticCacheConfig) -> RragResult<Self> {
79 Ok(Self {
80 config,
81 storage: HashMap::new(),
82 embeddings: HashMap::new(),
83 clusters: Vec::new(),
84 query_clusters: HashMap::new(),
85 stats: CacheStats::default(),
86 })
87 }
88
89 pub fn find_similar(&self, _query: &str, embedding: &[f32]) -> Vec<SimilaritySearchResult> {
91 let mut results = Vec::new();
92
93 for (cached_query, cached_embedding) in &self.embeddings {
94 let similarity = self.compute_similarity(embedding, cached_embedding);
95
96 if similarity >= self.config.similarity_threshold {
97 if let Some(entry) = self.storage.get(cached_query) {
98 results.push(SimilaritySearchResult {
99 query: cached_query.clone(),
100 similarity,
101 entry: entry.clone(),
102 });
103 }
104 }
105 }
106
107 results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
109
110 results.truncate(10);
112 results
113 }
114
115 pub fn get_or_similar(
117 &self,
118 query: &str,
119 embedding: Option<&[f32]>,
120 ) -> Option<SemanticCacheEntry> {
121 if let Some(entry) = self.storage.get(query) {
123 return Some(entry.clone());
124 }
125
126 if let Some(emb) = embedding {
128 let similar = self.find_similar(query, emb);
129 if let Some(best_match) = similar.first() {
130 return Some(best_match.entry.clone());
131 }
132 }
133
134 None
135 }
136
137 pub fn cache_with_clustering(
139 &mut self,
140 query: String,
141 embedding: Vec<f32>,
142 entry: SemanticCacheEntry,
143 ) -> RragResult<()> {
144 self.embeddings.insert(query.clone(), embedding.clone());
146
147 if self.config.clustering_enabled {
149 let cluster_id = self.assign_to_cluster(&query, &embedding)?;
150 self.query_clusters.insert(query.clone(), cluster_id);
151 }
152
153 self.storage.insert(query, entry);
155
156 if self.config.clustering_enabled && self.storage.len() % 10 == 0 {
158 self.update_clusters()?;
159 }
160
161 Ok(())
162 }
163
164 fn assign_to_cluster(&mut self, query: &str, embedding: &[f32]) -> RragResult<usize> {
166 if self.clusters.is_empty() {
167 let cluster = SemanticCluster {
169 id: 0,
170 centroid: embedding.to_vec(),
171 queries: vec![query.to_string()],
172 representative: query.to_string(),
173 cohesion: 1.0,
174 last_updated: SystemTime::now(),
175 };
176 self.clusters.push(cluster);
177 return Ok(0);
178 }
179
180 let mut best_cluster = 0;
182 let mut best_similarity = 0.0;
183
184 for (i, cluster) in self.clusters.iter().enumerate() {
185 let similarity = self.compute_similarity(embedding, &cluster.centroid);
186 if similarity > best_similarity {
187 best_similarity = similarity;
188 best_cluster = i;
189 }
190 }
191
192 if best_similarity < self.config.similarity_threshold {
194 if self.clusters.len() < self.config.max_clusters {
195 let cluster_id = self.clusters.len();
196 let cluster = SemanticCluster {
197 id: cluster_id,
198 centroid: embedding.to_vec(),
199 queries: vec![query.to_string()],
200 representative: query.to_string(),
201 cohesion: 1.0,
202 last_updated: SystemTime::now(),
203 };
204 self.clusters.push(cluster);
205 return Ok(cluster_id);
206 }
207 }
208
209 if let Some(cluster) = self.clusters.get_mut(best_cluster) {
211 cluster.queries.push(query.to_string());
212 cluster.last_updated = SystemTime::now();
213 }
214
215 Ok(best_cluster)
216 }
217
218 fn update_clusters(&mut self) -> RragResult<()> {
220 for cluster in &mut self.clusters {
221 if cluster.queries.is_empty() {
222 continue;
223 }
224
225 let mut centroid = vec![0.0; cluster.centroid.len()];
227 let mut count = 0;
228
229 for query in &cluster.queries {
230 if let Some(embedding) = self.embeddings.get(query) {
231 for (i, &val) in embedding.iter().enumerate() {
232 if i < centroid.len() {
233 centroid[i] += val;
234 }
235 }
236 count += 1;
237 }
238 }
239
240 if count > 0 {
241 for val in &mut centroid {
242 *val /= count as f32;
243 }
244 cluster.centroid = centroid;
245 }
246
247 let mut best_query = cluster.representative.clone();
249 let mut best_similarity = 0.0;
250
251 for query in &cluster.queries {
252 if let Some(embedding) = self.embeddings.get(query) {
253 let dot_product: f32 = cluster
255 .centroid
256 .iter()
257 .zip(embedding.iter())
258 .map(|(x, y)| x * y)
259 .sum();
260 let norm_a: f32 = cluster.centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
261 let norm_b: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
262 let similarity = if norm_a == 0.0 || norm_b == 0.0 {
263 0.0
264 } else {
265 dot_product / (norm_a * norm_b)
266 };
267
268 if similarity > best_similarity {
269 best_similarity = similarity;
270 best_query = query.clone();
271 }
272 }
273 }
274
275 cluster.representative = best_query;
276 cluster.cohesion = best_similarity;
277 }
278
279 Ok(())
280 }
281
282 fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
284 if a.len() != b.len() || a.is_empty() {
285 return 0.0;
286 }
287
288 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
289 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
290 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
291
292 if norm_a == 0.0 || norm_b == 0.0 {
293 return 0.0;
294 }
295
296 dot_product / (norm_a * norm_b)
297 }
298
299 pub fn get_clusters(&self) -> &[SemanticCluster] {
301 &self.clusters
302 }
303
304 pub fn get_insights(&self) -> SemanticCacheInsights {
306 let total_queries = self.storage.len();
307 let total_clusters = self.clusters.len();
308 let avg_cluster_size = if total_clusters > 0 {
309 total_queries as f32 / total_clusters as f32
310 } else {
311 0.0
312 };
313
314 let cluster_cohesions: Vec<f32> = self.clusters.iter().map(|c| c.cohesion).collect();
315 let avg_cohesion = if !cluster_cohesions.is_empty() {
316 cluster_cohesions.iter().sum::<f32>() / cluster_cohesions.len() as f32
317 } else {
318 0.0
319 };
320
321 SemanticCacheInsights {
322 total_queries,
323 total_clusters,
324 avg_cluster_size,
325 avg_cohesion,
326 similarity_threshold: self.config.similarity_threshold,
327 clustering_enabled: self.config.clustering_enabled,
328 }
329 }
330}
331
332impl Cache<String, SemanticCacheEntry> for SemanticCache {
333 fn get(&self, key: &String) -> Option<SemanticCacheEntry> {
334 self.storage.get(key).cloned()
335 }
336
337 fn put(&mut self, key: String, value: SemanticCacheEntry) -> RragResult<()> {
338 if self.storage.len() >= self.config.max_size {
340 self.evict_entry()?;
341 }
342
343 self.storage.insert(key, value);
344 Ok(())
345 }
346
347 fn remove(&mut self, key: &String) -> Option<SemanticCacheEntry> {
348 let entry = self.storage.remove(key);
349 self.embeddings.remove(key);
350
351 if let Some(cluster_id) = self.query_clusters.remove(key) {
353 if let Some(cluster) = self.clusters.get_mut(cluster_id) {
354 cluster.queries.retain(|q| q != key);
355 }
356 }
357
358 entry
359 }
360
361 fn contains(&self, key: &String) -> bool {
362 self.storage.contains_key(key)
363 }
364
365 fn clear(&mut self) {
366 self.storage.clear();
367 self.embeddings.clear();
368 self.clusters.clear();
369 self.query_clusters.clear();
370 self.stats = CacheStats::default();
371 }
372
373 fn size(&self) -> usize {
374 self.storage.len()
375 }
376
377 fn stats(&self) -> CacheStats {
378 self.stats.clone()
379 }
380}
381
382impl SemanticCache {
383 fn evict_entry(&mut self) -> RragResult<()> {
385 if self.storage.is_empty() {
386 return Ok(());
387 }
388
389 let mut candidate_key: Option<String> = None;
391 let mut min_score = f32::INFINITY;
392
393 for (key, entry) in &self.storage {
394 let access_score = entry.metadata.access_count as f32;
396 let time_score = entry
397 .metadata
398 .last_accessed
399 .elapsed()
400 .unwrap_or_default()
401 .as_secs() as f32;
402
403 let cluster_score = if let Some(&cluster_id) = self.query_clusters.get(key) {
405 if let Some(cluster) = self.clusters.get(cluster_id) {
406 cluster.queries.len() as f32
407 } else {
408 1.0
409 }
410 } else {
411 1.0
412 };
413
414 let eviction_score = access_score / (time_score + 1.0) / cluster_score;
416
417 if eviction_score < min_score {
418 min_score = eviction_score;
419 candidate_key = Some(key.clone());
420 }
421 }
422
423 if let Some(key) = candidate_key {
424 self.remove(&key);
425 self.stats.evictions += 1;
426 }
427
428 Ok(())
429 }
430}
431
432#[derive(Debug, Clone, Serialize, Deserialize)]
434pub struct SemanticCacheInsights {
435 pub total_queries: usize,
437
438 pub total_clusters: usize,
440
441 pub avg_cluster_size: f32,
443
444 pub avg_cohesion: f32,
446
447 pub similarity_threshold: f32,
449
450 pub clustering_enabled: bool,
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use std::collections::HashMap;
458
459 fn create_test_config() -> SemanticCacheConfig {
460 SemanticCacheConfig {
461 enabled: true,
462 max_size: 100,
463 ttl: std::time::Duration::from_secs(3600),
464 similarity_threshold: 0.8,
465 clustering_enabled: true,
466 max_clusters: 10,
467 }
468 }
469
470 fn create_test_entry() -> SemanticCacheEntry {
471 SemanticCacheEntry {
472 representative: "test query".to_string(),
473 cluster_id: None,
474 similar_entries: vec![],
475 results: vec![CachedSearchResult {
476 document_id: "doc1".to_string(),
477 content: "test content".to_string(),
478 score: 0.9,
479 rank: 0,
480 metadata: HashMap::new(),
481 }],
482 metadata: CacheEntryMetadata::new(),
483 }
484 }
485
486 #[test]
487 fn test_semantic_cache_creation() {
488 let config = create_test_config();
489 let cache = SemanticCache::new(config).unwrap();
490
491 assert_eq!(cache.size(), 0);
492 assert_eq!(cache.clusters.len(), 0);
493 }
494
495 #[test]
496 fn test_basic_cache_operations() {
497 let config = create_test_config();
498 let mut cache = SemanticCache::new(config).unwrap();
499
500 let entry = create_test_entry();
501 let key = "test_query".to_string();
502
503 cache.put(key.clone(), entry.clone()).unwrap();
505 assert_eq!(cache.size(), 1);
506
507 let retrieved = cache.get(&key);
508 assert!(retrieved.is_some());
509 assert_eq!(retrieved.unwrap().representative, entry.representative);
510
511 let removed = cache.remove(&key);
513 assert!(removed.is_some());
514 assert_eq!(cache.size(), 0);
515 }
516
517 #[test]
518 fn test_similarity_computation() {
519 let config = create_test_config();
520 let cache = SemanticCache::new(config).unwrap();
521
522 let vec_a = vec![1.0, 0.0, 0.0];
523 let vec_b = vec![1.0, 0.0, 0.0];
524 let vec_c = vec![0.0, 1.0, 0.0];
525
526 let similarity = cache.compute_similarity(&vec_a, &vec_b);
528 assert!((similarity - 1.0).abs() < 0.001);
529
530 let similarity = cache.compute_similarity(&vec_a, &vec_c);
532 assert!((similarity - 0.0).abs() < 0.001);
533 }
534
535 #[test]
536 fn test_clustering() {
537 let config = create_test_config();
538 let mut cache = SemanticCache::new(config).unwrap();
539
540 let entry = create_test_entry();
541 let embedding = vec![1.0, 0.0, 0.0];
542
543 cache
544 .cache_with_clustering("test query".to_string(), embedding, entry)
545 .unwrap();
546
547 assert_eq!(cache.clusters.len(), 1);
548 assert_eq!(cache.clusters[0].queries.len(), 1);
549 }
550
551 #[test]
552 fn test_cache_insights() {
553 let config = create_test_config();
554 let cache = SemanticCache::new(config).unwrap();
555
556 let insights = cache.get_insights();
557 assert_eq!(insights.total_queries, 0);
558 assert_eq!(insights.total_clusters, 0);
559 assert!(insights.clustering_enabled);
560 }
561}