Skip to main content

oxios_kernel/memory/
embedding_cache.rs

1//! Embedding cache for reducing API calls.
2//!
3//! Provides LRU cache with TTL for embedding vectors.
4//!
5//! # Example
6//!
7//! ```
8//! use oxios_kernel::memory::EmbeddingCache;
9//!
10//! let cache = EmbeddingCache::new(3600, 10000);  // 1 hour TTL, 10k max
11//! cache.insert("hello", vec![1.0, 2.0, 3.0]);
12//! let embedded = cache.get("hello");
13//! assert!(embedded.is_some());
14//! ```
15
16use lru::LruCache;
17use parking_lot::RwLock;
18use serde::{Deserialize, Serialize};
19use std::hash::{Hash, Hasher};
20use std::time::{Duration, Instant};
21
22/// Cache entry with TTL tracking.
23struct CacheEntry<V> {
24    value: V,
25    created_at: Instant,
26    ttl: Duration,
27}
28
29impl<V> CacheEntry<V> {
30    fn is_expired(&self) -> bool {
31        self.created_at.elapsed() > self.ttl
32    }
33}
34
35/// Content-addressable embedding cache with TTL and LRU eviction.
36pub struct EmbeddingCache {
37    inner: RwLock<LruCache<u64, CacheEntry<Vec<f32>>>>,
38    ttl: Duration,
39    max_entries: usize,
40    hits: RwLock<u64>,
41    misses: RwLock<u64>,
42}
43
44/// Cache statistics for monitoring.
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct CacheStats {
47    /// Number of cache hits.
48    pub hits: u64,
49    /// Number of cache misses.
50    pub misses: u64,
51    /// Hit rate as a fraction (0.0 to 1.0).
52    pub hit_rate: f64,
53    /// Current number of entries in cache.
54    pub size: usize,
55    /// Maximum capacity of cache.
56    pub capacity: usize,
57}
58
59impl EmbeddingCache {
60    /// Create a new cache with TTL and capacity.
61    ///
62    /// # Arguments
63    /// * `ttl_secs` - Time-to-live for cached entries in seconds
64    /// * `max_entries` - Maximum number of entries to cache
65    pub fn new(ttl_secs: u64, max_entries: usize) -> Self {
66        Self {
67            inner: RwLock::new(LruCache::new(
68                std::num::NonZeroUsize::new(max_entries).unwrap_or(std::num::NonZeroUsize::MIN),
69            )),
70            ttl: Duration::from_secs(ttl_secs),
71            max_entries,
72            hits: RwLock::new(0),
73            misses: RwLock::new(0),
74        }
75    }
76
77    /// Hash content to cache key.
78    pub fn content_hash(content: &str) -> u64 {
79        use std::collections::hash_map::DefaultHasher;
80        let mut hasher = DefaultHasher::new();
81        content.hash(&mut hasher);
82        hasher.finish()
83    }
84
85    /// Get cached embedding if exists and not expired.
86    pub fn get(&self, content: &str) -> Option<Vec<f32>> {
87        let key = Self::content_hash(content);
88        let mut inner = self.inner.write();
89
90        match inner.get(&key) {
91            Some(entry) if !entry.is_expired() => {
92                *self.hits.write() += 1;
93                Some(entry.value.clone())
94            }
95            Some(_) => {
96                // Expired — remove
97                inner.pop(&key);
98                *self.misses.write() += 1;
99                None
100            }
101            None => {
102                *self.misses.write() += 1;
103                None
104            }
105        }
106    }
107
108    /// Cache an embedding.
109    pub fn insert(&self, content: &str, embedding: Vec<f32>) {
110        let key = Self::content_hash(content);
111        let mut inner = self.inner.write();
112
113        inner.push(
114            key,
115            CacheEntry {
116                value: embedding,
117                created_at: Instant::now(),
118                ttl: self.ttl,
119            },
120        );
121    }
122
123    /// Evict expired entries.
124    ///
125    /// Returns the number of entries evicted.
126    pub fn evict_expired(&self) -> usize {
127        let mut inner = self.inner.write();
128        let mut evicted = 0;
129
130        let keys: Vec<_> = inner
131            .iter()
132            .filter(|(_, entry)| entry.is_expired())
133            .map(|(k, _)| *k)
134            .collect();
135
136        for key in keys {
137            inner.pop(&key);
138            evicted += 1;
139        }
140
141        evicted
142    }
143
144    /// Evict least recently used entries to free space.
145    ///
146    /// Returns the number of entries evicted.
147    pub fn evict_lru(&self, target_size: usize) -> usize {
148        let mut inner = self.inner.write();
149        let mut evicted = 0;
150
151        while inner.len() > target_size {
152            if inner.pop_lru().is_none() {
153                break;
154            }
155            evicted += 1;
156        }
157
158        evicted
159    }
160
161    /// Cache statistics.
162    pub fn stats(&self) -> CacheStats {
163        let hits = *self.hits.read();
164        let misses = *self.misses.read();
165        let total = hits + misses;
166
167        CacheStats {
168            hits,
169            misses,
170            hit_rate: if total > 0 {
171                hits as f64 / total as f64
172            } else {
173                0.0
174            },
175            size: self.inner.read().len(),
176            capacity: self.max_entries,
177        }
178    }
179
180    /// Clear the cache.
181    pub fn clear(&self) {
182        self.inner.write().clear();
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use std::thread;
190    use std::time::Duration;
191
192    #[test]
193    fn test_cache_basic() {
194        let cache = EmbeddingCache::new(60, 100);
195
196        // Insert
197        cache.insert("hello", vec![1.0, 2.0, 3.0]);
198
199        // Get
200        let result = cache.get("hello");
201        assert!(result.is_some());
202        assert_eq!(result.unwrap(), vec![1.0, 2.0, 3.0]);
203
204        // Stats
205        let stats = cache.stats();
206        assert_eq!(stats.hits, 1);
207        assert_eq!(stats.misses, 0);
208    }
209
210    #[test]
211    fn test_cache_miss() {
212        let cache = EmbeddingCache::new(60, 100);
213
214        let result = cache.get("nonexistent");
215        assert!(result.is_none());
216
217        let stats = cache.stats();
218        assert_eq!(stats.hits, 0);
219        assert_eq!(stats.misses, 1);
220    }
221
222    #[test]
223    fn test_cache_ttl() {
224        let cache = EmbeddingCache::new(1, 100); // 1 second TTL
225
226        cache.insert("test", vec![1.0]);
227        assert!(cache.get("test").is_some());
228
229        // Wait for expiration
230        thread::sleep(Duration::from_secs(2));
231
232        // Should be expired
233        assert!(cache.get("test").is_none());
234    }
235
236    #[test]
237    fn test_cache_eviction() {
238        let cache = EmbeddingCache::new(60, 2);
239
240        cache.insert("a", vec![1.0]);
241        cache.insert("b", vec![2.0]);
242        cache.insert("c", vec![3.0]); // Should evict oldest
243
244        // a should be evicted
245        assert!(cache.get("a").is_none());
246
247        // b and c should exist
248        assert!(cache.get("b").is_some());
249        assert!(cache.get("c").is_some());
250    }
251}