Skip to main content

engram/search/
result_cache.rs

1//! Search Result Caching with Adaptive Thresholds (Phase 4 - ENG-36)
2//!
3//! Provides caching for search results with:
4//! - Similarity-based cache lookup (not just exact query match)
5//! - Adaptive threshold adjustment based on feedback
6//! - TTL-based expiration
7//! - Cache invalidation on memory changes
8
9use crate::types::{MemoryType, SearchResult};
10use dashmap::DashMap;
11use serde::{Deserialize, Serialize};
12use std::hash::{Hash, Hasher};
13use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17/// Filter parameters that affect cache key generation
18#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
19pub struct CacheFilterParams {
20    pub workspace: Option<String>,
21    pub tier: Option<String>,
22    pub memory_types: Option<Vec<MemoryType>>,
23    pub include_archived: bool,
24    pub include_transcripts: bool,
25    pub tags: Option<Vec<String>>,
26}
27
28/// A cached search result entry
29#[derive(Debug)]
30pub struct CachedSearchResult {
31    /// Hash of the original query
32    pub query_hash: u64,
33    /// The query embedding (for similarity matching)
34    pub query_embedding: Option<Vec<f32>>,
35    /// Filter parameters used for this search
36    pub filter_params: CacheFilterParams,
37    /// The cached results
38    pub results: Vec<SearchResult>,
39    /// When this entry was created
40    pub created_at: Instant,
41    /// Number of times this cache entry was hit
42    pub hit_count: AtomicU64,
43    /// Feedback score (positive = good results, negative = bad)
44    pub feedback_score: AtomicI64,
45}
46
47impl CachedSearchResult {
48    pub fn new(
49        query_hash: u64,
50        query_embedding: Option<Vec<f32>>,
51        filter_params: CacheFilterParams,
52        results: Vec<SearchResult>,
53    ) -> Self {
54        Self {
55            query_hash,
56            query_embedding,
57            filter_params,
58            results,
59            created_at: Instant::now(),
60            hit_count: AtomicU64::new(0),
61            feedback_score: AtomicI64::new(0),
62        }
63    }
64
65    /// Check if this entry is expired
66    pub fn is_expired(&self, ttl: Duration) -> bool {
67        self.created_at.elapsed() > ttl
68    }
69
70    /// Record a cache hit
71    pub fn record_hit(&self) {
72        self.hit_count.fetch_add(1, Ordering::Relaxed);
73    }
74
75    /// Record feedback (positive or negative)
76    pub fn record_feedback(&self, positive: bool) {
77        if positive {
78            self.feedback_score.fetch_add(1, Ordering::Relaxed);
79        } else {
80            self.feedback_score.fetch_sub(1, Ordering::Relaxed);
81        }
82    }
83}
84
85/// Configuration for the adaptive cache
86#[derive(Debug, Clone)]
87pub struct AdaptiveCacheConfig {
88    /// Base similarity threshold for cache hits (default: 0.92)
89    pub similarity_threshold: f32,
90    /// Minimum similarity threshold (floor: 0.85)
91    pub min_threshold: f32,
92    /// Maximum similarity threshold (ceiling: 0.98)
93    pub max_threshold: f32,
94    /// Time-to-live for cache entries (default: 5 minutes)
95    pub ttl_seconds: u64,
96    /// Maximum number of cache entries (default: 1000)
97    pub max_entries: usize,
98    /// Enable adaptive threshold adjustment
99    pub adaptive_enabled: bool,
100}
101
102impl Default for AdaptiveCacheConfig {
103    fn default() -> Self {
104        Self {
105            similarity_threshold: 0.92,
106            min_threshold: 0.85,
107            max_threshold: 0.98,
108            ttl_seconds: 300, // 5 minutes
109            max_entries: 1000,
110            adaptive_enabled: true,
111        }
112    }
113}
114
115/// Search result cache with adaptive thresholds
116pub struct SearchResultCache {
117    /// Cached entries keyed by cache key (query_hash + filter_hash)
118    entries: DashMap<String, Arc<CachedSearchResult>>,
119    /// Configuration
120    config: AdaptiveCacheConfig,
121    /// Current adaptive threshold
122    current_threshold: std::sync::atomic::AtomicU32,
123    /// Cache statistics
124    stats: CacheStats,
125}
126
127/// Cache statistics
128#[derive(Debug, Default)]
129pub struct CacheStats {
130    pub hits: AtomicU64,
131    pub misses: AtomicU64,
132    pub invalidations: AtomicU64,
133    pub evictions: AtomicU64,
134}
135
136impl CacheStats {
137    pub fn hit_rate(&self) -> f64 {
138        let hits = self.hits.load(Ordering::Relaxed);
139        let misses = self.misses.load(Ordering::Relaxed);
140        let total = hits + misses;
141        if total == 0 {
142            0.0
143        } else {
144            hits as f64 / total as f64
145        }
146    }
147}
148
149/// Cache lookup result
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct CacheStatsResponse {
152    pub entries: usize,
153    pub hits: u64,
154    pub misses: u64,
155    pub hit_rate: f64,
156    pub invalidations: u64,
157    pub evictions: u64,
158    pub current_threshold: f32,
159    pub ttl_seconds: u64,
160}
161
162impl SearchResultCache {
163    pub fn new(config: AdaptiveCacheConfig) -> Self {
164        let threshold_bits = config.similarity_threshold.to_bits();
165        Self {
166            entries: DashMap::new(),
167            current_threshold: std::sync::atomic::AtomicU32::new(threshold_bits),
168            config,
169            stats: CacheStats::default(),
170        }
171    }
172
173    /// Get current similarity threshold
174    pub fn current_threshold(&self) -> f32 {
175        f32::from_bits(self.current_threshold.load(Ordering::Relaxed))
176    }
177
178    /// Generate cache key from query hash and filter params
179    fn cache_key(query_hash: u64, filters: &CacheFilterParams) -> String {
180        let mut hasher = std::collections::hash_map::DefaultHasher::new();
181        query_hash.hash(&mut hasher);
182        filters.hash(&mut hasher);
183        format!("{:016x}", hasher.finish())
184    }
185
186    /// Hash a query string
187    pub fn hash_query(query: &str) -> u64 {
188        let mut hasher = std::collections::hash_map::DefaultHasher::new();
189        query.to_lowercase().trim().hash(&mut hasher);
190        hasher.finish()
191    }
192
193    /// Calculate cosine similarity between two embeddings
194    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
195        if a.len() != b.len() || a.is_empty() {
196            return 0.0;
197        }
198
199        let mut dot = 0.0f32;
200        let mut norm_a = 0.0f32;
201        let mut norm_b = 0.0f32;
202
203        for (x, y) in a.iter().zip(b.iter()) {
204            dot += x * y;
205            norm_a += x * x;
206            norm_b += y * y;
207        }
208
209        if norm_a == 0.0 || norm_b == 0.0 {
210            return 0.0;
211        }
212
213        dot / (norm_a.sqrt() * norm_b.sqrt())
214    }
215
216    /// Try to get cached results for a query
217    pub fn get(
218        &self,
219        query: &str,
220        query_embedding: Option<&[f32]>,
221        filters: &CacheFilterParams,
222    ) -> Option<Vec<SearchResult>> {
223        let query_hash = Self::hash_query(query);
224        let cache_key = Self::cache_key(query_hash, filters);
225
226        // First try exact match
227        if let Some(entry) = self.entries.get(&cache_key) {
228            if !entry.is_expired(Duration::from_secs(self.config.ttl_seconds)) {
229                entry.record_hit();
230                self.stats.hits.fetch_add(1, Ordering::Relaxed);
231                return Some(entry.results.clone());
232            } else {
233                // Remove expired entry
234                drop(entry);
235                self.entries.remove(&cache_key);
236            }
237        }
238
239        // Try similarity-based lookup if we have an embedding
240        if let Some(embedding) = query_embedding {
241            let threshold = self.current_threshold();
242
243            for entry in self.entries.iter() {
244                if entry.filter_params != *filters {
245                    continue;
246                }
247
248                if entry.is_expired(Duration::from_secs(self.config.ttl_seconds)) {
249                    continue;
250                }
251
252                if let Some(ref cached_embedding) = entry.query_embedding {
253                    let similarity = Self::cosine_similarity(embedding, cached_embedding);
254                    if similarity >= threshold {
255                        entry.record_hit();
256                        self.stats.hits.fetch_add(1, Ordering::Relaxed);
257                        return Some(entry.results.clone());
258                    }
259                }
260            }
261        }
262
263        self.stats.misses.fetch_add(1, Ordering::Relaxed);
264        None
265    }
266
267    /// Store search results in cache
268    pub fn put(
269        &self,
270        query: &str,
271        query_embedding: Option<Vec<f32>>,
272        filters: CacheFilterParams,
273        results: Vec<SearchResult>,
274    ) {
275        let query_hash = Self::hash_query(query);
276        let cache_key = Self::cache_key(query_hash, &filters);
277
278        // Evict if at capacity
279        if self.entries.len() >= self.config.max_entries {
280            self.evict_oldest();
281        }
282
283        let entry = CachedSearchResult::new(query_hash, query_embedding, filters, results);
284        self.entries.insert(cache_key, Arc::new(entry));
285    }
286
287    /// Evict the oldest entry
288    fn evict_oldest(&self) {
289        let mut oldest_key: Option<String> = None;
290        let mut oldest_time = Instant::now();
291
292        for entry in self.entries.iter() {
293            if entry.created_at < oldest_time {
294                oldest_time = entry.created_at;
295                oldest_key = Some(entry.key().clone());
296            }
297        }
298
299        if let Some(key) = oldest_key {
300            self.entries.remove(&key);
301            self.stats.evictions.fetch_add(1, Ordering::Relaxed);
302        }
303    }
304
305    /// Remove expired entries
306    pub fn remove_expired(&self) {
307        let ttl = Duration::from_secs(self.config.ttl_seconds);
308        self.entries.retain(|_, v| !v.is_expired(ttl));
309    }
310
311    /// Invalidate cache entries for a specific workspace
312    pub fn invalidate_for_workspace(&self, workspace: Option<&str>) {
313        self.entries.retain(|_, v| {
314            let should_keep = v.filter_params.workspace.as_deref() != workspace;
315            if !should_keep {
316                self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
317            }
318            should_keep
319        });
320    }
321
322    /// Invalidate cache entries that might contain a specific memory
323    pub fn invalidate_for_memory(&self, memory_id: i64) {
324        // Since we don't track which memories are in which cache entries,
325        // we invalidate entries that could potentially contain this memory.
326        // For now, we do a simple approach: invalidate all entries older than
327        // a certain threshold or just clear all.
328        // A more sophisticated approach would track memory IDs in each entry.
329        self.entries.retain(|_, v| {
330            // Check if any result contains this memory ID
331            let contains_memory = v.results.iter().any(|r| r.memory.id == memory_id);
332            if contains_memory {
333                self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
334            }
335            !contains_memory
336        });
337    }
338
339    /// Clear all cache entries
340    pub fn clear(&self) {
341        let count = self.entries.len();
342        self.entries.clear();
343        self.stats
344            .invalidations
345            .fetch_add(count as u64, Ordering::Relaxed);
346    }
347
348    /// Record feedback for a query (adjusts adaptive threshold)
349    pub fn record_feedback(&self, query: &str, filters: &CacheFilterParams, positive: bool) {
350        let query_hash = Self::hash_query(query);
351        let cache_key = Self::cache_key(query_hash, filters);
352
353        if let Some(entry) = self.entries.get(&cache_key) {
354            entry.record_feedback(positive);
355        }
356
357        // Adjust threshold based on feedback
358        if self.config.adaptive_enabled {
359            self.adjust_threshold(positive);
360        }
361    }
362
363    /// Adjust the similarity threshold based on feedback
364    fn adjust_threshold(&self, positive: bool) {
365        let current = self.current_threshold();
366        let adjustment = 0.01; // 1% adjustment per feedback
367
368        let new_threshold = if positive {
369            // Positive feedback: can be more lenient (lower threshold)
370            (current - adjustment).max(self.config.min_threshold)
371        } else {
372            // Negative feedback: be more strict (higher threshold)
373            (current + adjustment).min(self.config.max_threshold)
374        };
375
376        self.current_threshold
377            .store(new_threshold.to_bits(), Ordering::Relaxed);
378    }
379
380    /// Get cache statistics
381    pub fn stats(&self) -> CacheStatsResponse {
382        CacheStatsResponse {
383            entries: self.entries.len(),
384            hits: self.stats.hits.load(Ordering::Relaxed),
385            misses: self.stats.misses.load(Ordering::Relaxed),
386            hit_rate: self.stats.hit_rate(),
387            invalidations: self.stats.invalidations.load(Ordering::Relaxed),
388            evictions: self.stats.evictions.load(Ordering::Relaxed),
389            current_threshold: self.current_threshold(),
390            ttl_seconds: self.config.ttl_seconds,
391        }
392    }
393
394    /// Start background expiration worker (call from main thread)
395    pub fn start_expiration_worker(cache: Arc<Self>, interval_secs: u64) {
396        std::thread::spawn(move || loop {
397            std::thread::sleep(Duration::from_secs(interval_secs));
398            cache.remove_expired();
399        });
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use crate::types::MemoryType;
407
408    fn make_test_memory(id: i64, content: &str) -> crate::types::Memory {
409        crate::types::Memory {
410            id,
411            content: content.to_string(),
412            memory_type: MemoryType::Note,
413            importance: 0.5,
414            tags: vec![],
415            access_count: 0,
416            created_at: chrono::Utc::now(),
417            updated_at: chrono::Utc::now(),
418            last_accessed_at: None,
419            owner_id: None,
420            visibility: Default::default(),
421            version: 1,
422            has_embedding: false,
423            metadata: Default::default(),
424            scope: crate::types::MemoryScope::Global,
425            workspace: "default".to_string(),
426            tier: crate::types::MemoryTier::Permanent,
427            expires_at: None,
428            content_hash: None,
429            event_time: None,
430            event_duration_seconds: None,
431            trigger_pattern: None,
432            procedure_success_count: 0,
433            procedure_failure_count: 0,
434            summary_of_id: None,
435            lifecycle_state: crate::types::LifecycleState::Active,
436        }
437    }
438
439    fn make_test_result(id: i64, content: &str, score: f32) -> SearchResult {
440        SearchResult {
441            memory: make_test_memory(id, content),
442            score,
443            match_info: crate::types::MatchInfo {
444                strategy: crate::types::SearchStrategy::Hybrid,
445                matched_terms: vec![],
446                highlights: vec![],
447                semantic_score: None,
448                keyword_score: Some(score),
449            },
450        }
451    }
452
453    #[test]
454    fn test_cache_put_get() {
455        let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
456        let results = vec![make_test_result(1, "test content", 0.9)];
457
458        cache.put(
459            "test query",
460            None,
461            CacheFilterParams::default(),
462            results.clone(),
463        );
464
465        let cached = cache.get("test query", None, &CacheFilterParams::default());
466        assert!(cached.is_some());
467        assert_eq!(cached.unwrap().len(), 1);
468    }
469
470    #[test]
471    fn test_cache_miss() {
472        let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
473
474        let cached = cache.get("nonexistent", None, &CacheFilterParams::default());
475        assert!(cached.is_none());
476    }
477
478    #[test]
479    fn test_cache_invalidation() {
480        let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
481        let results = vec![make_test_result(1, "test", 0.9)];
482
483        cache.put("query", None, CacheFilterParams::default(), results);
484
485        // Verify it's cached
486        assert!(cache
487            .get("query", None, &CacheFilterParams::default())
488            .is_some());
489
490        // Invalidate for memory ID 1
491        cache.invalidate_for_memory(1);
492
493        // Should be gone
494        assert!(cache
495            .get("query", None, &CacheFilterParams::default())
496            .is_none());
497    }
498
499    #[test]
500    fn test_different_filters_different_cache() {
501        let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
502        let results1 = vec![make_test_result(1, "result 1", 0.9)];
503        let results2 = vec![make_test_result(2, "result 2", 0.8)];
504
505        let filters1 = CacheFilterParams {
506            workspace: Some("ws1".to_string()),
507            ..Default::default()
508        };
509        let filters2 = CacheFilterParams {
510            workspace: Some("ws2".to_string()),
511            ..Default::default()
512        };
513
514        cache.put("query", None, filters1.clone(), results1);
515        cache.put("query", None, filters2.clone(), results2);
516
517        let cached1 = cache.get("query", None, &filters1);
518        let cached2 = cache.get("query", None, &filters2);
519
520        assert!(cached1.is_some());
521        assert!(cached2.is_some());
522        assert_eq!(cached1.unwrap()[0].memory.id, 1);
523        assert_eq!(cached2.unwrap()[0].memory.id, 2);
524    }
525
526    #[test]
527    fn test_similarity_lookup() {
528        let cache = SearchResultCache::new(AdaptiveCacheConfig {
529            similarity_threshold: 0.9,
530            ..Default::default()
531        });
532
533        let embedding = vec![1.0, 0.0, 0.0];
534        let results = vec![make_test_result(1, "test", 0.9)];
535
536        cache.put(
537            "original query",
538            Some(embedding.clone()),
539            CacheFilterParams::default(),
540            results,
541        );
542
543        // Same embedding should hit
544        let cached = cache.get(
545            "different query",
546            Some(&embedding),
547            &CacheFilterParams::default(),
548        );
549        assert!(cached.is_some());
550
551        // Very similar embedding should hit
552        let similar = vec![0.99, 0.1, 0.0];
553        let cached = cache.get(
554            "another query",
555            Some(&similar),
556            &CacheFilterParams::default(),
557        );
558        assert!(cached.is_some());
559
560        // Different embedding should miss
561        let different = vec![0.0, 1.0, 0.0];
562        let cached = cache.get(
563            "yet another",
564            Some(&different),
565            &CacheFilterParams::default(),
566        );
567        assert!(cached.is_none());
568    }
569
570    #[test]
571    fn test_stats() {
572        let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
573        let results = vec![make_test_result(1, "test", 0.9)];
574
575        // Miss
576        cache.get("query", None, &CacheFilterParams::default());
577
578        // Put
579        cache.put("query", None, CacheFilterParams::default(), results);
580
581        // Hit
582        cache.get("query", None, &CacheFilterParams::default());
583        cache.get("query", None, &CacheFilterParams::default());
584
585        let stats = cache.stats();
586        assert_eq!(stats.entries, 1);
587        assert_eq!(stats.misses, 1);
588        assert_eq!(stats.hits, 2);
589        assert!(stats.hit_rate > 0.6);
590    }
591}