Skip to main content

engram/search/
semantic_cache.rs

1//! Semantic Query Cache — RML-1229
2//!
3//! Caches search results and matches queries by embedding similarity rather
4//! than exact string match. Complements `result_cache.rs` (which caches by
5//! exact query hash). A new query is a cache hit when its embedding is
6//! sufficiently close (cosine similarity ≥ threshold) to a previously cached
7//! query.
8
9use dashmap::DashMap;
10use serde_json::Value;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::{Duration, Instant};
13
14// ---------------------------------------------------------------------------
15// Configuration
16// ---------------------------------------------------------------------------
17
18/// Configuration for the semantic query cache.
19#[derive(Debug, Clone)]
20pub struct SemanticCacheConfig {
21    /// Minimum cosine similarity to consider a cache hit (default 0.92).
22    pub similarity_threshold: f32,
23    /// Maximum number of cached entries (default 1000).
24    pub max_entries: usize,
25    /// Default TTL in seconds (default 300 = 5 minutes).
26    pub default_ttl_secs: u64,
27}
28
29impl Default for SemanticCacheConfig {
30    fn default() -> Self {
31        Self {
32            similarity_threshold: 0.92,
33            max_entries: 1000,
34            default_ttl_secs: 300,
35        }
36    }
37}
38
39// ---------------------------------------------------------------------------
40// Internal entry
41// ---------------------------------------------------------------------------
42
43struct CacheEntry {
44    query_embedding: Vec<f32>,
45    /// Stored for observability / future logging; not read in hot paths.
46    #[allow(dead_code)]
47    query_text: String,
48    results: Value,
49    created_at: Instant,
50    ttl: Duration,
51    hit_count: u64,
52}
53
54impl CacheEntry {
55    fn is_expired(&self) -> bool {
56        self.created_at.elapsed() > self.ttl
57    }
58}
59
60// ---------------------------------------------------------------------------
61// Public stats
62// ---------------------------------------------------------------------------
63
64/// Snapshot of cache statistics.
65#[derive(Debug, Clone, Default)]
66pub struct SemanticCacheStats {
67    pub entries: usize,
68    pub hits: u64,
69    pub misses: u64,
70    pub evictions: u64,
71    pub invalidations: u64,
72}
73
74// ---------------------------------------------------------------------------
75// SemanticCache
76// ---------------------------------------------------------------------------
77
78/// A cache that matches queries by embedding similarity rather than exact
79/// string match.
80pub struct SemanticCache {
81    /// Live entries. Key = `embedding_hash(embedding)`.
82    entries: DashMap<u64, CacheEntry>,
83    config: SemanticCacheConfig,
84    hits: AtomicU64,
85    misses: AtomicU64,
86    evictions: AtomicU64,
87    invalidations: AtomicU64,
88}
89
90impl SemanticCache {
91    /// Create a new cache with the supplied configuration.
92    pub fn new(config: SemanticCacheConfig) -> Self {
93        Self {
94            entries: DashMap::new(),
95            config,
96            hits: AtomicU64::new(0),
97            misses: AtomicU64::new(0),
98            evictions: AtomicU64::new(0),
99            invalidations: AtomicU64::new(0),
100        }
101    }
102
103    // -----------------------------------------------------------------------
104    // Public API
105    // -----------------------------------------------------------------------
106
107    /// Try to find a cached result for `query_embedding`.
108    ///
109    /// Performs a linear scan over all entries comparing cosine similarity.
110    /// Returns the result of the best-matching non-expired entry whose
111    /// similarity is at or above the configured threshold.
112    pub fn get(&self, query_embedding: &[f32]) -> Option<Value> {
113        let threshold = self.config.similarity_threshold;
114
115        let mut best_similarity = -1.0_f32;
116        let mut best_result: Option<Value> = None;
117        let mut best_key: Option<u64> = None;
118
119        for mut entry_ref in self.entries.iter_mut() {
120            if entry_ref.is_expired() {
121                continue;
122            }
123
124            let sim = cosine_similarity(query_embedding, &entry_ref.query_embedding);
125            if sim >= threshold && sim > best_similarity {
126                best_similarity = sim;
127                best_result = Some(entry_ref.results.clone());
128                best_key = Some(*entry_ref.key());
129                // Increment hit_count in place while we hold the write lock.
130                entry_ref.hit_count = entry_ref.hit_count.saturating_add(1);
131            }
132        }
133
134        // If we found the best key but had to iterate past it to finish, we
135        // already incremented hit_count above — nothing more needed.
136        let _ = best_key; // suppress unused warning
137
138        if best_result.is_some() {
139            self.hits.fetch_add(1, Ordering::Relaxed);
140        } else {
141            self.misses.fetch_add(1, Ordering::Relaxed);
142        }
143
144        best_result
145    }
146
147    /// Store `results` under `query_embedding`.
148    ///
149    /// If the cache is at capacity, the oldest entry (by `created_at`) is
150    /// evicted first.
151    pub fn put(&self, query_embedding: Vec<f32>, query_text: String, results: Value) {
152        if self.entries.len() >= self.config.max_entries {
153            self.evict_oldest();
154        }
155
156        let key = embedding_hash(&query_embedding);
157        let entry = CacheEntry {
158            query_embedding,
159            query_text,
160            results,
161            created_at: Instant::now(),
162            ttl: Duration::from_secs(self.config.default_ttl_secs),
163            hit_count: 0,
164        };
165
166        self.entries.insert(key, entry);
167    }
168
169    /// Remove all entries whose result JSON contains `memory_id`.
170    ///
171    /// The JSON is checked for any occurrence of the integer value at any
172    /// position in the document.
173    pub fn invalidate_memory(&self, memory_id: i64) {
174        let target = Value::Number(memory_id.into());
175
176        self.entries.retain(|_, entry| {
177            let contains = json_contains(&entry.results, &target);
178            if contains {
179                self.invalidations.fetch_add(1, Ordering::Relaxed);
180            }
181            !contains
182        });
183    }
184
185    /// Remove all cached entries.
186    pub fn clear(&self) {
187        self.entries.clear();
188    }
189
190    /// Return a snapshot of cache statistics.
191    pub fn stats(&self) -> SemanticCacheStats {
192        SemanticCacheStats {
193            entries: self.entries.len(),
194            hits: self.hits.load(Ordering::Relaxed),
195            misses: self.misses.load(Ordering::Relaxed),
196            evictions: self.evictions.load(Ordering::Relaxed),
197            invalidations: self.invalidations.load(Ordering::Relaxed),
198        }
199    }
200
201    // -----------------------------------------------------------------------
202    // Private helpers
203    // -----------------------------------------------------------------------
204
205    /// Evict the entry with the smallest `created_at`.
206    fn evict_oldest(&self) {
207        // Collect (key, created_at) pairs to find the oldest without holding
208        // any individual shard lock longer than necessary.
209        let mut oldest_key: Option<u64> = None;
210        let mut oldest_time: Option<Instant> = None;
211
212        for entry_ref in self.entries.iter() {
213            let t = entry_ref.created_at;
214            match oldest_time {
215                None => {
216                    oldest_time = Some(t);
217                    oldest_key = Some(*entry_ref.key());
218                }
219                Some(ot) if t < ot => {
220                    oldest_time = Some(t);
221                    oldest_key = Some(*entry_ref.key());
222                }
223                _ => {}
224            }
225        }
226
227        if let Some(key) = oldest_key {
228            self.entries.remove(&key);
229            self.evictions.fetch_add(1, Ordering::Relaxed);
230        }
231    }
232}
233
234// ---------------------------------------------------------------------------
235// Free-standing helpers (pub for testing / benchmarking)
236// ---------------------------------------------------------------------------
237
238/// Compute cosine similarity between two embedding vectors.
239///
240/// Returns `0.0` if either vector is empty or if the lengths differ.
241pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
242    if a.len() != b.len() || a.is_empty() {
243        return 0.0;
244    }
245
246    let mut dot = 0.0_f32;
247    let mut norm_a = 0.0_f32;
248    let mut norm_b = 0.0_f32;
249
250    for (x, y) in a.iter().zip(b.iter()) {
251        dot += x * y;
252        norm_a += x * x;
253        norm_b += y * y;
254    }
255
256    if norm_a == 0.0 || norm_b == 0.0 {
257        return 0.0;
258    }
259
260    dot / (norm_a.sqrt() * norm_b.sqrt())
261}
262
263/// Derive a `u64` bucket key from the first 8 floats of an embedding.
264///
265/// This is used as the DashMap key for `O(1)` insertion. `get` always does a
266/// full linear scan for semantic matching.
267pub fn embedding_hash(embedding: &[f32]) -> u64 {
268    let mut hash: u64 = 0xcbf29ce484222325; // FNV-1a offset basis
269    let take = embedding.len().min(8);
270    for &f in &embedding[..take] {
271        let bytes = f.to_le_bytes();
272        for byte in bytes {
273            hash ^= byte as u64;
274            hash = hash.wrapping_mul(0x100000001b3); // FNV-1a prime
275        }
276    }
277    hash
278}
279
280/// Recursively check whether `json` contains `target` anywhere in its tree.
281fn json_contains(json: &Value, target: &Value) -> bool {
282    if json == target {
283        return true;
284    }
285    match json {
286        Value::Array(arr) => arr.iter().any(|v| json_contains(v, target)),
287        Value::Object(map) => map.values().any(|v| json_contains(v, target)),
288        _ => false,
289    }
290}
291
292// ---------------------------------------------------------------------------
293// Tests
294// ---------------------------------------------------------------------------
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use serde_json::json;
300    use std::thread;
301    use std::time::Duration;
302
303    fn default_cache() -> SemanticCache {
304        SemanticCache::new(SemanticCacheConfig::default())
305    }
306
307    fn unit_vec(dim: usize, hot: usize) -> Vec<f32> {
308        let mut v = vec![0.0_f32; dim];
309        v[hot] = 1.0;
310        v
311    }
312
313    // -----------------------------------------------------------------------
314
315    #[test]
316    fn test_cache_hit() {
317        let cache = default_cache();
318        let emb = unit_vec(4, 0);
319        let results = json!({"memories": [{"id": 42}]});
320
321        cache.put(emb.clone(), "query".into(), results.clone());
322        let got = cache.get(&emb);
323        assert_eq!(got, Some(results));
324    }
325
326    #[test]
327    fn test_cache_miss_below_threshold() {
328        let cache = default_cache();
329        // Perpendicular vectors: similarity = 0.0
330        let emb_stored = unit_vec(4, 0);
331        let emb_query = unit_vec(4, 1);
332
333        cache.put(emb_stored, "query A".into(), json!({"memories": []}));
334        let got = cache.get(&emb_query);
335        assert!(got.is_none(), "perpendicular vectors must not hit");
336    }
337
338    #[test]
339    fn test_cache_hit_similar() {
340        // Slightly perturb the stored embedding; similarity must remain ≥ 0.92.
341        let cache = SemanticCache::new(SemanticCacheConfig {
342            similarity_threshold: 0.92,
343            ..Default::default()
344        });
345
346        let emb_stored = vec![1.0_f32, 0.0, 0.0, 0.0];
347        let results = json!({"memories": [{"id": 7}]});
348        cache.put(emb_stored, "original".into(), results.clone());
349
350        // Small perturbation: similarity ≈ 0.9998
351        let emb_similar = vec![1.0_f32, 0.01, 0.0, 0.0];
352        let got = cache.get(&emb_similar);
353        assert!(got.is_some(), "similar embedding must be a cache hit");
354    }
355
356    #[test]
357    fn test_ttl_expiration() {
358        let cache = SemanticCache::new(SemanticCacheConfig {
359            default_ttl_secs: 0, // expires immediately
360            ..Default::default()
361        });
362
363        let emb = unit_vec(4, 2);
364        cache.put(emb.clone(), "q".into(), json!({"ok": true}));
365
366        // Even 1 ms is enough for `created_at.elapsed() > Duration::ZERO`.
367        thread::sleep(Duration::from_millis(5));
368        let got = cache.get(&emb);
369        assert!(got.is_none(), "entry should have expired");
370    }
371
372    #[test]
373    fn test_invalidate_memory() {
374        let cache = default_cache();
375        let emb = unit_vec(4, 0);
376        let results = json!([{"id": 99, "content": "hello"}]);
377
378        cache.put(emb.clone(), "q".into(), results);
379
380        // Confirm it's cached.
381        assert!(cache.get(&emb).is_some());
382
383        // Invalidate by memory id 99.
384        cache.invalidate_memory(99);
385        assert!(
386            cache.get(&emb).is_none(),
387            "entry containing id 99 must be removed"
388        );
389
390        let stats = cache.stats();
391        assert_eq!(stats.invalidations, 1);
392    }
393
394    #[test]
395    fn test_invalidate_memory_does_not_remove_unrelated() {
396        let cache = default_cache();
397        let emb1 = unit_vec(4, 0);
398        let emb2 = unit_vec(4, 1);
399
400        cache.put(emb1.clone(), "q1".into(), json!([{"id": 1}]));
401        cache.put(emb2.clone(), "q2".into(), json!([{"id": 2}]));
402
403        cache.invalidate_memory(1);
404
405        assert!(cache.get(&emb1).is_none());
406        assert!(cache.get(&emb2).is_some());
407    }
408
409    #[test]
410    fn test_clear() {
411        let cache = default_cache();
412        cache.put(unit_vec(4, 0), "a".into(), json!(1));
413        cache.put(unit_vec(4, 1), "b".into(), json!(2));
414
415        cache.clear();
416        assert_eq!(cache.stats().entries, 0);
417    }
418
419    #[test]
420    fn test_stats() {
421        let cache = default_cache();
422        let emb = unit_vec(4, 3);
423
424        // Miss
425        cache.get(&emb);
426
427        // Put + hit twice
428        cache.put(emb.clone(), "q".into(), json!({"x": 1}));
429        cache.get(&emb);
430        cache.get(&emb);
431
432        let stats = cache.stats();
433        assert_eq!(stats.hits, 2, "expected 2 hits");
434        assert_eq!(stats.misses, 1, "expected 1 miss");
435        assert_eq!(stats.entries, 1);
436    }
437
438    #[test]
439    fn test_capacity_eviction() {
440        let cache = SemanticCache::new(SemanticCacheConfig {
441            max_entries: 2,
442            ..Default::default()
443        });
444
445        let emb0 = unit_vec(4, 0);
446        let emb1 = unit_vec(4, 1);
447        let emb2 = unit_vec(4, 2);
448
449        cache.put(emb0.clone(), "first".into(), json!("first"));
450        // Small sleep to ensure distinct Instants (monotonic clock granularity).
451        thread::sleep(Duration::from_millis(1));
452        cache.put(emb1.clone(), "second".into(), json!("second"));
453        thread::sleep(Duration::from_millis(1));
454        // Third insert must evict the oldest (emb0 / "first").
455        cache.put(emb2.clone(), "third".into(), json!("third"));
456
457        assert_eq!(cache.stats().entries, 2);
458        assert_eq!(cache.stats().evictions, 1);
459
460        // "first" should be gone, "second" and "third" should survive.
461        assert!(
462            cache.get(&emb0).is_none(),
463            "oldest entry must have been evicted"
464        );
465        assert!(cache.get(&emb1).is_some());
466        assert!(cache.get(&emb2).is_some());
467    }
468
469    // -----------------------------------------------------------------------
470    // cosine_similarity corner cases
471    // -----------------------------------------------------------------------
472
473    #[test]
474    fn test_cosine_similarity_identical() {
475        let v = vec![1.0_f32, 2.0, 3.0];
476        let sim = cosine_similarity(&v, &v);
477        assert!((sim - 1.0).abs() < 1e-6);
478    }
479
480    #[test]
481    fn test_cosine_similarity_orthogonal() {
482        let a = vec![1.0_f32, 0.0];
483        let b = vec![0.0_f32, 1.0];
484        assert_eq!(cosine_similarity(&a, &b), 0.0);
485    }
486
487    #[test]
488    fn test_cosine_similarity_empty() {
489        assert_eq!(cosine_similarity(&[], &[]), 0.0);
490    }
491
492    #[test]
493    fn test_cosine_similarity_length_mismatch() {
494        let a = vec![1.0_f32, 0.0];
495        let b = vec![1.0_f32];
496        assert_eq!(cosine_similarity(&a, &b), 0.0);
497    }
498}