Skip to main content

oxios_memory/memory/
embedding_cache.rs

1//! Embedding cache for reducing API calls.
2//!
3//! Provides LRU cache with TTL for embedding vectors.
4//!
5//! # Example
6//!
7//! ```
8//! use oxios_memory::EmbeddingCache;
9//!
10//! let cache = EmbeddingCache::new(3600, 10000);  // 1 hour TTL, 10k max
11//! cache.insert("hello", vec![1.0, 2.0, 3.0]);
12//! let embedded = cache.get("hello");
13//! assert!(embedded.is_some());
14//! ```
15
16use lru::LruCache;
17use parking_lot::RwLock;
18use serde::{Deserialize, Serialize};
19use std::time::{Duration, Instant};
20
21/// Cache entry with TTL tracking.
22struct CacheEntry<V> {
23    value: V,
24    created_at: Instant,
25    ttl: Duration,
26}
27
28impl<V> CacheEntry<V> {
29    fn is_expired(&self) -> bool {
30        self.created_at.elapsed() > self.ttl
31    }
32}
33
34/// Content-addressable embedding cache with TTL and LRU eviction.
35pub struct EmbeddingCache {
36    inner: RwLock<LruCache<u64, CacheEntry<Vec<f32>>>>,
37    ttl: Duration,
38    max_entries: usize,
39    hits: RwLock<u64>,
40    misses: RwLock<u64>,
41}
42
43/// Cache statistics for monitoring.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct CacheStats {
46    /// Number of cache hits.
47    pub hits: u64,
48    /// Number of cache misses.
49    pub misses: u64,
50    /// Hit rate as a fraction (0.0 to 1.0).
51    pub hit_rate: f64,
52    /// Current number of entries in cache.
53    pub size: usize,
54    /// Maximum capacity of cache.
55    pub capacity: usize,
56}
57
58impl EmbeddingCache {
59    /// Create a new cache with TTL and capacity.
60    ///
61    /// # Arguments
62    /// * `ttl_secs` - Time-to-live for cached entries in seconds
63    /// * `max_entries` - Maximum number of entries to cache
64    pub fn new(ttl_secs: u64, max_entries: usize) -> Self {
65        Self {
66            inner: RwLock::new(LruCache::new(
67                std::num::NonZeroUsize::new(max_entries).unwrap_or(std::num::NonZeroUsize::MIN),
68            )),
69            ttl: Duration::from_secs(ttl_secs),
70            max_entries,
71            hits: RwLock::new(0),
72            misses: RwLock::new(0),
73        }
74    }
75
76    /// Hash content to cache key.
77    ///
78    /// Delegates to the stable FNV-1a `types::content_hash` so that in-memory
79    /// and persisted keys share the same version-independent algorithm.
80    pub fn content_hash(content: &str) -> u64 {
81        crate::memory::types::content_hash(content)
82    }
83
84    /// Get cached embedding if exists and not expired.
85    pub fn get(&self, content: &str) -> Option<Vec<f32>> {
86        let key = Self::content_hash(content);
87        let mut inner = self.inner.write();
88
89        match inner.get(&key) {
90            Some(entry) if !entry.is_expired() => {
91                *self.hits.write() += 1;
92                Some(entry.value.clone())
93            }
94            Some(_) => {
95                // Expired — remove
96                inner.pop(&key);
97                *self.misses.write() += 1;
98                None
99            }
100            None => {
101                *self.misses.write() += 1;
102                None
103            }
104        }
105    }
106
107    /// Cache an embedding.
108    pub fn insert(&self, content: &str, embedding: Vec<f32>) {
109        let key = Self::content_hash(content);
110        let mut inner = self.inner.write();
111
112        inner.push(
113            key,
114            CacheEntry {
115                value: embedding,
116                created_at: Instant::now(),
117                ttl: self.ttl,
118            },
119        );
120    }
121
122    /// Evict expired entries.
123    ///
124    /// Returns the number of entries evicted.
125    pub fn evict_expired(&self) -> usize {
126        let mut inner = self.inner.write();
127        let mut evicted = 0;
128
129        let keys: Vec<_> = inner
130            .iter()
131            .filter(|(_, entry)| entry.is_expired())
132            .map(|(k, _)| *k)
133            .collect();
134
135        for key in keys {
136            inner.pop(&key);
137            evicted += 1;
138        }
139
140        evicted
141    }
142
143    /// Evict least recently used entries to free space.
144    ///
145    /// Returns the number of entries evicted.
146    pub fn evict_lru(&self, target_size: usize) -> usize {
147        let mut inner = self.inner.write();
148        let mut evicted = 0;
149
150        while inner.len() > target_size {
151            if inner.pop_lru().is_none() {
152                break;
153            }
154            evicted += 1;
155        }
156
157        evicted
158    }
159
160    /// Cache statistics.
161    pub fn stats(&self) -> CacheStats {
162        let hits = *self.hits.read();
163        let misses = *self.misses.read();
164        let total = hits + misses;
165
166        CacheStats {
167            hits,
168            misses,
169            hit_rate: if total > 0 {
170                hits as f64 / total as f64
171            } else {
172                0.0
173            },
174            size: self.inner.read().len(),
175            capacity: self.max_entries,
176        }
177    }
178
179    /// Clear the cache.
180    pub fn clear(&self) {
181        self.inner.write().clear();
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use std::thread;
189    use std::time::Duration;
190
191    #[test]
192    fn test_cache_basic() {
193        let cache = EmbeddingCache::new(60, 100);
194
195        // Insert
196        cache.insert("hello", vec![1.0, 2.0, 3.0]);
197
198        // Get
199        let result = cache.get("hello");
200        assert!(result.is_some());
201        assert_eq!(result.unwrap(), vec![1.0, 2.0, 3.0]);
202
203        // Stats
204        let stats = cache.stats();
205        assert_eq!(stats.hits, 1);
206        assert_eq!(stats.misses, 0);
207    }
208
209    #[test]
210    fn test_cache_miss() {
211        let cache = EmbeddingCache::new(60, 100);
212
213        let result = cache.get("nonexistent");
214        assert!(result.is_none());
215
216        let stats = cache.stats();
217        assert_eq!(stats.hits, 0);
218        assert_eq!(stats.misses, 1);
219    }
220
221    #[test]
222    fn test_cache_ttl() {
223        let cache = EmbeddingCache::new(1, 100); // 1 second TTL
224
225        cache.insert("test", vec![1.0]);
226        assert!(cache.get("test").is_some());
227
228        // Wait for expiration
229        thread::sleep(Duration::from_millis(1_100));
230
231        // Should be expired
232        assert!(cache.get("test").is_none());
233    }
234
235    #[test]
236    fn test_cache_eviction() {
237        let cache = EmbeddingCache::new(60, 2);
238
239        cache.insert("a", vec![1.0]);
240        cache.insert("b", vec![2.0]);
241        cache.insert("c", vec![3.0]); // Should evict oldest
242
243        // a should be evicted
244        assert!(cache.get("a").is_none());
245
246        // b and c should exist
247        assert!(cache.get("b").is_some());
248        assert!(cache.get("c").is_some());
249    }
250}