Skip to main content

oxios_kernel/memory/
embedding_cache.rs

1//! Embedding cache for reducing API calls.
2//!
3//! Provides LRU cache with TTL for embedding vectors.
4//!
5//! # Example
6//!
7//! ```
8//! use oxios_kernel::memory::EmbeddingCache;
9//!
10//! let cache = EmbeddingCache::new(3600, 10000);  // 1 hour TTL, 10k max
11//! cache.insert("hello", vec![1.0, 2.0, 3.0]);
12//! let embedded = cache.get("hello");
13//! assert!(embedded.is_some());
14//! ```
15
16use lru::LruCache;
17use parking_lot::RwLock;
18use serde::{Deserialize, Serialize};
19use std::hash::{Hash, Hasher};
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22
23/// Cache entry with TTL tracking.
24struct CacheEntry<V> {
25    value: V,
26    created_at: Instant,
27    ttl: Duration,
28}
29
30impl<V> CacheEntry<V> {
31    fn is_expired(&self) -> bool {
32        self.created_at.elapsed() > self.ttl
33    }
34}
35
36/// Content-addressable embedding cache with TTL and LRU eviction.
37pub struct EmbeddingCache {
38    inner: RwLock<LruCache<u64, CacheEntry<Vec<f32>>>>,
39    ttl: Duration,
40    max_entries: usize,
41    hits: RwLock<u64>,
42    misses: RwLock<u64>,
43}
44
45/// Cache statistics for monitoring.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct CacheStats {
48    /// Number of cache hits.
49    pub hits: u64,
50    /// Number of cache misses.
51    pub misses: u64,
52    /// Hit rate as a fraction (0.0 to 1.0).
53    pub hit_rate: f64,
54    /// Current number of entries in cache.
55    pub size: usize,
56    /// Maximum capacity of cache.
57    pub capacity: usize,
58}
59
60impl EmbeddingCache {
61    /// Create a new cache with TTL and capacity.
62    ///
63    /// # Arguments
64    /// * `ttl_secs` - Time-to-live for cached entries in seconds
65    /// * `max_entries` - Maximum number of entries to cache
66    pub fn new(ttl_secs: u64, max_entries: usize) -> Self {
67        Self {
68            inner: RwLock::new(LruCache::new(
69                std::num::NonZeroUsize::new(max_entries).unwrap_or(std::num::NonZeroUsize::MIN),
70            )),
71            ttl: Duration::from_secs(ttl_secs),
72            max_entries,
73            hits: RwLock::new(0),
74            misses: RwLock::new(0),
75        }
76    }
77
78    /// Hash content to cache key.
79    pub fn content_hash(content: &str) -> u64 {
80        use std::collections::hash_map::DefaultHasher;
81        let mut hasher = DefaultHasher::new();
82        content.hash(&mut hasher);
83        hasher.finish()
84    }
85
86    /// Get cached embedding if exists and not expired.
87    pub fn get(&self, content: &str) -> Option<Vec<f32>> {
88        let key = Self::content_hash(content);
89        let mut inner = self.inner.write();
90
91        match inner.get(&key) {
92            Some(entry) if !entry.is_expired() => {
93                *self.hits.write() += 1;
94                Some(entry.value.clone())
95            }
96            Some(_) => {
97                // Expired — remove
98                inner.pop(&key);
99                *self.misses.write() += 1;
100                None
101            }
102            None => {
103                *self.misses.write() += 1;
104                None
105            }
106        }
107    }
108
109    /// Cache an embedding.
110    pub fn insert(&self, content: &str, embedding: Vec<f32>) {
111        let key = Self::content_hash(content);
112        let mut inner = self.inner.write();
113
114        inner.push(
115            key,
116            CacheEntry {
117                value: embedding,
118                created_at: Instant::now(),
119                ttl: self.ttl,
120            },
121        );
122    }
123
124    /// Evict expired entries.
125    ///
126    /// Returns the number of entries evicted.
127    pub fn evict_expired(&self) -> usize {
128        let mut inner = self.inner.write();
129        let mut evicted = 0;
130
131        let keys: Vec<_> = inner
132            .iter()
133            .filter(|(_, entry)| entry.is_expired())
134            .map(|(k, _)| *k)
135            .collect();
136
137        for key in keys {
138            inner.pop(&key);
139            evicted += 1;
140        }
141
142        evicted
143    }
144
145    /// Evict least recently used entries to free space.
146    ///
147    /// Returns the number of entries evicted.
148    pub fn evict_lru(&self, target_size: usize) -> usize {
149        let mut inner = self.inner.write();
150        let mut evicted = 0;
151
152        while inner.len() > target_size {
153            if inner.pop_lru().is_none() {
154                break;
155            }
156            evicted += 1;
157        }
158
159        evicted
160    }
161
162    /// Cache statistics.
163    pub fn stats(&self) -> CacheStats {
164        let hits = *self.hits.read();
165        let misses = *self.misses.read();
166        let total = hits + misses;
167
168        CacheStats {
169            hits,
170            misses,
171            hit_rate: if total > 0 { hits as f64 / total as f64 } else { 0.0 },
172            size: self.inner.read().len(),
173            capacity: self.max_entries,
174        }
175    }
176
177    /// Clear the cache.
178    pub fn clear(&self) {
179        self.inner.write().clear();
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use std::thread;
187    use std::time::Duration;
188
189    #[test]
190    fn test_cache_basic() {
191        let cache = EmbeddingCache::new(60, 100);
192        
193        // Insert
194        cache.insert("hello", vec![1.0, 2.0, 3.0]);
195        
196        // Get
197        let result = cache.get("hello");
198        assert!(result.is_some());
199        assert_eq!(result.unwrap(), vec![1.0, 2.0, 3.0]);
200        
201        // Stats
202        let stats = cache.stats();
203        assert_eq!(stats.hits, 1);
204        assert_eq!(stats.misses, 0);
205    }
206
207    #[test]
208    fn test_cache_miss() {
209        let cache = EmbeddingCache::new(60, 100);
210        
211        let result = cache.get("nonexistent");
212        assert!(result.is_none());
213        
214        let stats = cache.stats();
215        assert_eq!(stats.hits, 0);
216        assert_eq!(stats.misses, 1);
217    }
218
219    #[test]
220    fn test_cache_ttl() {
221        let cache = EmbeddingCache::new(1, 100);  // 1 second TTL
222        
223        cache.insert("test", vec![1.0]);
224        assert!(cache.get("test").is_some());
225        
226        // Wait for expiration
227        thread::sleep(Duration::from_secs(2));
228        
229        // Should be expired
230        assert!(cache.get("test").is_none());
231    }
232
233    #[test]
234    fn test_cache_eviction() {
235        let cache = EmbeddingCache::new(60, 2);
236        
237        cache.insert("a", vec![1.0]);
238        cache.insert("b", vec![2.0]);
239        cache.insert("c", vec![3.0]);  // Should evict oldest
240        
241        // a should be evicted
242        assert!(cache.get("a").is_none());
243        
244        // b and c should exist
245        assert!(cache.get("b").is_some());
246        assert!(cache.get("c").is_some());
247    }
248}