oxios_memory/memory/
embedding_cache.rs1use lru::LruCache;
17use parking_lot::RwLock;
18use serde::{Deserialize, Serialize};
19use std::time::{Duration, Instant};
20
21struct CacheEntry<V> {
23 value: V,
24 created_at: Instant,
25 ttl: Duration,
26}
27
28impl<V> CacheEntry<V> {
29 fn is_expired(&self) -> bool {
30 self.created_at.elapsed() > self.ttl
31 }
32}
33
34pub struct EmbeddingCache {
36 inner: RwLock<LruCache<u64, CacheEntry<Vec<f32>>>>,
37 ttl: Duration,
38 max_entries: usize,
39 hits: RwLock<u64>,
40 misses: RwLock<u64>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct CacheStats {
46 pub hits: u64,
48 pub misses: u64,
50 pub hit_rate: f64,
52 pub size: usize,
54 pub capacity: usize,
56}
57
58impl EmbeddingCache {
59 pub fn new(ttl_secs: u64, max_entries: usize) -> Self {
65 Self {
66 inner: RwLock::new(LruCache::new(
67 std::num::NonZeroUsize::new(max_entries).unwrap_or(std::num::NonZeroUsize::MIN),
68 )),
69 ttl: Duration::from_secs(ttl_secs),
70 max_entries,
71 hits: RwLock::new(0),
72 misses: RwLock::new(0),
73 }
74 }
75
76 pub fn content_hash(content: &str) -> u64 {
81 crate::memory::types::content_hash(content)
82 }
83
84 pub fn get(&self, content: &str) -> Option<Vec<f32>> {
86 let key = Self::content_hash(content);
87 let mut inner = self.inner.write();
88
89 match inner.get(&key) {
90 Some(entry) if !entry.is_expired() => {
91 *self.hits.write() += 1;
92 Some(entry.value.clone())
93 }
94 Some(_) => {
95 inner.pop(&key);
97 *self.misses.write() += 1;
98 None
99 }
100 None => {
101 *self.misses.write() += 1;
102 None
103 }
104 }
105 }
106
107 pub fn insert(&self, content: &str, embedding: Vec<f32>) {
109 let key = Self::content_hash(content);
110 let mut inner = self.inner.write();
111
112 inner.push(
113 key,
114 CacheEntry {
115 value: embedding,
116 created_at: Instant::now(),
117 ttl: self.ttl,
118 },
119 );
120 }
121
122 pub fn evict_expired(&self) -> usize {
126 let mut inner = self.inner.write();
127 let mut evicted = 0;
128
129 let keys: Vec<_> = inner
130 .iter()
131 .filter(|(_, entry)| entry.is_expired())
132 .map(|(k, _)| *k)
133 .collect();
134
135 for key in keys {
136 inner.pop(&key);
137 evicted += 1;
138 }
139
140 evicted
141 }
142
143 pub fn evict_lru(&self, target_size: usize) -> usize {
147 let mut inner = self.inner.write();
148 let mut evicted = 0;
149
150 while inner.len() > target_size {
151 if inner.pop_lru().is_none() {
152 break;
153 }
154 evicted += 1;
155 }
156
157 evicted
158 }
159
160 pub fn stats(&self) -> CacheStats {
162 let hits = *self.hits.read();
163 let misses = *self.misses.read();
164 let total = hits + misses;
165
166 CacheStats {
167 hits,
168 misses,
169 hit_rate: if total > 0 {
170 hits as f64 / total as f64
171 } else {
172 0.0
173 },
174 size: self.inner.read().len(),
175 capacity: self.max_entries,
176 }
177 }
178
179 pub fn clear(&self) {
181 self.inner.write().clear();
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use std::thread;
189 use std::time::Duration;
190
191 #[test]
192 fn test_cache_basic() {
193 let cache = EmbeddingCache::new(60, 100);
194
195 cache.insert("hello", vec![1.0, 2.0, 3.0]);
197
198 let result = cache.get("hello");
200 assert!(result.is_some());
201 assert_eq!(result.unwrap(), vec![1.0, 2.0, 3.0]);
202
203 let stats = cache.stats();
205 assert_eq!(stats.hits, 1);
206 assert_eq!(stats.misses, 0);
207 }
208
209 #[test]
210 fn test_cache_miss() {
211 let cache = EmbeddingCache::new(60, 100);
212
213 let result = cache.get("nonexistent");
214 assert!(result.is_none());
215
216 let stats = cache.stats();
217 assert_eq!(stats.hits, 0);
218 assert_eq!(stats.misses, 1);
219 }
220
221 #[test]
222 fn test_cache_ttl() {
223 let cache = EmbeddingCache::new(1, 100); cache.insert("test", vec![1.0]);
226 assert!(cache.get("test").is_some());
227
228 thread::sleep(Duration::from_millis(1_100));
230
231 assert!(cache.get("test").is_none());
233 }
234
235 #[test]
236 fn test_cache_eviction() {
237 let cache = EmbeddingCache::new(60, 2);
238
239 cache.insert("a", vec![1.0]);
240 cache.insert("b", vec![2.0]);
241 cache.insert("c", vec![3.0]); assert!(cache.get("a").is_none());
245
246 assert!(cache.get("b").is_some());
248 assert!(cache.get("c").is_some());
249 }
250}