1use std::sync::{Arc, Mutex};
19
20use super::jobs::now_ms;
21use super::persist::{MlPersistence, MlPersistenceResult};
22use crate::json::{Map, Value as JsonValue};
23
24#[derive(Debug, Clone)]
26pub struct SemanticCacheEntry {
27 pub prompt: String,
28 pub response: String,
29 pub embedding: Vec<f32>,
30 pub expires_at_ms: u64,
32 pub last_hit_ms: u64,
34 pub inserted_at_ms: u64,
36}
37
38impl SemanticCacheEntry {
39 pub fn is_expired_at(&self, now_ms_val: u64) -> bool {
40 self.expires_at_ms != 0 && now_ms_val >= self.expires_at_ms
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct SemanticCacheConfig {
47 pub similarity_threshold: f32,
50 pub default_ttl_ms: u64,
54 pub max_entries: usize,
57 pub namespace: String,
60}
61
62impl Default for SemanticCacheConfig {
63 fn default() -> Self {
64 Self {
65 similarity_threshold: 0.95,
66 default_ttl_ms: 24 * 60 * 60 * 1000,
67 max_entries: 10_000,
68 namespace: "default".to_string(),
69 }
70 }
71}
72
73#[derive(Debug, Clone, Default, PartialEq, Eq)]
76pub struct SemanticCacheStats {
77 pub entries: usize,
78 pub hits: u64,
79 pub misses: u64,
80 pub expired_evictions: u64,
81 pub capacity_evictions: u64,
82}
83
84struct Inner {
85 entries: Vec<SemanticCacheEntry>,
86 stats: SemanticCacheStats,
87}
88
89#[derive(Clone)]
91pub struct SemanticCache {
92 inner: Arc<Mutex<Inner>>,
93 config: SemanticCacheConfig,
94 backend: Option<Arc<dyn MlPersistence>>,
95}
96
97impl std::fmt::Debug for SemanticCache {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 f.debug_struct("SemanticCache")
100 .field("namespace", &self.config.namespace)
101 .field("similarity_threshold", &self.config.similarity_threshold)
102 .field("max_entries", &self.config.max_entries)
103 .field("persistent", &self.backend.is_some())
104 .finish()
105 }
106}
107
108impl SemanticCache {
109 pub fn new(config: SemanticCacheConfig) -> Self {
111 Self {
112 inner: Arc::new(Mutex::new(Inner {
113 entries: Vec::new(),
114 stats: SemanticCacheStats::default(),
115 })),
116 config,
117 backend: None,
118 }
119 }
120
121 pub fn with_backend(config: SemanticCacheConfig, backend: Arc<dyn MlPersistence>) -> Self {
124 let cache = Self {
125 inner: Arc::new(Mutex::new(Inner {
126 entries: Vec::new(),
127 stats: SemanticCacheStats::default(),
128 })),
129 config,
130 backend: Some(backend),
131 };
132 let _ = cache.load_from_backend();
133 cache
134 }
135
136 fn backend_namespace(&self) -> String {
137 format!("cache:{}", self.config.namespace)
138 }
139
140 fn persist_entry(&self, key: &str, entry: &SemanticCacheEntry) {
141 if let Some(backend) = self.backend.as_ref() {
142 let _ = backend.put(&self.backend_namespace(), key, &encode_entry(entry));
143 }
144 }
145
146 fn forget_entry(&self, key: &str) {
147 if let Some(backend) = self.backend.as_ref() {
148 let _ = backend.delete(&self.backend_namespace(), key);
149 }
150 }
151
152 pub fn load_from_backend(&self) -> MlPersistenceResult<usize> {
155 let Some(backend) = self.backend.as_ref() else {
156 return Ok(0);
157 };
158 let rows = backend.list(&self.backend_namespace())?;
159 let mut loaded = 0usize;
160 let now = now_ms();
161 let mut guard = match self.inner.lock() {
162 Ok(g) => g,
163 Err(p) => p.into_inner(),
164 };
165 guard.entries.clear();
166 for (_, raw) in rows {
167 let Some(entry) = decode_entry(&raw) else {
168 continue;
169 };
170 if entry.is_expired_at(now) {
171 continue;
176 }
177 guard.entries.push(entry);
178 loaded += 1;
179 }
180 guard.stats.entries = guard.entries.len();
181 Ok(loaded)
182 }
183
184 pub fn lookup(&self, embedding: &[f32]) -> Option<String> {
187 if embedding.is_empty() {
188 return None;
189 }
190 let now = now_ms();
191 let mut guard = match self.inner.lock() {
192 Ok(g) => g,
193 Err(p) => p.into_inner(),
194 };
195 let before = guard.entries.len();
197 guard.entries.retain(|e| !e.is_expired_at(now));
198 let evicted = before - guard.entries.len();
199 guard.stats.expired_evictions += evicted as u64;
200
201 let mut best: Option<(usize, f32)> = None;
202 for (idx, entry) in guard.entries.iter().enumerate() {
203 let score = cosine_similarity(embedding, &entry.embedding);
204 if score >= self.config.similarity_threshold {
205 match best {
206 Some((_, best_score)) if best_score >= score => {}
207 _ => best = Some((idx, score)),
208 }
209 }
210 }
211 match best {
212 Some((idx, _)) => {
213 let entry = &mut guard.entries[idx];
214 entry.last_hit_ms = now;
215 let response = entry.response.clone();
216 let persisted = entry.clone();
217 guard.stats.hits += 1;
218 guard.stats.entries = guard.entries.len();
219 drop(guard);
220 let key = cache_key(&persisted);
223 self.persist_entry(&key, &persisted);
224 Some(response)
225 }
226 None => {
227 guard.stats.misses += 1;
228 guard.stats.entries = guard.entries.len();
229 None
230 }
231 }
232 }
233
234 pub fn insert(
237 &self,
238 prompt: impl Into<String>,
239 response: impl Into<String>,
240 embedding: Vec<f32>,
241 ttl_ms_override: Option<u64>,
242 ) {
243 if embedding.is_empty() {
244 return;
245 }
246 let now = now_ms();
247 let ttl = ttl_ms_override.unwrap_or(self.config.default_ttl_ms);
248 let expires_at_ms = if ttl == 0 { 0 } else { now.saturating_add(ttl) };
249 let entry = SemanticCacheEntry {
250 prompt: prompt.into(),
251 response: response.into(),
252 embedding,
253 expires_at_ms,
254 last_hit_ms: now,
255 inserted_at_ms: now,
256 };
257 let evicted_keys: Vec<String>;
258 let stored_key: String;
259 let persist_entry: SemanticCacheEntry;
260 {
261 let mut guard = match self.inner.lock() {
262 Ok(g) => g,
263 Err(p) => p.into_inner(),
264 };
265 let mut pruned: Vec<String> = Vec::new();
269 if self.config.max_entries > 0 {
270 while guard.entries.len() >= self.config.max_entries {
271 if let Some((oldest_idx, _)) = guard
272 .entries
273 .iter()
274 .enumerate()
275 .min_by_key(|(_, e)| e.inserted_at_ms)
276 {
277 let gone = guard.entries.remove(oldest_idx);
278 guard.stats.capacity_evictions += 1;
279 pruned.push(cache_key(&gone));
280 } else {
281 break;
282 }
283 }
284 }
285 guard.entries.push(entry.clone());
286 guard.stats.entries = guard.entries.len();
287 evicted_keys = pruned;
288 stored_key = cache_key(&entry);
289 persist_entry = entry;
290 }
291 for k in &evicted_keys {
292 self.forget_entry(k);
293 }
294 self.persist_entry(&stored_key, &persist_entry);
295 }
296
297 pub fn evict_expired(&self) -> usize {
299 let now = now_ms();
300 let evicted_keys: Vec<String>;
301 let count;
302 {
303 let mut guard = match self.inner.lock() {
304 Ok(g) => g,
305 Err(p) => p.into_inner(),
306 };
307 let mut keep = Vec::with_capacity(guard.entries.len());
308 let mut dropped = Vec::new();
309 for entry in guard.entries.drain(..) {
310 if entry.is_expired_at(now) {
311 dropped.push(cache_key(&entry));
312 } else {
313 keep.push(entry);
314 }
315 }
316 count = dropped.len();
317 guard.entries = keep;
318 guard.stats.expired_evictions += count as u64;
319 guard.stats.entries = guard.entries.len();
320 evicted_keys = dropped;
321 }
322 for k in &evicted_keys {
323 self.forget_entry(k);
324 }
325 count
326 }
327
328 pub fn stats(&self) -> SemanticCacheStats {
330 let guard = match self.inner.lock() {
331 Ok(g) => g,
332 Err(p) => p.into_inner(),
333 };
334 SemanticCacheStats {
335 entries: guard.entries.len(),
336 ..guard.stats.clone()
337 }
338 }
339
340 pub fn config(&self) -> &SemanticCacheConfig {
341 &self.config
342 }
343}
344
345fn cache_key(entry: &SemanticCacheEntry) -> String {
349 const FNV_OFFSET: u64 = 0xcbf29ce484222325;
354 const FNV_PRIME: u64 = 0x100000001b3;
355 let mut h = FNV_OFFSET;
356 for b in entry.prompt.as_bytes() {
357 h ^= *b as u64;
358 h = h.wrapping_mul(FNV_PRIME);
359 }
360 format!("{:020}-{:016x}", entry.inserted_at_ms, h)
361}
362
363fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
364 if a.len() != b.len() || a.is_empty() {
365 return 0.0;
366 }
367 let mut dot = 0.0f32;
368 let mut na = 0.0f32;
369 let mut nb = 0.0f32;
370 for i in 0..a.len() {
371 dot += a[i] * b[i];
372 na += a[i] * a[i];
373 nb += b[i] * b[i];
374 }
375 if na == 0.0 || nb == 0.0 {
376 return 0.0;
377 }
378 dot / (na.sqrt() * nb.sqrt())
379}
380
381fn encode_entry(entry: &SemanticCacheEntry) -> String {
384 let mut obj = Map::new();
385 obj.insert(
386 "prompt".to_string(),
387 JsonValue::String(entry.prompt.clone()),
388 );
389 obj.insert(
390 "response".to_string(),
391 JsonValue::String(entry.response.clone()),
392 );
393 obj.insert(
394 "embedding".to_string(),
395 JsonValue::Array(
396 entry
397 .embedding
398 .iter()
399 .map(|f| JsonValue::Number(*f as f64))
400 .collect(),
401 ),
402 );
403 obj.insert(
404 "expires_at".to_string(),
405 JsonValue::Number(entry.expires_at_ms as f64),
406 );
407 obj.insert(
408 "last_hit".to_string(),
409 JsonValue::Number(entry.last_hit_ms as f64),
410 );
411 obj.insert(
412 "inserted_at".to_string(),
413 JsonValue::Number(entry.inserted_at_ms as f64),
414 );
415 JsonValue::Object(obj).to_string_compact()
416}
417
418fn decode_entry(raw: &str) -> Option<SemanticCacheEntry> {
419 let parsed = crate::json::parse_json(raw).ok()?;
420 let value = JsonValue::from(parsed);
421 let obj = value.as_object()?;
422 let prompt = obj.get("prompt")?.as_str()?.to_string();
423 let response = obj.get("response")?.as_str()?.to_string();
424 let embedding = obj
425 .get("embedding")?
426 .as_array()?
427 .iter()
428 .filter_map(|v| v.as_f64().map(|f| f as f32))
429 .collect::<Vec<f32>>();
430 let expires_at_ms = obj.get("expires_at")?.as_i64()? as u64;
431 let last_hit_ms = obj.get("last_hit")?.as_i64()? as u64;
432 let inserted_at_ms = obj.get("inserted_at")?.as_i64()? as u64;
433 Some(SemanticCacheEntry {
434 prompt,
435 response,
436 embedding,
437 expires_at_ms,
438 last_hit_ms,
439 inserted_at_ms,
440 })
441}
442
443#[cfg(test)]
444mod tests {
445 use super::super::persist::InMemoryMlPersistence;
446 use super::*;
447
448 fn cfg(threshold: f32, max: usize, ttl: u64) -> SemanticCacheConfig {
449 SemanticCacheConfig {
450 similarity_threshold: threshold,
451 default_ttl_ms: ttl,
452 max_entries: max,
453 namespace: "t".to_string(),
454 }
455 }
456
457 #[test]
458 fn cosine_similarity_is_symmetric_and_bounded() {
459 let a = [1.0, 0.0, 0.0];
460 let b = [0.0, 1.0, 0.0];
461 let c = [1.0, 0.0, 0.0];
462 assert!((cosine_similarity(&a, &c) - 1.0).abs() < 1e-6);
463 assert!(cosine_similarity(&a, &b).abs() < 1e-6);
464 assert!((cosine_similarity(&a, &b) - cosine_similarity(&b, &a)).abs() < 1e-6);
465 }
466
467 #[test]
468 fn cosine_zero_on_mismatched_dims_or_zero_vec() {
469 assert_eq!(cosine_similarity(&[1.0], &[1.0, 0.0]), 0.0);
470 assert_eq!(cosine_similarity(&[0.0, 0.0], &[0.0, 0.0]), 0.0);
471 }
472
473 #[test]
474 fn miss_returns_none_and_increments_miss_counter() {
475 let c = SemanticCache::new(cfg(0.9, 100, 0));
476 assert!(c.lookup(&[1.0, 0.0]).is_none());
477 assert_eq!(c.stats().misses, 1);
478 assert_eq!(c.stats().hits, 0);
479 }
480
481 #[test]
482 fn inserted_entry_is_found_on_identical_vector() {
483 let c = SemanticCache::new(cfg(0.9, 100, 0));
484 c.insert("p", "hello world", vec![1.0, 0.0, 0.0], None);
485 let got = c.lookup(&[1.0, 0.0, 0.0]).unwrap();
486 assert_eq!(got, "hello world");
487 assert_eq!(c.stats().hits, 1);
488 }
489
490 #[test]
491 fn below_threshold_is_a_miss() {
492 let c = SemanticCache::new(cfg(0.99, 100, 0));
493 c.insert("p", "r", vec![1.0, 0.0, 0.0], None);
494 assert!(c.lookup(&[0.8, 0.6, 0.0]).is_none());
496 }
497
498 #[test]
499 fn expired_entries_are_skipped_and_evicted() {
500 let c = SemanticCache::new(cfg(0.9, 100, 1));
501 c.insert("p", "r", vec![1.0, 0.0], None);
502 std::thread::sleep(std::time::Duration::from_millis(5));
503 assert!(c.lookup(&[1.0, 0.0]).is_none());
504 let stats = c.stats();
505 assert_eq!(stats.entries, 0);
506 assert!(stats.expired_evictions >= 1);
507 }
508
509 #[test]
510 fn capacity_limit_evicts_oldest_inserted() {
511 let c = SemanticCache::new(cfg(0.9, 2, 0));
512 c.insert("first", "r1", vec![1.0, 0.0], None);
513 std::thread::sleep(std::time::Duration::from_millis(2));
514 c.insert("second", "r2", vec![0.0, 1.0], None);
515 std::thread::sleep(std::time::Duration::from_millis(2));
516 c.insert("third", "r3", vec![1.0, 1.0], None);
517 assert_eq!(c.stats().entries, 2);
518 assert!(c.stats().capacity_evictions >= 1);
519 assert!(c.lookup(&[1.0, 0.0]).is_none() || c.lookup(&[1.0, 0.0]) != Some("r1".to_string()));
521 }
522
523 #[test]
524 fn best_candidate_wins_when_multiple_match() {
525 let c = SemanticCache::new(cfg(0.5, 100, 0));
526 c.insert("lo", "LO", vec![0.7, 0.7, 0.1], None);
527 c.insert("hi", "HI", vec![1.0, 0.0, 0.0], None);
528 let got = c.lookup(&[1.0, 0.0, 0.0]).unwrap();
529 assert_eq!(got, "HI");
530 }
531
532 #[test]
533 fn backend_round_trips_entry() {
534 let backend: Arc<dyn MlPersistence> = Arc::new(InMemoryMlPersistence::new());
535 let c1 = SemanticCache::with_backend(cfg(0.9, 100, 0), Arc::clone(&backend));
536 c1.insert("prompt one", "response one", vec![1.0, 0.0], None);
537 let c2 = SemanticCache::with_backend(cfg(0.9, 100, 0), backend);
538 let got = c2.lookup(&[1.0, 0.0]).unwrap();
539 assert_eq!(got, "response one");
540 }
541
542 #[test]
543 fn encode_decode_entry_round_trips() {
544 let e = SemanticCacheEntry {
545 prompt: "why".to_string(),
546 response: "because".to_string(),
547 embedding: vec![0.1, 0.2, -0.3],
548 expires_at_ms: 100,
549 last_hit_ms: 50,
550 inserted_at_ms: 10,
551 };
552 let back = decode_entry(&encode_entry(&e)).unwrap();
553 assert_eq!(back.prompt, e.prompt);
554 assert_eq!(back.response, e.response);
555 assert_eq!(back.embedding.len(), e.embedding.len());
556 for (a, b) in back.embedding.iter().zip(e.embedding.iter()) {
557 assert!((a - b).abs() < 1e-6);
558 }
559 assert_eq!(back.expires_at_ms, e.expires_at_ms);
560 }
561
562 #[test]
563 fn stats_entries_reflect_live_set() {
564 let c = SemanticCache::new(cfg(0.9, 100, 0));
565 c.insert("a", "1", vec![1.0, 0.0], None);
566 c.insert("b", "2", vec![0.0, 1.0], None);
567 assert_eq!(c.stats().entries, 2);
568 }
569}