1use dashmap::DashMap;
10use serde_json::Value;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::{Duration, Instant};
13
14#[derive(Debug, Clone)]
20pub struct SemanticCacheConfig {
21 pub similarity_threshold: f32,
23 pub max_entries: usize,
25 pub default_ttl_secs: u64,
27}
28
29impl Default for SemanticCacheConfig {
30 fn default() -> Self {
31 Self {
32 similarity_threshold: 0.92,
33 max_entries: 1000,
34 default_ttl_secs: 300,
35 }
36 }
37}
38
39struct CacheEntry {
44 query_embedding: Vec<f32>,
45 #[allow(dead_code)]
47 query_text: String,
48 results: Value,
49 created_at: Instant,
50 ttl: Duration,
51 hit_count: u64,
52}
53
54impl CacheEntry {
55 fn is_expired(&self) -> bool {
56 self.created_at.elapsed() > self.ttl
57 }
58}
59
60#[derive(Debug, Clone, Default)]
66pub struct SemanticCacheStats {
67 pub entries: usize,
68 pub hits: u64,
69 pub misses: u64,
70 pub evictions: u64,
71 pub invalidations: u64,
72}
73
74pub struct SemanticCache {
81 entries: DashMap<u64, CacheEntry>,
83 config: SemanticCacheConfig,
84 hits: AtomicU64,
85 misses: AtomicU64,
86 evictions: AtomicU64,
87 invalidations: AtomicU64,
88}
89
90impl SemanticCache {
91 pub fn new(config: SemanticCacheConfig) -> Self {
93 Self {
94 entries: DashMap::new(),
95 config,
96 hits: AtomicU64::new(0),
97 misses: AtomicU64::new(0),
98 evictions: AtomicU64::new(0),
99 invalidations: AtomicU64::new(0),
100 }
101 }
102
103 pub fn get(&self, query_embedding: &[f32]) -> Option<Value> {
113 let threshold = self.config.similarity_threshold;
114
115 let mut best_similarity = -1.0_f32;
116 let mut best_result: Option<Value> = None;
117 let mut best_key: Option<u64> = None;
118
119 for mut entry_ref in self.entries.iter_mut() {
120 if entry_ref.is_expired() {
121 continue;
122 }
123
124 let sim = cosine_similarity(query_embedding, &entry_ref.query_embedding);
125 if sim >= threshold && sim > best_similarity {
126 best_similarity = sim;
127 best_result = Some(entry_ref.results.clone());
128 best_key = Some(*entry_ref.key());
129 entry_ref.hit_count = entry_ref.hit_count.saturating_add(1);
131 }
132 }
133
134 let _ = best_key; if best_result.is_some() {
139 self.hits.fetch_add(1, Ordering::Relaxed);
140 } else {
141 self.misses.fetch_add(1, Ordering::Relaxed);
142 }
143
144 best_result
145 }
146
147 pub fn put(&self, query_embedding: Vec<f32>, query_text: String, results: Value) {
152 if self.entries.len() >= self.config.max_entries {
153 self.evict_oldest();
154 }
155
156 let key = embedding_hash(&query_embedding);
157 let entry = CacheEntry {
158 query_embedding,
159 query_text,
160 results,
161 created_at: Instant::now(),
162 ttl: Duration::from_secs(self.config.default_ttl_secs),
163 hit_count: 0,
164 };
165
166 self.entries.insert(key, entry);
167 }
168
169 pub fn invalidate_memory(&self, memory_id: i64) {
174 let target = Value::Number(memory_id.into());
175
176 self.entries.retain(|_, entry| {
177 let contains = json_contains(&entry.results, &target);
178 if contains {
179 self.invalidations.fetch_add(1, Ordering::Relaxed);
180 }
181 !contains
182 });
183 }
184
185 pub fn clear(&self) {
187 self.entries.clear();
188 }
189
190 pub fn stats(&self) -> SemanticCacheStats {
192 SemanticCacheStats {
193 entries: self.entries.len(),
194 hits: self.hits.load(Ordering::Relaxed),
195 misses: self.misses.load(Ordering::Relaxed),
196 evictions: self.evictions.load(Ordering::Relaxed),
197 invalidations: self.invalidations.load(Ordering::Relaxed),
198 }
199 }
200
201 fn evict_oldest(&self) {
207 let mut oldest_key: Option<u64> = None;
210 let mut oldest_time: Option<Instant> = None;
211
212 for entry_ref in self.entries.iter() {
213 let t = entry_ref.created_at;
214 match oldest_time {
215 None => {
216 oldest_time = Some(t);
217 oldest_key = Some(*entry_ref.key());
218 }
219 Some(ot) if t < ot => {
220 oldest_time = Some(t);
221 oldest_key = Some(*entry_ref.key());
222 }
223 _ => {}
224 }
225 }
226
227 if let Some(key) = oldest_key {
228 self.entries.remove(&key);
229 self.evictions.fetch_add(1, Ordering::Relaxed);
230 }
231 }
232}
233
234pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
242 if a.len() != b.len() || a.is_empty() {
243 return 0.0;
244 }
245
246 let mut dot = 0.0_f32;
247 let mut norm_a = 0.0_f32;
248 let mut norm_b = 0.0_f32;
249
250 for (x, y) in a.iter().zip(b.iter()) {
251 dot += x * y;
252 norm_a += x * x;
253 norm_b += y * y;
254 }
255
256 if norm_a == 0.0 || norm_b == 0.0 {
257 return 0.0;
258 }
259
260 dot / (norm_a.sqrt() * norm_b.sqrt())
261}
262
263pub fn embedding_hash(embedding: &[f32]) -> u64 {
268 let mut hash: u64 = 0xcbf29ce484222325; let take = embedding.len().min(8);
270 for &f in &embedding[..take] {
271 let bytes = f.to_le_bytes();
272 for byte in bytes {
273 hash ^= byte as u64;
274 hash = hash.wrapping_mul(0x100000001b3); }
276 }
277 hash
278}
279
280fn json_contains(json: &Value, target: &Value) -> bool {
282 if json == target {
283 return true;
284 }
285 match json {
286 Value::Array(arr) => arr.iter().any(|v| json_contains(v, target)),
287 Value::Object(map) => map.values().any(|v| json_contains(v, target)),
288 _ => false,
289 }
290}
291
292#[cfg(test)]
297mod tests {
298 use super::*;
299 use serde_json::json;
300 use std::thread;
301 use std::time::Duration;
302
303 fn default_cache() -> SemanticCache {
304 SemanticCache::new(SemanticCacheConfig::default())
305 }
306
307 fn unit_vec(dim: usize, hot: usize) -> Vec<f32> {
308 let mut v = vec![0.0_f32; dim];
309 v[hot] = 1.0;
310 v
311 }
312
313 #[test]
316 fn test_cache_hit() {
317 let cache = default_cache();
318 let emb = unit_vec(4, 0);
319 let results = json!({"memories": [{"id": 42}]});
320
321 cache.put(emb.clone(), "query".into(), results.clone());
322 let got = cache.get(&emb);
323 assert_eq!(got, Some(results));
324 }
325
326 #[test]
327 fn test_cache_miss_below_threshold() {
328 let cache = default_cache();
329 let emb_stored = unit_vec(4, 0);
331 let emb_query = unit_vec(4, 1);
332
333 cache.put(emb_stored, "query A".into(), json!({"memories": []}));
334 let got = cache.get(&emb_query);
335 assert!(got.is_none(), "perpendicular vectors must not hit");
336 }
337
338 #[test]
339 fn test_cache_hit_similar() {
340 let cache = SemanticCache::new(SemanticCacheConfig {
342 similarity_threshold: 0.92,
343 ..Default::default()
344 });
345
346 let emb_stored = vec![1.0_f32, 0.0, 0.0, 0.0];
347 let results = json!({"memories": [{"id": 7}]});
348 cache.put(emb_stored, "original".into(), results.clone());
349
350 let emb_similar = vec![1.0_f32, 0.01, 0.0, 0.0];
352 let got = cache.get(&emb_similar);
353 assert!(got.is_some(), "similar embedding must be a cache hit");
354 }
355
356 #[test]
357 fn test_ttl_expiration() {
358 let cache = SemanticCache::new(SemanticCacheConfig {
359 default_ttl_secs: 0, ..Default::default()
361 });
362
363 let emb = unit_vec(4, 2);
364 cache.put(emb.clone(), "q".into(), json!({"ok": true}));
365
366 thread::sleep(Duration::from_millis(5));
368 let got = cache.get(&emb);
369 assert!(got.is_none(), "entry should have expired");
370 }
371
372 #[test]
373 fn test_invalidate_memory() {
374 let cache = default_cache();
375 let emb = unit_vec(4, 0);
376 let results = json!([{"id": 99, "content": "hello"}]);
377
378 cache.put(emb.clone(), "q".into(), results);
379
380 assert!(cache.get(&emb).is_some());
382
383 cache.invalidate_memory(99);
385 assert!(
386 cache.get(&emb).is_none(),
387 "entry containing id 99 must be removed"
388 );
389
390 let stats = cache.stats();
391 assert_eq!(stats.invalidations, 1);
392 }
393
394 #[test]
395 fn test_invalidate_memory_does_not_remove_unrelated() {
396 let cache = default_cache();
397 let emb1 = unit_vec(4, 0);
398 let emb2 = unit_vec(4, 1);
399
400 cache.put(emb1.clone(), "q1".into(), json!([{"id": 1}]));
401 cache.put(emb2.clone(), "q2".into(), json!([{"id": 2}]));
402
403 cache.invalidate_memory(1);
404
405 assert!(cache.get(&emb1).is_none());
406 assert!(cache.get(&emb2).is_some());
407 }
408
409 #[test]
410 fn test_clear() {
411 let cache = default_cache();
412 cache.put(unit_vec(4, 0), "a".into(), json!(1));
413 cache.put(unit_vec(4, 1), "b".into(), json!(2));
414
415 cache.clear();
416 assert_eq!(cache.stats().entries, 0);
417 }
418
419 #[test]
420 fn test_stats() {
421 let cache = default_cache();
422 let emb = unit_vec(4, 3);
423
424 cache.get(&emb);
426
427 cache.put(emb.clone(), "q".into(), json!({"x": 1}));
429 cache.get(&emb);
430 cache.get(&emb);
431
432 let stats = cache.stats();
433 assert_eq!(stats.hits, 2, "expected 2 hits");
434 assert_eq!(stats.misses, 1, "expected 1 miss");
435 assert_eq!(stats.entries, 1);
436 }
437
438 #[test]
439 fn test_capacity_eviction() {
440 let cache = SemanticCache::new(SemanticCacheConfig {
441 max_entries: 2,
442 ..Default::default()
443 });
444
445 let emb0 = unit_vec(4, 0);
446 let emb1 = unit_vec(4, 1);
447 let emb2 = unit_vec(4, 2);
448
449 cache.put(emb0.clone(), "first".into(), json!("first"));
450 thread::sleep(Duration::from_millis(1));
452 cache.put(emb1.clone(), "second".into(), json!("second"));
453 thread::sleep(Duration::from_millis(1));
454 cache.put(emb2.clone(), "third".into(), json!("third"));
456
457 assert_eq!(cache.stats().entries, 2);
458 assert_eq!(cache.stats().evictions, 1);
459
460 assert!(
462 cache.get(&emb0).is_none(),
463 "oldest entry must have been evicted"
464 );
465 assert!(cache.get(&emb1).is_some());
466 assert!(cache.get(&emb2).is_some());
467 }
468
469 #[test]
474 fn test_cosine_similarity_identical() {
475 let v = vec![1.0_f32, 2.0, 3.0];
476 let sim = cosine_similarity(&v, &v);
477 assert!((sim - 1.0).abs() < 1e-6);
478 }
479
480 #[test]
481 fn test_cosine_similarity_orthogonal() {
482 let a = vec![1.0_f32, 0.0];
483 let b = vec![0.0_f32, 1.0];
484 assert_eq!(cosine_similarity(&a, &b), 0.0);
485 }
486
487 #[test]
488 fn test_cosine_similarity_empty() {
489 assert_eq!(cosine_similarity(&[], &[]), 0.0);
490 }
491
492 #[test]
493 fn test_cosine_similarity_length_mismatch() {
494 let a = vec![1.0_f32, 0.0];
495 let b = vec![1.0_f32];
496 assert_eq!(cosine_similarity(&a, &b), 0.0);
497 }
498}