Skip to main content

reddb_server/storage/ml/
semantic_cache.rs

1//! Semantic cache for LLM responses (ML Feature 3, MVP).
2//!
3//! Caches `(prompt, response)` pairs keyed by the prompt's embedding
4//! vector. A lookup returns the cached response when any entry has
5//! cosine similarity to the query embedding ≥ the caller's threshold
6//! **and** has not expired.
7//!
8//! This is a linear-scan implementation — fine for caches up to
9//! ~10k entries. A future sprint swaps the scan for the existing
10//! HNSW index when entry counts make that worth the added
11//! complexity. The external surface stays the same.
12//!
13//! Eviction is lazy: `lookup` drops any expired entry it skips
14//! over, and [`Self::evict_expired`] sweeps the whole set on
15//! demand. Bounded size is enforced on insert via LRU-by-age
16//! (oldest entries dropped first once `max_entries` is reached).
17
18use std::sync::{Arc, Mutex};
19
20use super::jobs::now_ms;
21use super::persist::{MlPersistence, MlPersistenceResult};
22use crate::json::{Map, Value as JsonValue};
23
24/// One cached entry.
25#[derive(Debug, Clone)]
26pub struct SemanticCacheEntry {
27    pub prompt: String,
28    pub response: String,
29    pub embedding: Vec<f32>,
30    /// Epoch millis; `0` means "never expires".
31    pub expires_at_ms: u64,
32    /// Epoch millis of last read hit — used for LRU eviction.
33    pub last_hit_ms: u64,
34    /// Epoch millis the entry landed in the cache.
35    pub inserted_at_ms: u64,
36}
37
38impl SemanticCacheEntry {
39    pub fn is_expired_at(&self, now_ms_val: u64) -> bool {
40        self.expires_at_ms != 0 && now_ms_val >= self.expires_at_ms
41    }
42}
43
44/// Compile-time tuneables for a cache instance.
45#[derive(Debug, Clone)]
46pub struct SemanticCacheConfig {
47    /// Cosine similarity threshold above which a candidate counts
48    /// as a hit. Typical values: 0.90–0.98.
49    pub similarity_threshold: f32,
50    /// Default TTL applied to freshly-inserted entries. `0` =
51    /// entries never expire. Callers can still pass a per-insert
52    /// override.
53    pub default_ttl_ms: u64,
54    /// Maximum live entries. Oldest inserted entry is evicted once
55    /// this limit is reached. `0` = unbounded.
56    pub max_entries: usize,
57    /// Persistence namespace suffix. Allows multiple named caches
58    /// to share one `MlPersistence` backend without colliding.
59    pub namespace: String,
60}
61
62impl Default for SemanticCacheConfig {
63    fn default() -> Self {
64        Self {
65            similarity_threshold: 0.95,
66            default_ttl_ms: 24 * 60 * 60 * 1000,
67            max_entries: 10_000,
68            namespace: "default".to_string(),
69        }
70    }
71}
72
73/// Runtime statistics — exposed via `SELECT * FROM ML_CACHE_STATS`
74/// later, and useful in tests now.
75#[derive(Debug, Clone, Default, PartialEq, Eq)]
76pub struct SemanticCacheStats {
77    pub entries: usize,
78    pub hits: u64,
79    pub misses: u64,
80    pub expired_evictions: u64,
81    pub capacity_evictions: u64,
82}
83
84struct Inner {
85    entries: Vec<SemanticCacheEntry>,
86    stats: SemanticCacheStats,
87}
88
89/// The cache itself. Cloning shares state via the inner `Arc`.
90#[derive(Clone)]
91pub struct SemanticCache {
92    inner: Arc<Mutex<Inner>>,
93    config: SemanticCacheConfig,
94    backend: Option<Arc<dyn MlPersistence>>,
95}
96
97impl std::fmt::Debug for SemanticCache {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        f.debug_struct("SemanticCache")
100            .field("namespace", &self.config.namespace)
101            .field("similarity_threshold", &self.config.similarity_threshold)
102            .field("max_entries", &self.config.max_entries)
103            .field("persistent", &self.backend.is_some())
104            .finish()
105    }
106}
107
108impl SemanticCache {
109    /// Build an in-process cache with no durable backend.
110    pub fn new(config: SemanticCacheConfig) -> Self {
111        Self {
112            inner: Arc::new(Mutex::new(Inner {
113                entries: Vec::new(),
114                stats: SemanticCacheStats::default(),
115            })),
116            config,
117            backend: None,
118        }
119    }
120
121    /// Build a cache that persists every mutation to `backend` under
122    /// `cache:{namespace}`. Entries rehydrate on construction.
123    pub fn with_backend(config: SemanticCacheConfig, backend: Arc<dyn MlPersistence>) -> Self {
124        let cache = Self {
125            inner: Arc::new(Mutex::new(Inner {
126                entries: Vec::new(),
127                stats: SemanticCacheStats::default(),
128            })),
129            config,
130            backend: Some(backend),
131        };
132        let _ = cache.load_from_backend();
133        cache
134    }
135
136    fn backend_namespace(&self) -> String {
137        format!("cache:{}", self.config.namespace)
138    }
139
140    fn persist_entry(&self, key: &str, entry: &SemanticCacheEntry) {
141        if let Some(backend) = self.backend.as_ref() {
142            let _ = backend.put(&self.backend_namespace(), key, &encode_entry(entry));
143        }
144    }
145
146    fn forget_entry(&self, key: &str) {
147        if let Some(backend) = self.backend.as_ref() {
148            let _ = backend.delete(&self.backend_namespace(), key);
149        }
150    }
151
152    /// Re-read persisted entries into memory. Malformed rows are
153    /// skipped rather than crashing startup.
154    pub fn load_from_backend(&self) -> MlPersistenceResult<usize> {
155        let Some(backend) = self.backend.as_ref() else {
156            return Ok(0);
157        };
158        let rows = backend.list(&self.backend_namespace())?;
159        let mut loaded = 0usize;
160        let now = now_ms();
161        let mut guard = match self.inner.lock() {
162            Ok(g) => g,
163            Err(p) => p.into_inner(),
164        };
165        guard.entries.clear();
166        for (_, raw) in rows {
167            let Some(entry) = decode_entry(&raw) else {
168                continue;
169            };
170            if entry.is_expired_at(now) {
171                // Skip stale entries rather than loading-then-evicting —
172                // saves a pass. Their rows are left in the backend
173                // and removed on next insert of the same prompt or
174                // on explicit evict.
175                continue;
176            }
177            guard.entries.push(entry);
178            loaded += 1;
179        }
180        guard.stats.entries = guard.entries.len();
181        Ok(loaded)
182    }
183
184    /// Look up by embedding. Returns the cached response on hit;
185    /// updates `last_hit_ms` and hit counter as a side-effect.
186    pub fn lookup(&self, embedding: &[f32]) -> Option<String> {
187        if embedding.is_empty() {
188            return None;
189        }
190        let now = now_ms();
191        let mut guard = match self.inner.lock() {
192            Ok(g) => g,
193            Err(p) => p.into_inner(),
194        };
195        // Drop expired entries in-place so the scan cost stays low.
196        let before = guard.entries.len();
197        guard.entries.retain(|e| !e.is_expired_at(now));
198        let evicted = before - guard.entries.len();
199        guard.stats.expired_evictions += evicted as u64;
200
201        let mut best: Option<(usize, f32)> = None;
202        for (idx, entry) in guard.entries.iter().enumerate() {
203            let score = cosine_similarity(embedding, &entry.embedding);
204            if score >= self.config.similarity_threshold {
205                match best {
206                    Some((_, best_score)) if best_score >= score => {}
207                    _ => best = Some((idx, score)),
208                }
209            }
210        }
211        match best {
212            Some((idx, _)) => {
213                let entry = &mut guard.entries[idx];
214                entry.last_hit_ms = now;
215                let response = entry.response.clone();
216                let persisted = entry.clone();
217                guard.stats.hits += 1;
218                guard.stats.entries = guard.entries.len();
219                drop(guard);
220                // Persist the updated last_hit so LRU eviction
221                // respects read traffic across restarts.
222                let key = cache_key(&persisted);
223                self.persist_entry(&key, &persisted);
224                Some(response)
225            }
226            None => {
227                guard.stats.misses += 1;
228                guard.stats.entries = guard.entries.len();
229                None
230            }
231        }
232    }
233
234    /// Insert `(prompt, response)` keyed by `embedding`.
235    /// If `ttl_ms_override` is `None` the config default applies.
236    pub fn insert(
237        &self,
238        prompt: impl Into<String>,
239        response: impl Into<String>,
240        embedding: Vec<f32>,
241        ttl_ms_override: Option<u64>,
242    ) {
243        if embedding.is_empty() {
244            return;
245        }
246        let now = now_ms();
247        let ttl = ttl_ms_override.unwrap_or(self.config.default_ttl_ms);
248        let expires_at_ms = if ttl == 0 { 0 } else { now.saturating_add(ttl) };
249        let entry = SemanticCacheEntry {
250            prompt: prompt.into(),
251            response: response.into(),
252            embedding,
253            expires_at_ms,
254            last_hit_ms: now,
255            inserted_at_ms: now,
256        };
257        let evicted_keys: Vec<String>;
258        let stored_key: String;
259        let persist_entry: SemanticCacheEntry;
260        {
261            let mut guard = match self.inner.lock() {
262                Ok(g) => g,
263                Err(p) => p.into_inner(),
264            };
265            // Enforce capacity first. Oldest `inserted_at` loses —
266            // simple and deterministic; swap for LRU-on-read later
267            // if needed.
268            let mut pruned: Vec<String> = Vec::new();
269            if self.config.max_entries > 0 {
270                while guard.entries.len() >= self.config.max_entries {
271                    if let Some((oldest_idx, _)) = guard
272                        .entries
273                        .iter()
274                        .enumerate()
275                        .min_by_key(|(_, e)| e.inserted_at_ms)
276                    {
277                        let gone = guard.entries.remove(oldest_idx);
278                        guard.stats.capacity_evictions += 1;
279                        pruned.push(cache_key(&gone));
280                    } else {
281                        break;
282                    }
283                }
284            }
285            guard.entries.push(entry.clone());
286            guard.stats.entries = guard.entries.len();
287            evicted_keys = pruned;
288            stored_key = cache_key(&entry);
289            persist_entry = entry;
290        }
291        for k in &evicted_keys {
292            self.forget_entry(k);
293        }
294        self.persist_entry(&stored_key, &persist_entry);
295    }
296
297    /// Manually force a sweep. Returns number of entries dropped.
298    pub fn evict_expired(&self) -> usize {
299        let now = now_ms();
300        let evicted_keys: Vec<String>;
301        let count;
302        {
303            let mut guard = match self.inner.lock() {
304                Ok(g) => g,
305                Err(p) => p.into_inner(),
306            };
307            let mut keep = Vec::with_capacity(guard.entries.len());
308            let mut dropped = Vec::new();
309            for entry in guard.entries.drain(..) {
310                if entry.is_expired_at(now) {
311                    dropped.push(cache_key(&entry));
312                } else {
313                    keep.push(entry);
314                }
315            }
316            count = dropped.len();
317            guard.entries = keep;
318            guard.stats.expired_evictions += count as u64;
319            guard.stats.entries = guard.entries.len();
320            evicted_keys = dropped;
321        }
322        for k in &evicted_keys {
323            self.forget_entry(k);
324        }
325        count
326    }
327
328    /// Snapshot of counters.
329    pub fn stats(&self) -> SemanticCacheStats {
330        let guard = match self.inner.lock() {
331            Ok(g) => g,
332            Err(p) => p.into_inner(),
333        };
334        SemanticCacheStats {
335            entries: guard.entries.len(),
336            ..guard.stats.clone()
337        }
338    }
339
340    pub fn config(&self) -> &SemanticCacheConfig {
341        &self.config
342    }
343}
344
345/// Deterministic key per entry — the inserted_at timestamp plus the
346/// first 16 bytes of the prompt gives a collision-resistant, sortable
347/// identifier without pulling sha2 into the ML module.
348fn cache_key(entry: &SemanticCacheEntry) -> String {
349    // Hash the prompt with a small FNV-1a so the key is stable across
350    // processes. Avoids depending on `sha2` here; collisions between
351    // two distinct prompts inserted in the exact same millisecond are
352    // acceptable (the second insert overwrites the first).
353    const FNV_OFFSET: u64 = 0xcbf29ce484222325;
354    const FNV_PRIME: u64 = 0x100000001b3;
355    let mut h = FNV_OFFSET;
356    for b in entry.prompt.as_bytes() {
357        h ^= *b as u64;
358        h = h.wrapping_mul(FNV_PRIME);
359    }
360    format!("{:020}-{:016x}", entry.inserted_at_ms, h)
361}
362
363fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
364    if a.len() != b.len() || a.is_empty() {
365        return 0.0;
366    }
367    let mut dot = 0.0f32;
368    let mut na = 0.0f32;
369    let mut nb = 0.0f32;
370    for i in 0..a.len() {
371        dot += a[i] * b[i];
372        na += a[i] * a[i];
373        nb += b[i] * b[i];
374    }
375    if na == 0.0 || nb == 0.0 {
376        return 0.0;
377    }
378    dot / (na.sqrt() * nb.sqrt())
379}
380
381// ---- JSON (en|de)coding of a single entry -------------------------------
382
383fn encode_entry(entry: &SemanticCacheEntry) -> String {
384    let mut obj = Map::new();
385    obj.insert(
386        "prompt".to_string(),
387        JsonValue::String(entry.prompt.clone()),
388    );
389    obj.insert(
390        "response".to_string(),
391        JsonValue::String(entry.response.clone()),
392    );
393    obj.insert(
394        "embedding".to_string(),
395        JsonValue::Array(
396            entry
397                .embedding
398                .iter()
399                .map(|f| JsonValue::Number(*f as f64))
400                .collect(),
401        ),
402    );
403    obj.insert(
404        "expires_at".to_string(),
405        JsonValue::Number(entry.expires_at_ms as f64),
406    );
407    obj.insert(
408        "last_hit".to_string(),
409        JsonValue::Number(entry.last_hit_ms as f64),
410    );
411    obj.insert(
412        "inserted_at".to_string(),
413        JsonValue::Number(entry.inserted_at_ms as f64),
414    );
415    JsonValue::Object(obj).to_string_compact()
416}
417
418fn decode_entry(raw: &str) -> Option<SemanticCacheEntry> {
419    let parsed = crate::json::parse_json(raw).ok()?;
420    let value = JsonValue::from(parsed);
421    let obj = value.as_object()?;
422    let prompt = obj.get("prompt")?.as_str()?.to_string();
423    let response = obj.get("response")?.as_str()?.to_string();
424    let embedding = obj
425        .get("embedding")?
426        .as_array()?
427        .iter()
428        .filter_map(|v| v.as_f64().map(|f| f as f32))
429        .collect::<Vec<f32>>();
430    let expires_at_ms = obj.get("expires_at")?.as_i64()? as u64;
431    let last_hit_ms = obj.get("last_hit")?.as_i64()? as u64;
432    let inserted_at_ms = obj.get("inserted_at")?.as_i64()? as u64;
433    Some(SemanticCacheEntry {
434        prompt,
435        response,
436        embedding,
437        expires_at_ms,
438        last_hit_ms,
439        inserted_at_ms,
440    })
441}
442
443#[cfg(test)]
444mod tests {
445    use super::super::persist::InMemoryMlPersistence;
446    use super::*;
447
448    fn cfg(threshold: f32, max: usize, ttl: u64) -> SemanticCacheConfig {
449        SemanticCacheConfig {
450            similarity_threshold: threshold,
451            default_ttl_ms: ttl,
452            max_entries: max,
453            namespace: "t".to_string(),
454        }
455    }
456
457    #[test]
458    fn cosine_similarity_is_symmetric_and_bounded() {
459        let a = [1.0, 0.0, 0.0];
460        let b = [0.0, 1.0, 0.0];
461        let c = [1.0, 0.0, 0.0];
462        assert!((cosine_similarity(&a, &c) - 1.0).abs() < 1e-6);
463        assert!(cosine_similarity(&a, &b).abs() < 1e-6);
464        assert!((cosine_similarity(&a, &b) - cosine_similarity(&b, &a)).abs() < 1e-6);
465    }
466
467    #[test]
468    fn cosine_zero_on_mismatched_dims_or_zero_vec() {
469        assert_eq!(cosine_similarity(&[1.0], &[1.0, 0.0]), 0.0);
470        assert_eq!(cosine_similarity(&[0.0, 0.0], &[0.0, 0.0]), 0.0);
471    }
472
473    #[test]
474    fn miss_returns_none_and_increments_miss_counter() {
475        let c = SemanticCache::new(cfg(0.9, 100, 0));
476        assert!(c.lookup(&[1.0, 0.0]).is_none());
477        assert_eq!(c.stats().misses, 1);
478        assert_eq!(c.stats().hits, 0);
479    }
480
481    #[test]
482    fn inserted_entry_is_found_on_identical_vector() {
483        let c = SemanticCache::new(cfg(0.9, 100, 0));
484        c.insert("p", "hello world", vec![1.0, 0.0, 0.0], None);
485        let got = c.lookup(&[1.0, 0.0, 0.0]).unwrap();
486        assert_eq!(got, "hello world");
487        assert_eq!(c.stats().hits, 1);
488    }
489
490    #[test]
491    fn below_threshold_is_a_miss() {
492        let c = SemanticCache::new(cfg(0.99, 100, 0));
493        c.insert("p", "r", vec![1.0, 0.0, 0.0], None);
494        // Cosine of [1,0,0] vs [0.8, 0.6, 0] = 0.8 < 0.99
495        assert!(c.lookup(&[0.8, 0.6, 0.0]).is_none());
496    }
497
498    #[test]
499    fn expired_entries_are_skipped_and_evicted() {
500        let c = SemanticCache::new(cfg(0.9, 100, 1));
501        c.insert("p", "r", vec![1.0, 0.0], None);
502        std::thread::sleep(std::time::Duration::from_millis(5));
503        assert!(c.lookup(&[1.0, 0.0]).is_none());
504        let stats = c.stats();
505        assert_eq!(stats.entries, 0);
506        assert!(stats.expired_evictions >= 1);
507    }
508
509    #[test]
510    fn capacity_limit_evicts_oldest_inserted() {
511        let c = SemanticCache::new(cfg(0.9, 2, 0));
512        c.insert("first", "r1", vec![1.0, 0.0], None);
513        std::thread::sleep(std::time::Duration::from_millis(2));
514        c.insert("second", "r2", vec![0.0, 1.0], None);
515        std::thread::sleep(std::time::Duration::from_millis(2));
516        c.insert("third", "r3", vec![1.0, 1.0], None);
517        assert_eq!(c.stats().entries, 2);
518        assert!(c.stats().capacity_evictions >= 1);
519        // first should have been evicted
520        assert!(c.lookup(&[1.0, 0.0]).is_none() || c.lookup(&[1.0, 0.0]) != Some("r1".to_string()));
521    }
522
523    #[test]
524    fn best_candidate_wins_when_multiple_match() {
525        let c = SemanticCache::new(cfg(0.5, 100, 0));
526        c.insert("lo", "LO", vec![0.7, 0.7, 0.1], None);
527        c.insert("hi", "HI", vec![1.0, 0.0, 0.0], None);
528        let got = c.lookup(&[1.0, 0.0, 0.0]).unwrap();
529        assert_eq!(got, "HI");
530    }
531
532    #[test]
533    fn backend_round_trips_entry() {
534        let backend: Arc<dyn MlPersistence> = Arc::new(InMemoryMlPersistence::new());
535        let c1 = SemanticCache::with_backend(cfg(0.9, 100, 0), Arc::clone(&backend));
536        c1.insert("prompt one", "response one", vec![1.0, 0.0], None);
537        let c2 = SemanticCache::with_backend(cfg(0.9, 100, 0), backend);
538        let got = c2.lookup(&[1.0, 0.0]).unwrap();
539        assert_eq!(got, "response one");
540    }
541
542    #[test]
543    fn encode_decode_entry_round_trips() {
544        let e = SemanticCacheEntry {
545            prompt: "why".to_string(),
546            response: "because".to_string(),
547            embedding: vec![0.1, 0.2, -0.3],
548            expires_at_ms: 100,
549            last_hit_ms: 50,
550            inserted_at_ms: 10,
551        };
552        let back = decode_entry(&encode_entry(&e)).unwrap();
553        assert_eq!(back.prompt, e.prompt);
554        assert_eq!(back.response, e.response);
555        assert_eq!(back.embedding.len(), e.embedding.len());
556        for (a, b) in back.embedding.iter().zip(e.embedding.iter()) {
557            assert!((a - b).abs() < 1e-6);
558        }
559        assert_eq!(back.expires_at_ms, e.expires_at_ms);
560    }
561
562    #[test]
563    fn stats_entries_reflect_live_set() {
564        let c = SemanticCache::new(cfg(0.9, 100, 0));
565        c.insert("a", "1", vec![1.0, 0.0], None);
566        c.insert("b", "2", vec![0.0, 1.0], None);
567        assert_eq!(c.stats().entries, 2);
568    }
569}