1use std::collections::HashMap;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::{Arc, Mutex};
15
16#[derive(Debug, Clone)]
18pub struct EmbeddingCacheStats {
19 pub hits: u64,
21 pub misses: u64,
23 pub entries: usize,
25 pub bytes_used: usize,
27 pub max_bytes: usize,
29 pub hit_rate: f64,
31}
32
33struct LruNode {
35 embedding: Arc<[f32]>,
37 size_bytes: usize,
39 prev: Option<String>,
41 next: Option<String>,
43}
44
45struct CacheState {
47 entries: HashMap<String, LruNode>,
49 head: Option<String>,
51 tail: Option<String>,
53 bytes_used: usize,
55}
56
57impl CacheState {
58 fn new() -> Self {
59 Self {
60 entries: HashMap::new(),
61 head: None,
62 tail: None,
63 bytes_used: 0,
64 }
65 }
66
67 fn move_to_front(&mut self, key: &str) {
69 if self.head.as_deref() == Some(key) {
70 return; }
72
73 if let Some(node) = self.entries.get(key) {
75 let prev = node.prev.clone();
76 let next = node.next.clone();
77
78 if let Some(ref prev_key) = prev {
80 if let Some(prev_node) = self.entries.get_mut(prev_key) {
81 prev_node.next = next.clone();
82 }
83 }
84 if let Some(ref next_key) = next {
85 if let Some(next_node) = self.entries.get_mut(next_key) {
86 next_node.prev = prev.clone();
87 }
88 }
89
90 if self.tail.as_deref() == Some(key) {
92 self.tail = prev;
93 }
94 }
95
96 if let Some(node) = self.entries.get_mut(key) {
98 node.prev = None;
99 node.next = self.head.clone();
100 }
101
102 if let Some(ref old_head) = self.head {
103 if let Some(head_node) = self.entries.get_mut(old_head) {
104 head_node.prev = Some(key.to_string());
105 }
106 }
107
108 self.head = Some(key.to_string());
109
110 if self.tail.is_none() {
111 self.tail = self.head.clone();
112 }
113 }
114
115 fn evict_lru(&mut self) -> Option<usize> {
117 let tail_key = self.tail.take()?;
118
119 if let Some(node) = self.entries.remove(&tail_key) {
120 self.tail = node.prev.clone();
122 if let Some(ref new_tail_key) = self.tail {
123 if let Some(new_tail) = self.entries.get_mut(new_tail_key) {
124 new_tail.next = None;
125 }
126 }
127
128 if self.head.as_deref() == Some(&tail_key) {
130 self.head = None;
131 }
132
133 self.bytes_used -= node.size_bytes;
134 return Some(node.size_bytes);
135 }
136
137 None
138 }
139}
140
141pub struct EmbeddingCache {
143 state: Mutex<CacheState>,
145 max_bytes: usize,
147 hits: AtomicU64,
149 misses: AtomicU64,
151}
152
153impl EmbeddingCache {
154 pub fn new(max_bytes: usize) -> Self {
161 Self {
162 state: Mutex::new(CacheState::new()),
163 max_bytes,
164 hits: AtomicU64::new(0),
165 misses: AtomicU64::new(0),
166 }
167 }
168
169 pub fn default_capacity() -> Self {
171 Self::new(100 * 1024 * 1024) }
173
174 pub fn get(&self, key: &str) -> Option<Arc<[f32]>> {
178 let mut state = self.state.lock().unwrap();
179
180 if state.entries.contains_key(key) {
181 state.move_to_front(key);
182 self.hits.fetch_add(1, Ordering::Relaxed);
183 state.entries.get(key).map(|n| n.embedding.clone())
184 } else {
185 self.misses.fetch_add(1, Ordering::Relaxed);
186 None
187 }
188 }
189
190 pub fn put(&self, key: String, embedding: Vec<f32>) {
195 let size_bytes = embedding.len() * std::mem::size_of::<f32>();
196
197 if size_bytes > self.max_bytes {
199 return;
200 }
201
202 let arc: Arc<[f32]> = embedding.into();
203 let mut state = self.state.lock().unwrap();
204
205 if let Some(old_node) = state.entries.remove(&key) {
207 state.bytes_used -= old_node.size_bytes;
208
209 if let Some(ref prev_key) = old_node.prev {
211 if let Some(prev_node) = state.entries.get_mut(prev_key) {
212 prev_node.next = old_node.next.clone();
213 }
214 }
215 if let Some(ref next_key) = old_node.next {
216 if let Some(next_node) = state.entries.get_mut(next_key) {
217 next_node.prev = old_node.prev.clone();
218 }
219 }
220 if state.head.as_deref() == Some(&key) {
221 state.head = old_node.next.clone();
222 }
223 if state.tail.as_deref() == Some(&key) {
224 state.tail = old_node.prev.clone();
225 }
226 }
227
228 while state.bytes_used + size_bytes > self.max_bytes {
230 if state.evict_lru().is_none() {
231 break;
232 }
233 }
234
235 let old_head = state.head.clone();
237 let node = LruNode {
238 embedding: arc,
239 size_bytes,
240 prev: None,
241 next: old_head.clone(),
242 };
243
244 if let Some(ref old_head_key) = old_head {
246 if let Some(head_node) = state.entries.get_mut(old_head_key) {
247 head_node.prev = Some(key.clone());
248 }
249 }
250
251 state.entries.insert(key.clone(), node);
252 state.bytes_used += size_bytes;
253 state.head = Some(key);
254
255 if state.tail.is_none() {
256 state.tail = state.head.clone();
257 }
258 }
259
260 pub fn stats(&self) -> EmbeddingCacheStats {
262 let state = self.state.lock().unwrap();
263 let hits = self.hits.load(Ordering::Relaxed);
264 let misses = self.misses.load(Ordering::Relaxed);
265 let total = hits + misses;
266
267 EmbeddingCacheStats {
268 hits,
269 misses,
270 entries: state.entries.len(),
271 bytes_used: state.bytes_used,
272 max_bytes: self.max_bytes,
273 hit_rate: if total > 0 {
274 (hits as f64 / total as f64) * 100.0
275 } else {
276 0.0
277 },
278 }
279 }
280
281 pub fn clear(&self) {
283 let mut state = self.state.lock().unwrap();
284 state.entries.clear();
285 state.head = None;
286 state.tail = None;
287 state.bytes_used = 0;
288 }
290
291 pub fn len(&self) -> usize {
293 self.state.lock().unwrap().entries.len()
294 }
295
296 pub fn is_empty(&self) -> bool {
298 self.len() == 0
299 }
300}
301
302impl Default for EmbeddingCache {
303 fn default() -> Self {
304 Self::default_capacity()
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_basic_operations() {
314 let cache = EmbeddingCache::new(1024 * 1024); let embedding = vec![1.0, 2.0, 3.0];
318 cache.put("test-key".to_string(), embedding.clone());
319
320 let retrieved = cache.get("test-key").unwrap();
321 assert_eq!(&*retrieved, &[1.0, 2.0, 3.0]);
322
323 assert!(cache.get("nonexistent").is_none());
325
326 let stats = cache.stats();
328 assert_eq!(stats.hits, 1);
329 assert_eq!(stats.misses, 1);
330 assert_eq!(stats.entries, 1);
331 }
332
333 #[test]
334 fn test_lru_eviction() {
335 let cache = EmbeddingCache::new(48);
337
338 cache.put("a".to_string(), vec![1.0, 2.0, 3.0, 4.0]);
340 cache.put("b".to_string(), vec![5.0, 6.0, 7.0, 8.0]);
341 cache.put("c".to_string(), vec![9.0, 10.0, 11.0, 12.0]);
342
343 assert_eq!(cache.len(), 3);
344
345 cache.put("d".to_string(), vec![13.0, 14.0, 15.0, 16.0]);
347
348 assert_eq!(cache.len(), 3);
349 assert!(cache.get("a").is_none()); assert!(cache.get("b").is_some());
351 assert!(cache.get("c").is_some());
352 assert!(cache.get("d").is_some());
353 }
354
355 #[test]
356 fn test_access_updates_lru() {
357 let cache = EmbeddingCache::new(32);
359
360 cache.put("a".to_string(), vec![1.0, 2.0, 3.0, 4.0]);
361 cache.put("b".to_string(), vec![5.0, 6.0, 7.0, 8.0]);
362
363 let _ = cache.get("a");
365
366 cache.put("c".to_string(), vec![9.0, 10.0, 11.0, 12.0]);
368
369 assert!(cache.get("a").is_some()); assert!(cache.get("b").is_none()); assert!(cache.get("c").is_some());
372 }
373
374 #[test]
375 fn test_clear() {
376 let cache = EmbeddingCache::new(1024 * 1024);
377
378 cache.put("a".to_string(), vec![1.0, 2.0, 3.0]);
379 cache.put("b".to_string(), vec![4.0, 5.0, 6.0]);
380
381 assert_eq!(cache.len(), 2);
382
383 cache.clear();
384
385 assert_eq!(cache.len(), 0);
386 assert!(cache.get("a").is_none());
387 assert!(cache.get("b").is_none());
388
389 let stats = cache.stats();
390 assert_eq!(stats.entries, 0);
391 assert_eq!(stats.bytes_used, 0);
392 }
393
394 #[test]
395 fn test_update_existing() {
396 let cache = EmbeddingCache::new(1024 * 1024);
397
398 cache.put("key".to_string(), vec![1.0, 2.0, 3.0]);
399 let v1 = cache.get("key").unwrap();
400 assert_eq!(&*v1, &[1.0, 2.0, 3.0]);
401
402 cache.put("key".to_string(), vec![4.0, 5.0, 6.0, 7.0]);
404 let v2 = cache.get("key").unwrap();
405 assert_eq!(&*v2, &[4.0, 5.0, 6.0, 7.0]);
406
407 assert_eq!(cache.len(), 1);
408 }
409
410 #[test]
411 fn test_zero_copy() {
412 let cache = EmbeddingCache::new(1024 * 1024);
413
414 cache.put("key".to_string(), vec![1.0, 2.0, 3.0]);
415
416 let ref1 = cache.get("key").unwrap();
418 let ref2 = cache.get("key").unwrap();
419
420 assert!(Arc::ptr_eq(&ref1, &ref2));
422 }
423
424 #[test]
425 fn test_stats_tracking() {
426 let cache = EmbeddingCache::new(1024 * 1024);
427
428 let stats = cache.stats();
430 assert_eq!(stats.hits, 0);
431 assert_eq!(stats.misses, 0);
432 assert_eq!(stats.hit_rate, 0.0);
433
434 cache.put("a".to_string(), vec![1.0, 2.0]);
435
436 cache.get("a");
438 cache.get("nonexistent");
440 cache.get("a");
442
443 let stats = cache.stats();
444 assert_eq!(stats.hits, 2);
445 assert_eq!(stats.misses, 1);
446 assert!((stats.hit_rate - 66.666).abs() < 1.0);
447 }
448}