1use async_trait::async_trait;
4use lru::LruCache;
5use parking_lot::RwLock;
6use sha2::{Digest, Sha256};
7use std::num::NonZeroUsize;
8use std::sync::atomic::{AtomicU64, Ordering};
9
10use crate::{Cache, CacheConfig, CacheEntry, CacheError, CacheStats};
11
12pub struct MemoryCache {
14 cache: RwLock<LruCache<String, CacheEntry>>,
15 config: CacheConfig,
16 stats: MemoryCacheStats,
17}
18
19struct MemoryCacheStats {
20 hits: AtomicU64,
21 misses: AtomicU64,
22 stores: AtomicU64,
23 evictions: AtomicU64,
24}
25
26impl Default for MemoryCacheStats {
27 fn default() -> Self {
28 Self {
29 hits: AtomicU64::new(0),
30 misses: AtomicU64::new(0),
31 stores: AtomicU64::new(0),
32 evictions: AtomicU64::new(0),
33 }
34 }
35}
36
37impl MemoryCache {
38 pub fn new(config: CacheConfig) -> Self {
40 let capacity =
41 NonZeroUsize::new(config.max_entries).unwrap_or(NonZeroUsize::new(1000).unwrap());
42
43 Self {
44 cache: RwLock::new(LruCache::new(capacity)),
45 config,
46 stats: MemoryCacheStats::default(),
47 }
48 }
49
50 pub fn with_defaults() -> Self {
52 Self::new(CacheConfig::default())
53 }
54
55 fn make_key(query: &str, context: &str) -> String {
57 let mut hasher = Sha256::new();
58 hasher.update(query.as_bytes());
59 if !context.is_empty() {
60 hasher.update(b"|ctx:");
61 let ctx_bytes = context.as_bytes();
63 let limit = ctx_bytes.len().min(500);
64 hasher.update(&ctx_bytes[..limit]);
65 }
66 format!("{:x}", hasher.finalize())
67 }
68
69 pub fn len(&self) -> usize {
71 self.cache.read().len()
72 }
73
74 pub fn is_empty(&self) -> bool {
76 self.cache.read().is_empty()
77 }
78
79 pub fn capacity(&self) -> usize {
81 self.config.max_entries
82 }
83}
84
85#[async_trait]
86impl Cache for MemoryCache {
87 async fn get(&self, query: &str, context: &str) -> Result<Option<CacheEntry>, CacheError> {
88 let key = Self::make_key(query, context);
89 let ttl_secs = self.config.ttl.as_secs() as i64;
90
91 let mut cache = self.cache.write();
92
93 if let Some(entry) = cache.get_mut(&key) {
94 if entry.is_expired(ttl_secs) {
96 cache.pop(&key);
97 self.stats.misses.fetch_add(1, Ordering::Relaxed);
98 self.stats.evictions.fetch_add(1, Ordering::Relaxed);
99 return Ok(None);
100 }
101
102 entry.record_hit();
103 self.stats.hits.fetch_add(1, Ordering::Relaxed);
104
105 tracing::debug!(
106 query = %query,
107 hits = %entry.hit_count,
108 "Cache HIT"
109 );
110
111 return Ok(Some(entry.clone()));
112 }
113
114 self.stats.misses.fetch_add(1, Ordering::Relaxed);
115 Ok(None)
116 }
117
118 async fn store(
119 &self,
120 query: &str,
121 context: &str,
122 response: &str,
123 function_calls: Vec<String>,
124 ) -> Result<(), CacheError> {
125 let key = Self::make_key(query, context);
126 let entry = CacheEntry::new(query, context, response, function_calls);
127
128 let mut cache = self.cache.write();
129
130 if cache.len() >= self.config.max_entries && !cache.contains(&key) {
132 self.stats.evictions.fetch_add(1, Ordering::Relaxed);
133 }
134
135 cache.put(key, entry);
136 self.stats.stores.fetch_add(1, Ordering::Relaxed);
137
138 tracing::debug!(
139 query = %query,
140 "Cache STORE"
141 );
142
143 Ok(())
144 }
145
146 async fn delete(&self, query: &str, context: &str) -> Result<bool, CacheError> {
147 let key = Self::make_key(query, context);
148 let mut cache = self.cache.write();
149 Ok(cache.pop(&key).is_some())
150 }
151
152 async fn clear(&self) -> Result<usize, CacheError> {
153 let mut cache = self.cache.write();
154 let count = cache.len();
155 cache.clear();
156 Ok(count)
157 }
158
159 async fn stats(&self) -> Result<CacheStats, CacheError> {
160 Ok(CacheStats {
161 entries: self.cache.read().len(),
162 hits: self.stats.hits.load(Ordering::Relaxed),
163 misses: self.stats.misses.load(Ordering::Relaxed),
164 stores: self.stats.stores.load(Ordering::Relaxed),
165 evictions: self.stats.evictions.load(Ordering::Relaxed),
166 })
167 }
168}
169
170pub struct SemanticMemoryCache {
172 inner: MemoryCache,
173 embeddings: RwLock<LruCache<String, Vec<f32>>>,
174}
175
176impl SemanticMemoryCache {
177 pub fn new(config: CacheConfig) -> Self {
179 let capacity =
180 NonZeroUsize::new(config.max_entries).unwrap_or(NonZeroUsize::new(1000).unwrap());
181
182 Self {
183 inner: MemoryCache::new(config),
184 embeddings: RwLock::new(LruCache::new(capacity)),
185 }
186 }
187
188 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
190 if a.len() != b.len() || a.is_empty() {
191 return 0.0;
192 }
193
194 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
195 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
196 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
197
198 if norm_a == 0.0 || norm_b == 0.0 {
199 return 0.0;
200 }
201
202 dot / (norm_a * norm_b)
203 }
204
205 pub async fn find_similar(
207 &self,
208 query_embedding: &[f32],
209 threshold: f32,
210 ) -> Option<(CacheEntry, f32)> {
211 let embeddings = self.embeddings.read();
212 let cache = self.inner.cache.read();
213 let ttl_secs = self.inner.config.ttl.as_secs() as i64;
214
215 let mut best_match: Option<(String, f32)> = None;
216
217 for (key, embedding) in embeddings.iter() {
218 let similarity = Self::cosine_similarity(query_embedding, embedding);
219
220 if similarity >= threshold
221 && (best_match.is_none() || similarity > best_match.as_ref().unwrap().1)
222 {
223 best_match = Some((key.clone(), similarity));
224 }
225 }
226
227 if let Some((key, similarity)) = best_match {
228 if let Some(entry) = cache.peek(&key) {
229 if !entry.is_expired(ttl_secs) {
230 return Some((entry.clone(), similarity));
231 }
232 }
233 }
234
235 None
236 }
237
238 pub async fn store_with_embedding(
240 &self,
241 query: &str,
242 context: &str,
243 response: &str,
244 function_calls: Vec<String>,
245 embedding: Vec<f32>,
246 ) -> Result<(), CacheError> {
247 let key = MemoryCache::make_key(query, context);
248
249 self.embeddings.write().put(key.clone(), embedding);
251
252 self.inner
254 .store(query, context, response, function_calls)
255 .await
256 }
257}
258
259#[async_trait]
260impl Cache for SemanticMemoryCache {
261 async fn get(&self, query: &str, context: &str) -> Result<Option<CacheEntry>, CacheError> {
262 self.inner.get(query, context).await
263 }
264
265 async fn store(
266 &self,
267 query: &str,
268 context: &str,
269 response: &str,
270 function_calls: Vec<String>,
271 ) -> Result<(), CacheError> {
272 self.inner
273 .store(query, context, response, function_calls)
274 .await
275 }
276
277 async fn delete(&self, query: &str, context: &str) -> Result<bool, CacheError> {
278 let key = MemoryCache::make_key(query, context);
279 self.embeddings.write().pop(&key);
280 self.inner.delete(query, context).await
281 }
282
283 async fn clear(&self) -> Result<usize, CacheError> {
284 self.embeddings.write().clear();
285 self.inner.clear().await
286 }
287
288 async fn stats(&self) -> Result<CacheStats, CacheError> {
289 self.inner.stats().await
290 }
291}
292
293#[async_trait]
294impl crate::SemanticCache for SemanticMemoryCache {
295 async fn find_similar_by_embedding(
296 &self,
297 query_embedding: &[f32],
298 threshold: f32,
299 ) -> Result<Option<(CacheEntry, f32)>, CacheError> {
300 Ok(self.find_similar(query_embedding, threshold).await)
301 }
302
303 async fn store_with_embedding(
304 &self,
305 query: &str,
306 context: &str,
307 response: &str,
308 function_calls: Vec<String>,
309 embedding: Vec<f32>,
310 ) -> Result<(), CacheError> {
311 SemanticMemoryCache::store_with_embedding(
312 self,
313 query,
314 context,
315 response,
316 function_calls,
317 embedding,
318 )
319 .await
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use std::time::Duration;
327
328 #[tokio::test]
329 async fn test_memory_cache_basic() {
330 let cache = MemoryCache::with_defaults();
331
332 cache
334 .store("query1", "ctx", "response1", vec![])
335 .await
336 .unwrap();
337
338 let entry = cache.get("query1", "ctx").await.unwrap();
340 assert!(entry.is_some());
341 assert_eq!(entry.unwrap().response, "response1");
342
343 let miss = cache.get("nonexistent", "ctx").await.unwrap();
345 assert!(miss.is_none());
346 }
347
348 #[tokio::test]
349 async fn test_memory_cache_hit_count() {
350 let cache = MemoryCache::with_defaults();
351
352 cache
353 .store("query", "ctx", "response", vec![])
354 .await
355 .unwrap();
356
357 for i in 1..=3 {
359 let entry = cache.get("query", "ctx").await.unwrap().unwrap();
360 assert_eq!(entry.hit_count, i);
361 }
362 }
363
364 #[tokio::test]
365 async fn test_memory_cache_expiry() {
366 let config = CacheConfig {
368 ttl: Duration::from_secs(1),
369 ..Default::default()
370 };
371 let cache = MemoryCache::new(config);
372
373 cache
374 .store("query", "ctx", "response", vec![])
375 .await
376 .unwrap();
377
378 assert!(cache.get("query", "ctx").await.unwrap().is_some());
380
381 tokio::time::sleep(Duration::from_millis(2100)).await;
384
385 assert!(cache.get("query", "ctx").await.unwrap().is_none());
387 }
388
389 #[tokio::test]
390 async fn test_memory_cache_lru_eviction() {
391 let config = CacheConfig {
392 max_entries: 3,
393 ..Default::default()
394 };
395 let cache = MemoryCache::new(config);
396
397 cache.store("q1", "", "r1", vec![]).await.unwrap();
399 cache.store("q2", "", "r2", vec![]).await.unwrap();
400 cache.store("q3", "", "r3", vec![]).await.unwrap();
401
402 assert_eq!(cache.len(), 3);
403
404 cache.get("q1", "").await.unwrap();
406
407 cache.store("q4", "", "r4", vec![]).await.unwrap();
409
410 assert_eq!(cache.len(), 3);
411 assert!(cache.get("q1", "").await.unwrap().is_some());
412 assert!(cache.get("q2", "").await.unwrap().is_none()); assert!(cache.get("q3", "").await.unwrap().is_some());
414 assert!(cache.get("q4", "").await.unwrap().is_some());
415 }
416
417 #[tokio::test]
418 async fn test_memory_cache_stats() {
419 let cache = MemoryCache::with_defaults();
420
421 cache.store("q1", "", "r1", vec![]).await.unwrap();
422 cache.store("q2", "", "r2", vec![]).await.unwrap();
423
424 cache.get("q1", "").await.unwrap(); cache.get("q1", "").await.unwrap(); cache.get("q3", "").await.unwrap(); let stats = cache.stats().await.unwrap();
429 assert_eq!(stats.entries, 2);
430 assert_eq!(stats.stores, 2);
431 assert_eq!(stats.hits, 2);
432 assert_eq!(stats.misses, 1);
433 }
434
435 #[tokio::test]
436 async fn test_memory_cache_delete() {
437 let cache = MemoryCache::with_defaults();
438
439 cache
440 .store("query", "ctx", "response", vec![])
441 .await
442 .unwrap();
443 assert!(cache.get("query", "ctx").await.unwrap().is_some());
444
445 let deleted = cache.delete("query", "ctx").await.unwrap();
446 assert!(deleted);
447
448 assert!(cache.get("query", "ctx").await.unwrap().is_none());
449 }
450
451 #[tokio::test]
452 async fn test_memory_cache_clear() {
453 let cache = MemoryCache::with_defaults();
454
455 cache.store("q1", "", "r1", vec![]).await.unwrap();
456 cache.store("q2", "", "r2", vec![]).await.unwrap();
457
458 let cleared = cache.clear().await.unwrap();
459 assert_eq!(cleared, 2);
460 assert!(cache.is_empty());
461 }
462
463 #[tokio::test]
464 async fn test_semantic_cache_similarity() {
465 let cache = SemanticMemoryCache::new(CacheConfig::default());
466
467 let embedding1 = vec![1.0, 0.0, 0.0];
469 cache
470 .store_with_embedding("q1", "", "r1", vec![], embedding1)
471 .await
472 .unwrap();
473
474 let query_embedding = vec![1.0, 0.0, 0.0];
476 let result = cache.find_similar(&query_embedding, 0.9).await;
477 assert!(result.is_some());
478 let (entry, similarity) = result.unwrap();
479 assert_eq!(entry.response, "r1");
480 assert!((similarity - 1.0).abs() < 0.001);
481
482 let query_embedding2 = vec![0.9, 0.1, 0.0];
484 let result2 = cache.find_similar(&query_embedding2, 0.9).await;
485 assert!(result2.is_some());
486
487 let query_embedding3 = vec![0.0, 1.0, 0.0];
489 let result3 = cache.find_similar(&query_embedding3, 0.9).await;
490 assert!(result3.is_none());
491 }
492}