reddb_server/runtime/ai/
dedup_cache.rs1use std::sync::atomic::{AtomicU64, Ordering};
7use std::time::{Duration, Instant};
8
9pub static DEDUP_HITS_TOTAL: AtomicU64 = AtomicU64::new(0);
11pub static DEDUP_MISSES_TOTAL: AtomicU64 = AtomicU64::new(0);
13
14use lru::LruCache;
15use parking_lot::Mutex;
16use sha2::{Digest, Sha256};
17
18pub const CONFIG_DEDUP_ENABLED: &str = "runtime.ai.embedding_dedup_enabled";
19pub const CONFIG_DEDUP_TTL_MS: &str = "runtime.ai.embedding_dedup_ttl_ms";
20pub const CONFIG_DEDUP_LRU_SIZE: &str = "runtime.ai.embedding_dedup_lru_size";
21
22pub const DEFAULT_DEDUP_TTL_MS: u64 = 60_000;
23pub const DEFAULT_DEDUP_LRU_SIZE: usize = 4096;
24
25type HashKey = [u8; 32];
26
27struct Entry {
28 embedding: Vec<f32>,
29 inserted_at: Instant,
30}
31
32pub struct EmbeddingDedupCache {
33 inner: Mutex<LruCache<HashKey, Entry>>,
34 ttl: Duration,
35 hits: AtomicU64,
36 misses: AtomicU64,
37}
38
39impl EmbeddingDedupCache {
40 pub fn new(max_size: usize, ttl: Duration) -> Self {
41 let capacity = std::num::NonZeroUsize::new(max_size.max(1)).expect("max_size >= 1");
42 Self {
43 inner: Mutex::new(LruCache::new(capacity)),
44 ttl,
45 hits: AtomicU64::new(0),
46 misses: AtomicU64::new(0),
47 }
48 }
49
50 pub fn get(&self, text: &str) -> Option<Vec<f32>> {
52 let key = hash(text);
53 let mut guard = self.inner.lock();
54 match guard.get(&key) {
55 Some(entry) if entry.inserted_at.elapsed() < self.ttl => {
56 self.hits.fetch_add(1, Ordering::Relaxed);
57 DEDUP_HITS_TOTAL.fetch_add(1, Ordering::Relaxed);
58 Some(entry.embedding.clone())
59 }
60 Some(_expired) => {
61 guard.pop(&key);
63 self.misses.fetch_add(1, Ordering::Relaxed);
64 DEDUP_MISSES_TOTAL.fetch_add(1, Ordering::Relaxed);
65 None
66 }
67 None => {
68 self.misses.fetch_add(1, Ordering::Relaxed);
69 DEDUP_MISSES_TOTAL.fetch_add(1, Ordering::Relaxed);
70 None
71 }
72 }
73 }
74
75 pub fn insert(&self, text: &str, embedding: Vec<f32>) {
77 let key = hash(text);
78 self.inner.lock().put(
79 key,
80 Entry {
81 embedding,
82 inserted_at: Instant::now(),
83 },
84 );
85 }
86
87 pub fn hits(&self) -> u64 {
88 self.hits.load(Ordering::Relaxed)
89 }
90
91 pub fn misses(&self) -> u64 {
92 self.misses.load(Ordering::Relaxed)
93 }
94}
95
96fn hash(text: &str) -> HashKey {
97 let mut hasher = Sha256::new();
98 hasher.update(text.as_bytes());
99 hasher.finalize().into()
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 fn cache(size: usize, ttl_ms: u64) -> EmbeddingDedupCache {
107 EmbeddingDedupCache::new(size, Duration::from_millis(ttl_ms))
108 }
109
110 #[test]
111 fn miss_then_hit() {
112 let c = cache(16, 60_000);
113 assert!(c.get("hello").is_none());
114 c.insert("hello", vec![1.0, 2.0]);
115 let v = c.get("hello").unwrap();
116 assert_eq!(v, vec![1.0, 2.0]);
117 assert_eq!(c.hits(), 1);
118 assert_eq!(c.misses(), 1);
119 }
120
121 #[test]
122 fn lru_eviction() {
123 let c = cache(2, 60_000);
124 c.insert("a", vec![1.0]);
125 c.insert("b", vec![2.0]);
126 c.get("a");
128 c.insert("c", vec![3.0]); assert!(c.get("b").is_none());
130 assert!(c.get("a").is_some());
131 assert!(c.get("c").is_some());
132 }
133
134 #[test]
135 fn ttl_expired_treated_as_miss() {
136 let c = cache(16, 1); c.insert("x", vec![9.9]);
138 std::thread::sleep(Duration::from_millis(5));
139 assert!(c.get("x").is_none());
140 }
141
142 #[test]
143 fn dedup_1000_inputs_10_unique() {
144 let c = cache(1024, 60_000);
146 let unique: Vec<String> = (0..10).map(|i| format!("text {i}")).collect();
147 let inputs: Vec<String> = (0..1000).map(|i| unique[i % 10].clone()).collect();
148
149 let mut miss_count = 0usize;
150 let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(inputs.len());
151 for text in &inputs {
152 if let Some(cached) = c.get(text) {
153 embeddings.push(cached);
154 } else {
155 miss_count += 1;
156 let emb = vec![miss_count as f32];
157 c.insert(text, emb.clone());
158 embeddings.push(emb);
159 }
160 }
161
162 assert_eq!(miss_count, 10, "only 10 unique texts → 10 provider calls");
163 assert_eq!(embeddings.len(), 1000);
164 assert_eq!(c.misses(), 10);
165 assert_eq!(c.hits(), 990);
166 }
167}