1use super::{
6 Cache, CacheEntryMetadata, CacheStats, CachedSearchResult, QueryCacheConfig, QueryCacheEntry,
7};
8use crate::RragResult;
9use std::collections::HashMap;
10use std::time::SystemTime;
11
12pub struct QueryCache {
14 config: QueryCacheConfig,
16
17 storage: HashMap<String, QueryCacheEntry>,
19
20 normalized_queries: HashMap<String, String>,
22
23 query_patterns: Vec<QueryPattern>,
25
26 access_stats: HashMap<String, QueryAccessStats>,
28
29 stats: CacheStats,
31}
32
33#[derive(Debug, Clone)]
35pub struct QueryPattern {
36 pub id: String,
38
39 pub template: String,
41
42 pub match_count: u64,
44
45 pub avg_similarity: f32,
47
48 pub effectiveness: f32,
50}
51
52#[derive(Debug, Clone)]
54pub struct QueryAccessStats {
55 pub access_count: u64,
57
58 pub last_access: SystemTime,
60
61 pub avg_response_time_ms: f32,
63
64 pub similarity_hit_rate: f32,
66
67 pub variations: Vec<String>,
69}
70
71impl QueryCache {
72 pub fn new(config: QueryCacheConfig) -> RragResult<Self> {
74 Ok(Self {
75 config,
76 storage: HashMap::new(),
77 normalized_queries: HashMap::new(),
78 query_patterns: Vec::new(),
79 access_stats: HashMap::new(),
80 stats: CacheStats::default(),
81 })
82 }
83
84 pub fn get_results(&self, query: &str) -> Option<QueryCacheEntry> {
86 if let Some(entry) = self.storage.get(query) {
88 if !entry.metadata.is_expired() {
89 return Some(entry.clone());
90 }
91 }
92
93 let normalized = self.normalize_query(query);
95 if let Some(canonical) = self.normalized_queries.get(&normalized) {
96 if let Some(entry) = self.storage.get(canonical) {
97 if !entry.metadata.is_expired() {
98 return Some(entry.clone());
99 }
100 }
101 }
102
103 if self.config.similarity_threshold > 0.0 {
105 return self.find_similar_query(query);
106 }
107
108 None
109 }
110
111 pub fn cache_results(
113 &mut self,
114 query: String,
115 results: Vec<CachedSearchResult>,
116 generated_answer: Option<String>,
117 embedding_hash: String,
118 ) -> RragResult<()> {
119 if self.storage.len() >= self.config.max_size {
121 self.evict_entry()?;
122 }
123
124 let mut metadata = CacheEntryMetadata::new();
126 metadata.ttl = Some(self.config.ttl);
127
128 let entry = QueryCacheEntry {
129 query: query.clone(),
130 embedding_hash,
131 results,
132 generated_answer,
133 metadata,
134 };
135
136 let normalized = self.normalize_query(&query);
138 self.normalized_queries.insert(normalized, query.clone());
139 self.storage.insert(query.clone(), entry);
140
141 self.update_patterns(&query);
143
144 self.update_access_stats(&query);
146
147 Ok(())
148 }
149
150 fn normalize_query(&self, query: &str) -> String {
152 query
153 .to_lowercase()
154 .trim()
155 .chars()
156 .filter(|c| c.is_alphanumeric() || c.is_whitespace())
157 .collect::<String>()
158 .split_whitespace()
159 .collect::<Vec<_>>()
160 .join(" ")
161 }
162
163 fn find_similar_query(&self, query: &str) -> Option<QueryCacheEntry> {
165 let normalized = self.normalize_query(query);
166 let query_tokens: Vec<&str> = normalized.split_whitespace().collect();
167
168 let mut best_match: Option<(&String, &QueryCacheEntry, f32)> = None;
169
170 for (cached_query, entry) in &self.storage {
171 if entry.metadata.is_expired() {
172 continue;
173 }
174
175 let cached_normalized = self.normalize_query(cached_query);
176 let cached_tokens: Vec<&str> = cached_normalized.split_whitespace().collect();
177
178 let intersection = query_tokens
180 .iter()
181 .filter(|t| cached_tokens.contains(t))
182 .count();
183 let union = (query_tokens.len() + cached_tokens.len() - intersection).max(1);
184 let similarity = intersection as f32 / union as f32;
185
186 if similarity >= self.config.similarity_threshold {
187 if best_match.is_none() || similarity > best_match.as_ref().unwrap().2 {
188 best_match = Some((cached_query, entry, similarity));
189 }
190 }
191 }
192
193 best_match.map(|(_, entry, _)| entry.clone())
194 }
195
196 fn update_patterns(&mut self, query: &str) {
198 let pattern = self.extract_pattern(query);
200
201 if let Some(existing) = self
203 .query_patterns
204 .iter_mut()
205 .find(|p| p.template == pattern)
206 {
207 existing.match_count += 1;
208 } else if self.query_patterns.len() < 100 {
209 self.query_patterns.push(QueryPattern {
211 id: format!("pattern_{}", self.query_patterns.len()),
212 template: pattern,
213 match_count: 1,
214 avg_similarity: 0.0,
215 effectiveness: 0.0,
216 });
217 }
218 }
219
220 fn extract_pattern(&self, query: &str) -> String {
222 let mut pattern = query.to_string();
224
225 pattern = regex::Regex::new(r"\b\d+\b")
227 .unwrap_or_else(|_| regex::Regex::new("").unwrap())
228 .replace_all(&pattern, "{NUM}")
229 .to_string();
230
231 pattern = regex::Regex::new(r#""[^"]*""#)
233 .unwrap_or_else(|_| regex::Regex::new("").unwrap())
234 .replace_all(&pattern, "{STR}")
235 .to_string();
236
237 pattern
238 }
239
240 fn update_access_stats(&mut self, query: &str) {
242 let stats = self
243 .access_stats
244 .entry(query.to_string())
245 .or_insert_with(|| QueryAccessStats {
246 access_count: 0,
247 last_access: SystemTime::now(),
248 avg_response_time_ms: 0.0,
249 similarity_hit_rate: 0.0,
250 variations: Vec::new(),
251 });
252
253 stats.access_count += 1;
254 stats.last_access = SystemTime::now();
255 }
256
257 fn evict_entry(&mut self) -> RragResult<()> {
259 use super::EvictionPolicy;
260
261 match self.config.eviction_policy {
262 EvictionPolicy::LRU => self.evict_lru(),
263 EvictionPolicy::LFU => self.evict_lfu(),
264 EvictionPolicy::TTL => self.evict_expired(),
265 _ => self.evict_lru(), }
267 }
268
269 fn evict_lru(&mut self) -> RragResult<()> {
271 if let Some((key, _)) = self
272 .storage
273 .iter()
274 .min_by_key(|(_, entry)| entry.metadata.last_accessed)
275 {
276 let key = key.clone();
277 self.storage.remove(&key);
278 self.stats.evictions += 1;
279 }
280 Ok(())
281 }
282
283 fn evict_lfu(&mut self) -> RragResult<()> {
285 if let Some((key, _)) = self
286 .storage
287 .iter()
288 .min_by_key(|(_, entry)| entry.metadata.access_count)
289 {
290 let key = key.clone();
291 self.storage.remove(&key);
292 self.stats.evictions += 1;
293 }
294 Ok(())
295 }
296
297 fn evict_expired(&mut self) -> RragResult<()> {
299 let _now = SystemTime::now();
300 let before_count = self.storage.len();
301
302 self.storage.retain(|_, entry| !entry.metadata.is_expired());
303
304 let evicted = before_count - self.storage.len();
305 self.stats.evictions += evicted as u64;
306
307 if self.storage.len() >= self.config.max_size {
309 self.evict_lru()?;
310 }
311
312 Ok(())
313 }
314
315 pub fn get_insights(&self) -> QueryCacheInsights {
317 let total_queries = self.storage.len();
318 let expired_queries = self
319 .storage
320 .values()
321 .filter(|e| e.metadata.is_expired())
322 .count();
323
324 let avg_results_per_query = if total_queries > 0 {
325 self.storage
326 .values()
327 .map(|e| e.results.len())
328 .sum::<usize>() as f32
329 / total_queries as f32
330 } else {
331 0.0
332 };
333
334 let top_patterns: Vec<String> = self
335 .query_patterns
336 .iter()
337 .filter(|p| p.match_count > 1)
338 .take(5)
339 .map(|p| p.template.clone())
340 .collect();
341
342 QueryCacheInsights {
343 total_queries,
344 expired_queries,
345 avg_results_per_query,
346 top_patterns,
347 similarity_threshold: self.config.similarity_threshold,
348 }
349 }
350}
351
352impl Cache<String, QueryCacheEntry> for QueryCache {
353 fn get(&self, key: &String) -> Option<QueryCacheEntry> {
354 self.get_results(key)
355 }
356
357 fn put(&mut self, key: String, value: QueryCacheEntry) -> RragResult<()> {
358 if self.storage.len() >= self.config.max_size {
359 self.evict_entry()?;
360 }
361
362 let normalized = self.normalize_query(&key);
363 self.normalized_queries.insert(normalized, key.clone());
364 self.storage.insert(key, value);
365 Ok(())
366 }
367
368 fn remove(&mut self, key: &String) -> Option<QueryCacheEntry> {
369 let entry = self.storage.remove(key);
370
371 let normalized = self.normalize_query(key);
373 self.normalized_queries.remove(&normalized);
374
375 self.access_stats.remove(key);
377
378 entry
379 }
380
381 fn contains(&self, key: &String) -> bool {
382 self.storage.contains_key(key)
383 && !self
384 .storage
385 .get(key)
386 .map_or(true, |e| e.metadata.is_expired())
387 }
388
389 fn clear(&mut self) {
390 self.storage.clear();
391 self.normalized_queries.clear();
392 self.query_patterns.clear();
393 self.access_stats.clear();
394 self.stats = CacheStats::default();
395 }
396
397 fn size(&self) -> usize {
398 self.storage
399 .values()
400 .filter(|e| !e.metadata.is_expired())
401 .count()
402 }
403
404 fn stats(&self) -> CacheStats {
405 self.stats.clone()
406 }
407}
408
409#[derive(Debug, Clone)]
411pub struct QueryCacheInsights {
412 pub total_queries: usize,
414
415 pub expired_queries: usize,
417
418 pub avg_results_per_query: f32,
420
421 pub top_patterns: Vec<String>,
423
424 pub similarity_threshold: f32,
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 fn create_test_config() -> QueryCacheConfig {
433 QueryCacheConfig {
434 enabled: true,
435 max_size: 100,
436 ttl: Duration::from_secs(3600),
437 eviction_policy: super::super::EvictionPolicy::LRU,
438 similarity_threshold: 0.8,
439 }
440 }
441
442 fn create_test_results() -> Vec<CachedSearchResult> {
443 vec![CachedSearchResult {
444 document_id: "doc1".to_string(),
445 content: "test content".to_string(),
446 score: 0.9,
447 rank: 0,
448 metadata: HashMap::new(),
449 }]
450 }
451
452 #[test]
453 fn test_query_cache_creation() {
454 let config = create_test_config();
455 let cache = QueryCache::new(config).unwrap();
456
457 assert_eq!(cache.size(), 0);
458 assert_eq!(cache.query_patterns.len(), 0);
459 }
460
461 #[test]
462 fn test_basic_caching() {
463 let config = create_test_config();
464 let mut cache = QueryCache::new(config).unwrap();
465
466 let query = "test query".to_string();
467 let results = create_test_results();
468
469 cache
470 .cache_results(query.clone(), results.clone(), None, "hash123".to_string())
471 .unwrap();
472
473 assert_eq!(cache.size(), 1);
474
475 let cached = cache.get_results(&query);
476 assert!(cached.is_some());
477 assert_eq!(cached.unwrap().results.len(), 1);
478 }
479
480 #[test]
481 fn test_query_normalization() {
482 let config = create_test_config();
483 let cache = QueryCache::new(config).unwrap();
484
485 let query1 = " What is Rust? ";
486 let query2 = "what is rust";
487 let query3 = "What is Rust???";
488
489 let norm1 = cache.normalize_query(query1);
490 let norm2 = cache.normalize_query(query2);
491 let norm3 = cache.normalize_query(query3);
492
493 assert_eq!(norm1, norm2);
494 assert_eq!(norm2, norm3);
495 }
496
497 #[test]
498 fn test_similarity_matching() {
499 let config = create_test_config();
500 let mut cache = QueryCache::new(config).unwrap();
501
502 let query1 = "how to learn rust programming".to_string();
503 let results = create_test_results();
504
505 cache
506 .cache_results(query1.clone(), results.clone(), None, "hash1".to_string())
507 .unwrap();
508
509 let query2 = "learn rust programming how to";
511 let cached = cache.get_results(query2);
512 assert!(cached.is_some());
513 }
514
515 #[test]
516 fn test_pattern_extraction() {
517 let config = create_test_config();
518 let cache = QueryCache::new(config).unwrap();
519
520 let query1 = "get user 123 details";
521 let query2 = "get user 456 details";
522
523 let pattern1 = cache.extract_pattern(query1);
524 let pattern2 = cache.extract_pattern(query2);
525
526 assert_eq!(pattern1, pattern2);
527 assert!(pattern1.contains("{NUM}"));
528 }
529
530 #[test]
531 fn test_eviction() {
532 let mut config = create_test_config();
533 config.max_size = 2;
534 let mut cache = QueryCache::new(config).unwrap();
535
536 let results = create_test_results();
537
538 cache
539 .cache_results(
540 "query1".to_string(),
541 results.clone(),
542 None,
543 "h1".to_string(),
544 )
545 .unwrap();
546 cache
547 .cache_results(
548 "query2".to_string(),
549 results.clone(),
550 None,
551 "h2".to_string(),
552 )
553 .unwrap();
554
555 assert_eq!(cache.size(), 2);
556
557 cache
559 .cache_results(
560 "query3".to_string(),
561 results.clone(),
562 None,
563 "h3".to_string(),
564 )
565 .unwrap();
566
567 assert_eq!(cache.size(), 2);
568 assert_eq!(cache.stats.evictions, 1);
569 }
570}