oxios_kernel/memory/
embedding_cache.rs1use lru::LruCache;
17use parking_lot::RwLock;
18use serde::{Deserialize, Serialize};
19use std::hash::{Hash, Hasher};
20use std::time::{Duration, Instant};
21
22struct CacheEntry<V> {
24 value: V,
25 created_at: Instant,
26 ttl: Duration,
27}
28
29impl<V> CacheEntry<V> {
30 fn is_expired(&self) -> bool {
31 self.created_at.elapsed() > self.ttl
32 }
33}
34
35pub struct EmbeddingCache {
37 inner: RwLock<LruCache<u64, CacheEntry<Vec<f32>>>>,
38 ttl: Duration,
39 max_entries: usize,
40 hits: RwLock<u64>,
41 misses: RwLock<u64>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct CacheStats {
47 pub hits: u64,
49 pub misses: u64,
51 pub hit_rate: f64,
53 pub size: usize,
55 pub capacity: usize,
57}
58
59impl EmbeddingCache {
60 pub fn new(ttl_secs: u64, max_entries: usize) -> Self {
66 Self {
67 inner: RwLock::new(LruCache::new(
68 std::num::NonZeroUsize::new(max_entries).unwrap_or(std::num::NonZeroUsize::MIN),
69 )),
70 ttl: Duration::from_secs(ttl_secs),
71 max_entries,
72 hits: RwLock::new(0),
73 misses: RwLock::new(0),
74 }
75 }
76
77 pub fn content_hash(content: &str) -> u64 {
79 use std::collections::hash_map::DefaultHasher;
80 let mut hasher = DefaultHasher::new();
81 content.hash(&mut hasher);
82 hasher.finish()
83 }
84
85 pub fn get(&self, content: &str) -> Option<Vec<f32>> {
87 let key = Self::content_hash(content);
88 let mut inner = self.inner.write();
89
90 match inner.get(&key) {
91 Some(entry) if !entry.is_expired() => {
92 *self.hits.write() += 1;
93 Some(entry.value.clone())
94 }
95 Some(_) => {
96 inner.pop(&key);
98 *self.misses.write() += 1;
99 None
100 }
101 None => {
102 *self.misses.write() += 1;
103 None
104 }
105 }
106 }
107
108 pub fn insert(&self, content: &str, embedding: Vec<f32>) {
110 let key = Self::content_hash(content);
111 let mut inner = self.inner.write();
112
113 inner.push(
114 key,
115 CacheEntry {
116 value: embedding,
117 created_at: Instant::now(),
118 ttl: self.ttl,
119 },
120 );
121 }
122
123 pub fn evict_expired(&self) -> usize {
127 let mut inner = self.inner.write();
128 let mut evicted = 0;
129
130 let keys: Vec<_> = inner
131 .iter()
132 .filter(|(_, entry)| entry.is_expired())
133 .map(|(k, _)| *k)
134 .collect();
135
136 for key in keys {
137 inner.pop(&key);
138 evicted += 1;
139 }
140
141 evicted
142 }
143
144 pub fn evict_lru(&self, target_size: usize) -> usize {
148 let mut inner = self.inner.write();
149 let mut evicted = 0;
150
151 while inner.len() > target_size {
152 if inner.pop_lru().is_none() {
153 break;
154 }
155 evicted += 1;
156 }
157
158 evicted
159 }
160
161 pub fn stats(&self) -> CacheStats {
163 let hits = *self.hits.read();
164 let misses = *self.misses.read();
165 let total = hits + misses;
166
167 CacheStats {
168 hits,
169 misses,
170 hit_rate: if total > 0 {
171 hits as f64 / total as f64
172 } else {
173 0.0
174 },
175 size: self.inner.read().len(),
176 capacity: self.max_entries,
177 }
178 }
179
180 pub fn clear(&self) {
182 self.inner.write().clear();
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use std::thread;
190 use std::time::Duration;
191
192 #[test]
193 fn test_cache_basic() {
194 let cache = EmbeddingCache::new(60, 100);
195
196 cache.insert("hello", vec![1.0, 2.0, 3.0]);
198
199 let result = cache.get("hello");
201 assert!(result.is_some());
202 assert_eq!(result.unwrap(), vec![1.0, 2.0, 3.0]);
203
204 let stats = cache.stats();
206 assert_eq!(stats.hits, 1);
207 assert_eq!(stats.misses, 0);
208 }
209
210 #[test]
211 fn test_cache_miss() {
212 let cache = EmbeddingCache::new(60, 100);
213
214 let result = cache.get("nonexistent");
215 assert!(result.is_none());
216
217 let stats = cache.stats();
218 assert_eq!(stats.hits, 0);
219 assert_eq!(stats.misses, 1);
220 }
221
222 #[test]
223 fn test_cache_ttl() {
224 let cache = EmbeddingCache::new(1, 100); cache.insert("test", vec![1.0]);
227 assert!(cache.get("test").is_some());
228
229 thread::sleep(Duration::from_secs(2));
231
232 assert!(cache.get("test").is_none());
234 }
235
236 #[test]
237 fn test_cache_eviction() {
238 let cache = EmbeddingCache::new(60, 2);
239
240 cache.insert("a", vec![1.0]);
241 cache.insert("b", vec![2.0]);
242 cache.insert("c", vec![3.0]); assert!(cache.get("a").is_none());
246
247 assert!(cache.get("b").is_some());
249 assert!(cache.get("c").is_some());
250 }
251}