Skip to main content

cortexai_cache/
memory.rs

1//! In-memory LRU cache implementation
2
3use async_trait::async_trait;
4use lru::LruCache;
5use parking_lot::RwLock;
6use sha2::{Digest, Sha256};
7use std::num::NonZeroUsize;
8use std::sync::atomic::{AtomicU64, Ordering};
9
10use crate::{Cache, CacheConfig, CacheEntry, CacheError, CacheStats};
11
12/// In-memory LRU cache with TTL support
13pub struct MemoryCache {
14    cache: RwLock<LruCache<String, CacheEntry>>,
15    config: CacheConfig,
16    stats: MemoryCacheStats,
17}
18
19struct MemoryCacheStats {
20    hits: AtomicU64,
21    misses: AtomicU64,
22    stores: AtomicU64,
23    evictions: AtomicU64,
24}
25
26impl Default for MemoryCacheStats {
27    fn default() -> Self {
28        Self {
29            hits: AtomicU64::new(0),
30            misses: AtomicU64::new(0),
31            stores: AtomicU64::new(0),
32            evictions: AtomicU64::new(0),
33        }
34    }
35}
36
37impl MemoryCache {
38    /// Create a new memory cache
39    pub fn new(config: CacheConfig) -> Self {
40        let capacity =
41            NonZeroUsize::new(config.max_entries).unwrap_or(NonZeroUsize::new(1000).unwrap());
42
43        Self {
44            cache: RwLock::new(LruCache::new(capacity)),
45            config,
46            stats: MemoryCacheStats::default(),
47        }
48    }
49
50    /// Create with default config
51    pub fn with_defaults() -> Self {
52        Self::new(CacheConfig::default())
53    }
54
55    /// Generate cache key from query and context
56    fn make_key(query: &str, context: &str) -> String {
57        let mut hasher = Sha256::new();
58        hasher.update(query.as_bytes());
59        if !context.is_empty() {
60            hasher.update(b"|ctx:");
61            // Limit context to prevent huge keys
62            let ctx_bytes = context.as_bytes();
63            let limit = ctx_bytes.len().min(500);
64            hasher.update(&ctx_bytes[..limit]);
65        }
66        format!("{:x}", hasher.finalize())
67    }
68
69    /// Get current entry count
70    pub fn len(&self) -> usize {
71        self.cache.read().len()
72    }
73
74    /// Check if cache is empty
75    pub fn is_empty(&self) -> bool {
76        self.cache.read().is_empty()
77    }
78
79    /// Get cache capacity
80    pub fn capacity(&self) -> usize {
81        self.config.max_entries
82    }
83}
84
85#[async_trait]
86impl Cache for MemoryCache {
87    async fn get(&self, query: &str, context: &str) -> Result<Option<CacheEntry>, CacheError> {
88        let key = Self::make_key(query, context);
89        let ttl_secs = self.config.ttl.as_secs() as i64;
90
91        let mut cache = self.cache.write();
92
93        if let Some(entry) = cache.get_mut(&key) {
94            // Check if expired
95            if entry.is_expired(ttl_secs) {
96                cache.pop(&key);
97                self.stats.misses.fetch_add(1, Ordering::Relaxed);
98                self.stats.evictions.fetch_add(1, Ordering::Relaxed);
99                return Ok(None);
100            }
101
102            entry.record_hit();
103            self.stats.hits.fetch_add(1, Ordering::Relaxed);
104
105            tracing::debug!(
106                query = %query,
107                hits = %entry.hit_count,
108                "Cache HIT"
109            );
110
111            return Ok(Some(entry.clone()));
112        }
113
114        self.stats.misses.fetch_add(1, Ordering::Relaxed);
115        Ok(None)
116    }
117
118    async fn store(
119        &self,
120        query: &str,
121        context: &str,
122        response: &str,
123        function_calls: Vec<String>,
124    ) -> Result<(), CacheError> {
125        let key = Self::make_key(query, context);
126        let entry = CacheEntry::new(query, context, response, function_calls);
127
128        let mut cache = self.cache.write();
129
130        // Check if we're at capacity and will evict
131        if cache.len() >= self.config.max_entries && !cache.contains(&key) {
132            self.stats.evictions.fetch_add(1, Ordering::Relaxed);
133        }
134
135        cache.put(key, entry);
136        self.stats.stores.fetch_add(1, Ordering::Relaxed);
137
138        tracing::debug!(
139            query = %query,
140            "Cache STORE"
141        );
142
143        Ok(())
144    }
145
146    async fn delete(&self, query: &str, context: &str) -> Result<bool, CacheError> {
147        let key = Self::make_key(query, context);
148        let mut cache = self.cache.write();
149        Ok(cache.pop(&key).is_some())
150    }
151
152    async fn clear(&self) -> Result<usize, CacheError> {
153        let mut cache = self.cache.write();
154        let count = cache.len();
155        cache.clear();
156        Ok(count)
157    }
158
159    async fn stats(&self) -> Result<CacheStats, CacheError> {
160        Ok(CacheStats {
161            entries: self.cache.read().len(),
162            hits: self.stats.hits.load(Ordering::Relaxed),
163            misses: self.stats.misses.load(Ordering::Relaxed),
164            stores: self.stats.stores.load(Ordering::Relaxed),
165            evictions: self.stats.evictions.load(Ordering::Relaxed),
166        })
167    }
168}
169
170/// Memory cache with semantic similarity support
171pub struct SemanticMemoryCache {
172    inner: MemoryCache,
173    embeddings: RwLock<LruCache<String, Vec<f32>>>,
174}
175
176impl SemanticMemoryCache {
177    /// Create a new semantic memory cache
178    pub fn new(config: CacheConfig) -> Self {
179        let capacity =
180            NonZeroUsize::new(config.max_entries).unwrap_or(NonZeroUsize::new(1000).unwrap());
181
182        Self {
183            inner: MemoryCache::new(config),
184            embeddings: RwLock::new(LruCache::new(capacity)),
185        }
186    }
187
188    /// Cosine similarity between two vectors
189    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
190        if a.len() != b.len() || a.is_empty() {
191            return 0.0;
192        }
193
194        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
195        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
196        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
197
198        if norm_a == 0.0 || norm_b == 0.0 {
199            return 0.0;
200        }
201
202        dot / (norm_a * norm_b)
203    }
204
205    /// Find most similar entry
206    pub async fn find_similar(
207        &self,
208        query_embedding: &[f32],
209        threshold: f32,
210    ) -> Option<(CacheEntry, f32)> {
211        let embeddings = self.embeddings.read();
212        let cache = self.inner.cache.read();
213        let ttl_secs = self.inner.config.ttl.as_secs() as i64;
214
215        let mut best_match: Option<(String, f32)> = None;
216
217        for (key, embedding) in embeddings.iter() {
218            let similarity = Self::cosine_similarity(query_embedding, embedding);
219
220            if similarity >= threshold
221                && (best_match.is_none() || similarity > best_match.as_ref().unwrap().1)
222            {
223                best_match = Some((key.clone(), similarity));
224            }
225        }
226
227        if let Some((key, similarity)) = best_match {
228            if let Some(entry) = cache.peek(&key) {
229                if !entry.is_expired(ttl_secs) {
230                    return Some((entry.clone(), similarity));
231                }
232            }
233        }
234
235        None
236    }
237
238    /// Store with embedding
239    pub async fn store_with_embedding(
240        &self,
241        query: &str,
242        context: &str,
243        response: &str,
244        function_calls: Vec<String>,
245        embedding: Vec<f32>,
246    ) -> Result<(), CacheError> {
247        let key = MemoryCache::make_key(query, context);
248
249        // Store embedding
250        self.embeddings.write().put(key.clone(), embedding);
251
252        // Store entry
253        self.inner
254            .store(query, context, response, function_calls)
255            .await
256    }
257}
258
259#[async_trait]
260impl Cache for SemanticMemoryCache {
261    async fn get(&self, query: &str, context: &str) -> Result<Option<CacheEntry>, CacheError> {
262        self.inner.get(query, context).await
263    }
264
265    async fn store(
266        &self,
267        query: &str,
268        context: &str,
269        response: &str,
270        function_calls: Vec<String>,
271    ) -> Result<(), CacheError> {
272        self.inner
273            .store(query, context, response, function_calls)
274            .await
275    }
276
277    async fn delete(&self, query: &str, context: &str) -> Result<bool, CacheError> {
278        let key = MemoryCache::make_key(query, context);
279        self.embeddings.write().pop(&key);
280        self.inner.delete(query, context).await
281    }
282
283    async fn clear(&self) -> Result<usize, CacheError> {
284        self.embeddings.write().clear();
285        self.inner.clear().await
286    }
287
288    async fn stats(&self) -> Result<CacheStats, CacheError> {
289        self.inner.stats().await
290    }
291}
292
293#[async_trait]
294impl crate::SemanticCache for SemanticMemoryCache {
295    async fn find_similar_by_embedding(
296        &self,
297        query_embedding: &[f32],
298        threshold: f32,
299    ) -> Result<Option<(CacheEntry, f32)>, CacheError> {
300        Ok(self.find_similar(query_embedding, threshold).await)
301    }
302
303    async fn store_with_embedding(
304        &self,
305        query: &str,
306        context: &str,
307        response: &str,
308        function_calls: Vec<String>,
309        embedding: Vec<f32>,
310    ) -> Result<(), CacheError> {
311        SemanticMemoryCache::store_with_embedding(
312            self,
313            query,
314            context,
315            response,
316            function_calls,
317            embedding,
318        )
319        .await
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use std::time::Duration;
327
328    #[tokio::test]
329    async fn test_memory_cache_basic() {
330        let cache = MemoryCache::with_defaults();
331
332        // Store
333        cache
334            .store("query1", "ctx", "response1", vec![])
335            .await
336            .unwrap();
337
338        // Get
339        let entry = cache.get("query1", "ctx").await.unwrap();
340        assert!(entry.is_some());
341        assert_eq!(entry.unwrap().response, "response1");
342
343        // Miss
344        let miss = cache.get("nonexistent", "ctx").await.unwrap();
345        assert!(miss.is_none());
346    }
347
348    #[tokio::test]
349    async fn test_memory_cache_hit_count() {
350        let cache = MemoryCache::with_defaults();
351
352        cache
353            .store("query", "ctx", "response", vec![])
354            .await
355            .unwrap();
356
357        // Multiple gets should increment hit count
358        for i in 1..=3 {
359            let entry = cache.get("query", "ctx").await.unwrap().unwrap();
360            assert_eq!(entry.hit_count, i);
361        }
362    }
363
364    #[tokio::test]
365    async fn test_memory_cache_expiry() {
366        // Use 1 second TTL since is_expired() uses seconds granularity
367        let config = CacheConfig {
368            ttl: Duration::from_secs(1),
369            ..Default::default()
370        };
371        let cache = MemoryCache::new(config);
372
373        cache
374            .store("query", "ctx", "response", vec![])
375            .await
376            .unwrap();
377
378        // Should exist
379        assert!(cache.get("query", "ctx").await.unwrap().is_some());
380
381        // Wait for expiry (TTL is 1 second, wait 2.1 seconds to ensure > 1 second passes)
382        // The is_expired check uses strictly greater than, so we need > ttl_secs
383        tokio::time::sleep(Duration::from_millis(2100)).await;
384
385        // Should be expired
386        assert!(cache.get("query", "ctx").await.unwrap().is_none());
387    }
388
389    #[tokio::test]
390    async fn test_memory_cache_lru_eviction() {
391        let config = CacheConfig {
392            max_entries: 3,
393            ..Default::default()
394        };
395        let cache = MemoryCache::new(config);
396
397        // Fill cache
398        cache.store("q1", "", "r1", vec![]).await.unwrap();
399        cache.store("q2", "", "r2", vec![]).await.unwrap();
400        cache.store("q3", "", "r3", vec![]).await.unwrap();
401
402        assert_eq!(cache.len(), 3);
403
404        // Access q1 to make it recently used
405        cache.get("q1", "").await.unwrap();
406
407        // Add new entry, q2 should be evicted (LRU)
408        cache.store("q4", "", "r4", vec![]).await.unwrap();
409
410        assert_eq!(cache.len(), 3);
411        assert!(cache.get("q1", "").await.unwrap().is_some());
412        assert!(cache.get("q2", "").await.unwrap().is_none()); // Evicted
413        assert!(cache.get("q3", "").await.unwrap().is_some());
414        assert!(cache.get("q4", "").await.unwrap().is_some());
415    }
416
417    #[tokio::test]
418    async fn test_memory_cache_stats() {
419        let cache = MemoryCache::with_defaults();
420
421        cache.store("q1", "", "r1", vec![]).await.unwrap();
422        cache.store("q2", "", "r2", vec![]).await.unwrap();
423
424        cache.get("q1", "").await.unwrap(); // Hit
425        cache.get("q1", "").await.unwrap(); // Hit
426        cache.get("q3", "").await.unwrap(); // Miss
427
428        let stats = cache.stats().await.unwrap();
429        assert_eq!(stats.entries, 2);
430        assert_eq!(stats.stores, 2);
431        assert_eq!(stats.hits, 2);
432        assert_eq!(stats.misses, 1);
433    }
434
435    #[tokio::test]
436    async fn test_memory_cache_delete() {
437        let cache = MemoryCache::with_defaults();
438
439        cache
440            .store("query", "ctx", "response", vec![])
441            .await
442            .unwrap();
443        assert!(cache.get("query", "ctx").await.unwrap().is_some());
444
445        let deleted = cache.delete("query", "ctx").await.unwrap();
446        assert!(deleted);
447
448        assert!(cache.get("query", "ctx").await.unwrap().is_none());
449    }
450
451    #[tokio::test]
452    async fn test_memory_cache_clear() {
453        let cache = MemoryCache::with_defaults();
454
455        cache.store("q1", "", "r1", vec![]).await.unwrap();
456        cache.store("q2", "", "r2", vec![]).await.unwrap();
457
458        let cleared = cache.clear().await.unwrap();
459        assert_eq!(cleared, 2);
460        assert!(cache.is_empty());
461    }
462
463    #[tokio::test]
464    async fn test_semantic_cache_similarity() {
465        let cache = SemanticMemoryCache::new(CacheConfig::default());
466
467        // Store with embedding
468        let embedding1 = vec![1.0, 0.0, 0.0];
469        cache
470            .store_with_embedding("q1", "", "r1", vec![], embedding1)
471            .await
472            .unwrap();
473
474        // Find similar (exact match)
475        let query_embedding = vec![1.0, 0.0, 0.0];
476        let result = cache.find_similar(&query_embedding, 0.9).await;
477        assert!(result.is_some());
478        let (entry, similarity) = result.unwrap();
479        assert_eq!(entry.response, "r1");
480        assert!((similarity - 1.0).abs() < 0.001);
481
482        // Find similar (partial match)
483        let query_embedding2 = vec![0.9, 0.1, 0.0];
484        let result2 = cache.find_similar(&query_embedding2, 0.9).await;
485        assert!(result2.is_some());
486
487        // No match (too different)
488        let query_embedding3 = vec![0.0, 1.0, 0.0];
489        let result3 = cache.find_similar(&query_embedding3, 0.9).await;
490        assert!(result3.is_none());
491    }
492}