rexis_rag/caching/
query_cache.rs

1//! # Query Cache Implementation
2//!
3//! Caching for query results with intelligent similarity matching.
4
5use super::{
6    Cache, CacheEntryMetadata, CacheStats, CachedSearchResult, QueryCacheConfig, QueryCacheEntry,
7};
8use crate::RragResult;
9use std::collections::HashMap;
10use std::time::SystemTime;
11
12/// Query cache with similarity-based retrieval
13pub struct QueryCache {
14    /// Configuration
15    config: QueryCacheConfig,
16
17    /// Main storage
18    storage: HashMap<String, QueryCacheEntry>,
19
20    /// Query normalization cache
21    normalized_queries: HashMap<String, String>,
22
23    /// Query patterns for template matching
24    query_patterns: Vec<QueryPattern>,
25
26    /// Access statistics for adaptive caching
27    access_stats: HashMap<String, QueryAccessStats>,
28
29    /// Cache statistics
30    stats: CacheStats,
31}
32
33/// Query pattern for template-based caching
34#[derive(Debug, Clone)]
35pub struct QueryPattern {
36    /// Pattern ID
37    pub id: String,
38
39    /// Pattern template (with placeholders)
40    pub template: String,
41
42    /// Pattern match count
43    pub match_count: u64,
44
45    /// Average result similarity
46    pub avg_similarity: f32,
47
48    /// Pattern effectiveness score
49    pub effectiveness: f32,
50}
51
52/// Query access statistics
53#[derive(Debug, Clone)]
54pub struct QueryAccessStats {
55    /// Total accesses
56    pub access_count: u64,
57
58    /// Last access time
59    pub last_access: SystemTime,
60
61    /// Average response time
62    pub avg_response_time_ms: f32,
63
64    /// Cache hit rate for similar queries
65    pub similarity_hit_rate: f32,
66
67    /// Query variations seen
68    pub variations: Vec<String>,
69}
70
71impl QueryCache {
72    /// Create new query cache
73    pub fn new(config: QueryCacheConfig) -> RragResult<Self> {
74        Ok(Self {
75            config,
76            storage: HashMap::new(),
77            normalized_queries: HashMap::new(),
78            query_patterns: Vec::new(),
79            access_stats: HashMap::new(),
80            stats: CacheStats::default(),
81        })
82    }
83
84    /// Get cached results for query
85    pub fn get_results(&self, query: &str) -> Option<QueryCacheEntry> {
86        // Direct lookup
87        if let Some(entry) = self.storage.get(query) {
88            if !entry.metadata.is_expired() {
89                return Some(entry.clone());
90            }
91        }
92
93        // Try normalized query
94        let normalized = self.normalize_query(query);
95        if let Some(canonical) = self.normalized_queries.get(&normalized) {
96            if let Some(entry) = self.storage.get(canonical) {
97                if !entry.metadata.is_expired() {
98                    return Some(entry.clone());
99                }
100            }
101        }
102
103        // Try similarity matching if threshold is set
104        if self.config.similarity_threshold > 0.0 {
105            return self.find_similar_query(query);
106        }
107
108        None
109    }
110
111    /// Cache query results with intelligent deduplication
112    pub fn cache_results(
113        &mut self,
114        query: String,
115        results: Vec<CachedSearchResult>,
116        generated_answer: Option<String>,
117        embedding_hash: String,
118    ) -> RragResult<()> {
119        // Check capacity
120        if self.storage.len() >= self.config.max_size {
121            self.evict_entry()?;
122        }
123
124        // Create cache entry
125        let mut metadata = CacheEntryMetadata::new();
126        metadata.ttl = Some(self.config.ttl);
127
128        let entry = QueryCacheEntry {
129            query: query.clone(),
130            embedding_hash,
131            results,
132            generated_answer,
133            metadata,
134        };
135
136        // Store with normalization
137        let normalized = self.normalize_query(&query);
138        self.normalized_queries.insert(normalized, query.clone());
139        self.storage.insert(query.clone(), entry);
140
141        // Update patterns
142        self.update_patterns(&query);
143
144        // Update access stats
145        self.update_access_stats(&query);
146
147        Ok(())
148    }
149
150    /// Normalize query for better cache hits
151    fn normalize_query(&self, query: &str) -> String {
152        query
153            .to_lowercase()
154            .trim()
155            .chars()
156            .filter(|c| c.is_alphanumeric() || c.is_whitespace())
157            .collect::<String>()
158            .split_whitespace()
159            .collect::<Vec<_>>()
160            .join(" ")
161    }
162
163    /// Find similar cached query
164    fn find_similar_query(&self, query: &str) -> Option<QueryCacheEntry> {
165        let normalized = self.normalize_query(query);
166        let query_tokens: Vec<&str> = normalized.split_whitespace().collect();
167
168        let mut best_match: Option<(&String, &QueryCacheEntry, f32)> = None;
169
170        for (cached_query, entry) in &self.storage {
171            if entry.metadata.is_expired() {
172                continue;
173            }
174
175            let cached_normalized = self.normalize_query(cached_query);
176            let cached_tokens: Vec<&str> = cached_normalized.split_whitespace().collect();
177
178            // Calculate Jaccard similarity
179            let intersection = query_tokens
180                .iter()
181                .filter(|t| cached_tokens.contains(t))
182                .count();
183            let union = (query_tokens.len() + cached_tokens.len() - intersection).max(1);
184            let similarity = intersection as f32 / union as f32;
185
186            if similarity >= self.config.similarity_threshold {
187                if best_match.is_none() || similarity > best_match.as_ref().unwrap().2 {
188                    best_match = Some((cached_query, entry, similarity));
189                }
190            }
191        }
192
193        best_match.map(|(_, entry, _)| entry.clone())
194    }
195
196    /// Update query patterns
197    fn update_patterns(&mut self, query: &str) {
198        // Extract potential pattern from query
199        let pattern = self.extract_pattern(query);
200
201        // Check if pattern exists
202        if let Some(existing) = self
203            .query_patterns
204            .iter_mut()
205            .find(|p| p.template == pattern)
206        {
207            existing.match_count += 1;
208        } else if self.query_patterns.len() < 100 {
209            // Limit patterns
210            self.query_patterns.push(QueryPattern {
211                id: format!("pattern_{}", self.query_patterns.len()),
212                template: pattern,
213                match_count: 1,
214                avg_similarity: 0.0,
215                effectiveness: 0.0,
216            });
217        }
218    }
219
220    /// Extract pattern from query
221    fn extract_pattern(&self, query: &str) -> String {
222        // Simple pattern extraction - replace numbers and quoted strings
223        let mut pattern = query.to_string();
224
225        // Replace numbers with placeholder
226        pattern = regex::Regex::new(r"\b\d+\b")
227            .unwrap_or_else(|_| regex::Regex::new("").unwrap())
228            .replace_all(&pattern, "{NUM}")
229            .to_string();
230
231        // Replace quoted strings with placeholder
232        pattern = regex::Regex::new(r#""[^"]*""#)
233            .unwrap_or_else(|_| regex::Regex::new("").unwrap())
234            .replace_all(&pattern, "{STR}")
235            .to_string();
236
237        pattern
238    }
239
240    /// Update access statistics
241    fn update_access_stats(&mut self, query: &str) {
242        let stats = self
243            .access_stats
244            .entry(query.to_string())
245            .or_insert_with(|| QueryAccessStats {
246                access_count: 0,
247                last_access: SystemTime::now(),
248                avg_response_time_ms: 0.0,
249                similarity_hit_rate: 0.0,
250                variations: Vec::new(),
251            });
252
253        stats.access_count += 1;
254        stats.last_access = SystemTime::now();
255    }
256
257    /// Evict entry based on policy
258    fn evict_entry(&mut self) -> RragResult<()> {
259        use super::EvictionPolicy;
260
261        match self.config.eviction_policy {
262            EvictionPolicy::LRU => self.evict_lru(),
263            EvictionPolicy::LFU => self.evict_lfu(),
264            EvictionPolicy::TTL => self.evict_expired(),
265            _ => self.evict_lru(), // Default to LRU
266        }
267    }
268
269    /// Evict least recently used entry
270    fn evict_lru(&mut self) -> RragResult<()> {
271        if let Some((key, _)) = self
272            .storage
273            .iter()
274            .min_by_key(|(_, entry)| entry.metadata.last_accessed)
275        {
276            let key = key.clone();
277            self.storage.remove(&key);
278            self.stats.evictions += 1;
279        }
280        Ok(())
281    }
282
283    /// Evict least frequently used entry
284    fn evict_lfu(&mut self) -> RragResult<()> {
285        if let Some((key, _)) = self
286            .storage
287            .iter()
288            .min_by_key(|(_, entry)| entry.metadata.access_count)
289        {
290            let key = key.clone();
291            self.storage.remove(&key);
292            self.stats.evictions += 1;
293        }
294        Ok(())
295    }
296
297    /// Evict expired entries
298    fn evict_expired(&mut self) -> RragResult<()> {
299        let _now = SystemTime::now();
300        let before_count = self.storage.len();
301
302        self.storage.retain(|_, entry| !entry.metadata.is_expired());
303
304        let evicted = before_count - self.storage.len();
305        self.stats.evictions += evicted as u64;
306
307        // If still over capacity, evict oldest
308        if self.storage.len() >= self.config.max_size {
309            self.evict_lru()?;
310        }
311
312        Ok(())
313    }
314
315    /// Get cache insights
316    pub fn get_insights(&self) -> QueryCacheInsights {
317        let total_queries = self.storage.len();
318        let expired_queries = self
319            .storage
320            .values()
321            .filter(|e| e.metadata.is_expired())
322            .count();
323
324        let avg_results_per_query = if total_queries > 0 {
325            self.storage
326                .values()
327                .map(|e| e.results.len())
328                .sum::<usize>() as f32
329                / total_queries as f32
330        } else {
331            0.0
332        };
333
334        let top_patterns: Vec<String> = self
335            .query_patterns
336            .iter()
337            .filter(|p| p.match_count > 1)
338            .take(5)
339            .map(|p| p.template.clone())
340            .collect();
341
342        QueryCacheInsights {
343            total_queries,
344            expired_queries,
345            avg_results_per_query,
346            top_patterns,
347            similarity_threshold: self.config.similarity_threshold,
348        }
349    }
350}
351
352impl Cache<String, QueryCacheEntry> for QueryCache {
353    fn get(&self, key: &String) -> Option<QueryCacheEntry> {
354        self.get_results(key)
355    }
356
357    fn put(&mut self, key: String, value: QueryCacheEntry) -> RragResult<()> {
358        if self.storage.len() >= self.config.max_size {
359            self.evict_entry()?;
360        }
361
362        let normalized = self.normalize_query(&key);
363        self.normalized_queries.insert(normalized, key.clone());
364        self.storage.insert(key, value);
365        Ok(())
366    }
367
368    fn remove(&mut self, key: &String) -> Option<QueryCacheEntry> {
369        let entry = self.storage.remove(key);
370
371        // Remove from normalized queries
372        let normalized = self.normalize_query(key);
373        self.normalized_queries.remove(&normalized);
374
375        // Remove from access stats
376        self.access_stats.remove(key);
377
378        entry
379    }
380
381    fn contains(&self, key: &String) -> bool {
382        self.storage.contains_key(key)
383            && !self
384                .storage
385                .get(key)
386                .map_or(true, |e| e.metadata.is_expired())
387    }
388
389    fn clear(&mut self) {
390        self.storage.clear();
391        self.normalized_queries.clear();
392        self.query_patterns.clear();
393        self.access_stats.clear();
394        self.stats = CacheStats::default();
395    }
396
397    fn size(&self) -> usize {
398        self.storage
399            .values()
400            .filter(|e| !e.metadata.is_expired())
401            .count()
402    }
403
404    fn stats(&self) -> CacheStats {
405        self.stats.clone()
406    }
407}
408
409/// Query cache insights
410#[derive(Debug, Clone)]
411pub struct QueryCacheInsights {
412    /// Total cached queries
413    pub total_queries: usize,
414
415    /// Number of expired queries
416    pub expired_queries: usize,
417
418    /// Average results per query
419    pub avg_results_per_query: f32,
420
421    /// Top query patterns
422    pub top_patterns: Vec<String>,
423
424    /// Configured similarity threshold
425    pub similarity_threshold: f32,
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    fn create_test_config() -> QueryCacheConfig {
433        QueryCacheConfig {
434            enabled: true,
435            max_size: 100,
436            ttl: Duration::from_secs(3600),
437            eviction_policy: super::super::EvictionPolicy::LRU,
438            similarity_threshold: 0.8,
439        }
440    }
441
442    fn create_test_results() -> Vec<CachedSearchResult> {
443        vec![CachedSearchResult {
444            document_id: "doc1".to_string(),
445            content: "test content".to_string(),
446            score: 0.9,
447            rank: 0,
448            metadata: HashMap::new(),
449        }]
450    }
451
452    #[test]
453    fn test_query_cache_creation() {
454        let config = create_test_config();
455        let cache = QueryCache::new(config).unwrap();
456
457        assert_eq!(cache.size(), 0);
458        assert_eq!(cache.query_patterns.len(), 0);
459    }
460
461    #[test]
462    fn test_basic_caching() {
463        let config = create_test_config();
464        let mut cache = QueryCache::new(config).unwrap();
465
466        let query = "test query".to_string();
467        let results = create_test_results();
468
469        cache
470            .cache_results(query.clone(), results.clone(), None, "hash123".to_string())
471            .unwrap();
472
473        assert_eq!(cache.size(), 1);
474
475        let cached = cache.get_results(&query);
476        assert!(cached.is_some());
477        assert_eq!(cached.unwrap().results.len(), 1);
478    }
479
480    #[test]
481    fn test_query_normalization() {
482        let config = create_test_config();
483        let cache = QueryCache::new(config).unwrap();
484
485        let query1 = "  What is   Rust?  ";
486        let query2 = "what is rust";
487        let query3 = "What is Rust???";
488
489        let norm1 = cache.normalize_query(query1);
490        let norm2 = cache.normalize_query(query2);
491        let norm3 = cache.normalize_query(query3);
492
493        assert_eq!(norm1, norm2);
494        assert_eq!(norm2, norm3);
495    }
496
497    #[test]
498    fn test_similarity_matching() {
499        let config = create_test_config();
500        let mut cache = QueryCache::new(config).unwrap();
501
502        let query1 = "how to learn rust programming".to_string();
503        let results = create_test_results();
504
505        cache
506            .cache_results(query1.clone(), results.clone(), None, "hash1".to_string())
507            .unwrap();
508
509        // Similar query should find cached results
510        let query2 = "learn rust programming how to";
511        let cached = cache.get_results(query2);
512        assert!(cached.is_some());
513    }
514
515    #[test]
516    fn test_pattern_extraction() {
517        let config = create_test_config();
518        let cache = QueryCache::new(config).unwrap();
519
520        let query1 = "get user 123 details";
521        let query2 = "get user 456 details";
522
523        let pattern1 = cache.extract_pattern(query1);
524        let pattern2 = cache.extract_pattern(query2);
525
526        assert_eq!(pattern1, pattern2);
527        assert!(pattern1.contains("{NUM}"));
528    }
529
530    #[test]
531    fn test_eviction() {
532        let mut config = create_test_config();
533        config.max_size = 2;
534        let mut cache = QueryCache::new(config).unwrap();
535
536        let results = create_test_results();
537
538        cache
539            .cache_results(
540                "query1".to_string(),
541                results.clone(),
542                None,
543                "h1".to_string(),
544            )
545            .unwrap();
546        cache
547            .cache_results(
548                "query2".to_string(),
549                results.clone(),
550                None,
551                "h2".to_string(),
552            )
553            .unwrap();
554
555        assert_eq!(cache.size(), 2);
556
557        // This should trigger eviction
558        cache
559            .cache_results(
560                "query3".to_string(),
561                results.clone(),
562                None,
563                "h3".to_string(),
564            )
565            .unwrap();
566
567        assert_eq!(cache.size(), 2);
568        assert_eq!(cache.stats.evictions, 1);
569    }
570}