Skip to main content

reddb_server/runtime/ai/
dedup_cache.rs

1//! Embedding dedup cache — issue #277.
2//!
3//! Optional LRU cache keyed by SHA-256(text) → Vec<f32>.
4//! Off by default; opt-in via `runtime.ai.embedding_dedup_enabled = true`.
5
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::time::{Duration, Instant};
8
9/// Global process-wide dedup hit counter (for /metrics).
10pub static DEDUP_HITS_TOTAL: AtomicU64 = AtomicU64::new(0);
11/// Global process-wide dedup miss counter (for /metrics).
12pub static DEDUP_MISSES_TOTAL: AtomicU64 = AtomicU64::new(0);
13
14use lru::LruCache;
15use parking_lot::Mutex;
16use sha2::{Digest, Sha256};
17
18pub const CONFIG_DEDUP_ENABLED: &str = "runtime.ai.embedding_dedup_enabled";
19pub const CONFIG_DEDUP_TTL_MS: &str = "runtime.ai.embedding_dedup_ttl_ms";
20pub const CONFIG_DEDUP_LRU_SIZE: &str = "runtime.ai.embedding_dedup_lru_size";
21
22pub const DEFAULT_DEDUP_TTL_MS: u64 = 60_000;
23pub const DEFAULT_DEDUP_LRU_SIZE: usize = 4096;
24
25type HashKey = [u8; 32];
26
27struct Entry {
28    embedding: Vec<f32>,
29    inserted_at: Instant,
30}
31
32pub struct EmbeddingDedupCache {
33    inner: Mutex<LruCache<HashKey, Entry>>,
34    ttl: Duration,
35    hits: AtomicU64,
36    misses: AtomicU64,
37}
38
39impl EmbeddingDedupCache {
40    pub fn new(max_size: usize, ttl: Duration) -> Self {
41        let capacity = std::num::NonZeroUsize::new(max_size.max(1)).expect("max_size >= 1");
42        Self {
43            inner: Mutex::new(LruCache::new(capacity)),
44            ttl,
45            hits: AtomicU64::new(0),
46            misses: AtomicU64::new(0),
47        }
48    }
49
50    /// Look up `text` in the cache. Returns `Some(embedding)` on hit.
51    pub fn get(&self, text: &str) -> Option<Vec<f32>> {
52        let key = hash(text);
53        let mut guard = self.inner.lock();
54        match guard.get(&key) {
55            Some(entry) if entry.inserted_at.elapsed() < self.ttl => {
56                self.hits.fetch_add(1, Ordering::Relaxed);
57                DEDUP_HITS_TOTAL.fetch_add(1, Ordering::Relaxed);
58                Some(entry.embedding.clone())
59            }
60            Some(_expired) => {
61                // TTL expired — remove and treat as miss
62                guard.pop(&key);
63                self.misses.fetch_add(1, Ordering::Relaxed);
64                DEDUP_MISSES_TOTAL.fetch_add(1, Ordering::Relaxed);
65                None
66            }
67            None => {
68                self.misses.fetch_add(1, Ordering::Relaxed);
69                DEDUP_MISSES_TOTAL.fetch_add(1, Ordering::Relaxed);
70                None
71            }
72        }
73    }
74
75    /// Insert `embedding` for `text`.
76    pub fn insert(&self, text: &str, embedding: Vec<f32>) {
77        let key = hash(text);
78        self.inner.lock().put(
79            key,
80            Entry {
81                embedding,
82                inserted_at: Instant::now(),
83            },
84        );
85    }
86
87    pub fn hits(&self) -> u64 {
88        self.hits.load(Ordering::Relaxed)
89    }
90
91    pub fn misses(&self) -> u64 {
92        self.misses.load(Ordering::Relaxed)
93    }
94}
95
96fn hash(text: &str) -> HashKey {
97    let mut hasher = Sha256::new();
98    hasher.update(text.as_bytes());
99    hasher.finalize().into()
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    fn cache(size: usize, ttl_ms: u64) -> EmbeddingDedupCache {
107        EmbeddingDedupCache::new(size, Duration::from_millis(ttl_ms))
108    }
109
110    #[test]
111    fn miss_then_hit() {
112        let c = cache(16, 60_000);
113        assert!(c.get("hello").is_none());
114        c.insert("hello", vec![1.0, 2.0]);
115        let v = c.get("hello").unwrap();
116        assert_eq!(v, vec![1.0, 2.0]);
117        assert_eq!(c.hits(), 1);
118        assert_eq!(c.misses(), 1);
119    }
120
121    #[test]
122    fn lru_eviction() {
123        let c = cache(2, 60_000);
124        c.insert("a", vec![1.0]);
125        c.insert("b", vec![2.0]);
126        // access "a" to make "b" the LRU
127        c.get("a");
128        c.insert("c", vec![3.0]); // evicts "b"
129        assert!(c.get("b").is_none());
130        assert!(c.get("a").is_some());
131        assert!(c.get("c").is_some());
132    }
133
134    #[test]
135    fn ttl_expired_treated_as_miss() {
136        let c = cache(16, 1); // 1ms TTL
137        c.insert("x", vec![9.9]);
138        std::thread::sleep(Duration::from_millis(5));
139        assert!(c.get("x").is_none());
140    }
141
142    #[test]
143    fn dedup_1000_inputs_10_unique() {
144        // simulate: 1000 inputs with 10 unique texts → only 10 misses
145        let c = cache(1024, 60_000);
146        let unique: Vec<String> = (0..10).map(|i| format!("text {i}")).collect();
147        let inputs: Vec<String> = (0..1000).map(|i| unique[i % 10].clone()).collect();
148
149        let mut miss_count = 0usize;
150        let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(inputs.len());
151        for text in &inputs {
152            if let Some(cached) = c.get(text) {
153                embeddings.push(cached);
154            } else {
155                miss_count += 1;
156                let emb = vec![miss_count as f32];
157                c.insert(text, emb.clone());
158                embeddings.push(emb);
159            }
160        }
161
162        assert_eq!(miss_count, 10, "only 10 unique texts → 10 provider calls");
163        assert_eq!(embeddings.len(), 1000);
164        assert_eq!(c.misses(), 10);
165        assert_eq!(c.hits(), 990);
166    }
167}