mentedb_embedding/
cache.rs1use std::collections::{HashMap, VecDeque};
4
5#[derive(Debug, Clone, Default)]
7pub struct CacheStats {
8 pub hits: u64,
9 pub misses: u64,
10 pub size: usize,
11 pub max_size: usize,
12 pub evictions: u64,
13}
14
15#[derive(Debug, Clone)]
17pub struct CachedEmbedding {
18 pub embedding: Vec<f32>,
19 pub created_at: u64,
20 pub hit_count: u32,
21}
22
23pub struct EmbeddingCache {
25 max_size: usize,
26 cache: HashMap<u64, CachedEmbedding>,
27 order: VecDeque<u64>,
28 hits: u64,
29 misses: u64,
30 evictions: u64,
31}
32
33impl EmbeddingCache {
34 pub fn new(max_size: usize) -> Self {
36 Self {
37 max_size,
38 cache: HashMap::with_capacity(max_size.min(1024)),
39 order: VecDeque::with_capacity(max_size.min(1024)),
40 hits: 0,
41 misses: 0,
42 evictions: 0,
43 }
44 }
45
46 pub fn default_size() -> Self {
48 Self::new(10_000)
49 }
50
51 fn cache_key(model: &str, text: &str) -> u64 {
53 let mut hash: u64 = 0xcbf29ce484222325;
55 let prime: u64 = 0x100000001b3;
56
57 for byte in model.as_bytes() {
58 hash ^= *byte as u64;
59 hash = hash.wrapping_mul(prime);
60 }
61 hash ^= 0xff;
63 hash = hash.wrapping_mul(prime);
64
65 for byte in text.as_bytes() {
66 hash ^= *byte as u64;
67 hash = hash.wrapping_mul(prime);
68 }
69
70 hash
71 }
72
73 pub fn get(&mut self, text: &str, model: &str) -> Option<&[f32]> {
75 let key = Self::cache_key(model, text);
76
77 if self.cache.contains_key(&key) {
78 self.hits += 1;
79
80 self.order.retain(|k| *k != key);
82 self.order.push_back(key);
83
84 let entry = self.cache.get_mut(&key).unwrap();
85 entry.hit_count += 1;
86 Some(&entry.embedding)
87 } else {
88 self.misses += 1;
89 None
90 }
91 }
92
93 pub fn put(&mut self, text: &str, model: &str, embedding: Vec<f32>) {
95 let key = Self::cache_key(model, text);
96
97 if self.cache.contains_key(&key) {
99 self.order.retain(|k| *k != key);
100 self.order.push_back(key);
101 self.cache.insert(
102 key,
103 CachedEmbedding {
104 embedding,
105 created_at: Self::now_micros(),
106 hit_count: 0,
107 },
108 );
109 return;
110 }
111
112 while self.cache.len() >= self.max_size {
114 if let Some(evict_key) = self.order.pop_front() {
115 self.cache.remove(&evict_key);
116 self.evictions += 1;
117 } else {
118 break;
119 }
120 }
121
122 self.cache.insert(
123 key,
124 CachedEmbedding {
125 embedding,
126 created_at: Self::now_micros(),
127 hit_count: 0,
128 },
129 );
130 self.order.push_back(key);
131 }
132
133 pub fn stats(&self) -> CacheStats {
135 CacheStats {
136 hits: self.hits,
137 misses: self.misses,
138 size: self.cache.len(),
139 max_size: self.max_size,
140 evictions: self.evictions,
141 }
142 }
143
144 pub fn clear(&mut self) {
146 self.cache.clear();
147 self.order.clear();
148 }
149
150 fn now_micros() -> u64 {
151 std::time::SystemTime::now()
152 .duration_since(std::time::UNIX_EPOCH)
153 .unwrap_or_default()
154 .as_micros() as u64
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn test_cache_hit_miss() {
164 let mut cache = EmbeddingCache::new(10);
165 assert!(cache.get("hello", "model").is_none());
166 assert_eq!(cache.stats().misses, 1);
167
168 cache.put("hello", "model", vec![1.0, 2.0, 3.0]);
169 let result = cache.get("hello", "model");
170 assert!(result.is_some());
171 assert_eq!(result.unwrap(), &[1.0, 2.0, 3.0]);
172 assert_eq!(cache.stats().hits, 1);
173 }
174
175 #[test]
176 fn test_lru_eviction() {
177 let mut cache = EmbeddingCache::new(3);
178
179 cache.put("a", "m", vec![1.0]);
180 cache.put("b", "m", vec![2.0]);
181 cache.put("c", "m", vec![3.0]);
182
183 cache.put("d", "m", vec![4.0]);
185
186 assert!(cache.get("a", "m").is_none());
187 assert!(cache.get("b", "m").is_some());
188 assert!(cache.get("c", "m").is_some());
189 assert!(cache.get("d", "m").is_some());
190 assert_eq!(cache.stats().evictions, 1);
191 }
192}