1use std::collections::HashMap;
8use std::sync::RwLock;
9
10use crate::posting::{PostingEntry, PostingList};
11use crate::trigram::Trigram;
12
13fn estimate_posting_size(list: &PostingList) -> usize {
14 let mut size: usize = 16;
15 for entry in &list.entries {
16 size += 4 + entry.offsets.len() * 4 + 16;
17 }
18 size += list.entries.capacity() * std::mem::size_of::<PostingEntry>();
19 size
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct CacheStats {
25 pub hits: u64,
27 pub misses: u64,
29 pub evictions: u64,
31}
32
33#[derive(Debug)]
39pub struct PostingCache {
40 map: RwLock<HashMap<Trigram, PostingList>>,
41 order: RwLock<Vec<Trigram>>,
42 memory_used: RwLock<usize>,
43 memory_ceiling: usize,
44 admit: RwLock<bool>,
45 stats: RwLock<CacheStats>,
46}
47
48impl PostingCache {
49 #[must_use]
51 pub fn new(memory_ceiling: usize) -> Self {
52 Self {
53 map: RwLock::new(HashMap::new()),
54 order: RwLock::new(Vec::new()),
55 memory_used: RwLock::new(0),
56 memory_ceiling,
57 admit: RwLock::new(true),
58 stats: RwLock::new(CacheStats {
59 hits: 0,
60 misses: 0,
61 evictions: 0,
62 }),
63 }
64 }
65
66 pub fn set_admit(&self, allow: bool) {
75 *self
76 .admit
77 .write()
78 .unwrap_or_else(std::sync::PoisonError::into_inner) = allow;
79 }
80
81 pub fn get(&self, trigram: Trigram) -> Option<PostingList> {
85 let map = self
86 .map
87 .read()
88 .unwrap_or_else(std::sync::PoisonError::into_inner);
89 if let Some(list) = map.get(&trigram) {
90 let result = list.clone();
91 drop(map);
92 let mut stats = self
93 .stats
94 .write()
95 .unwrap_or_else(std::sync::PoisonError::into_inner);
96 stats.hits += 1;
97 Some(result)
98 } else {
99 drop(map);
100 let mut stats = self
101 .stats
102 .write()
103 .unwrap_or_else(std::sync::PoisonError::into_inner);
104 stats.misses += 1;
105 None
106 }
107 }
108
109 pub fn insert(&self, trigram: Trigram, list: PostingList) {
113 if !*self
114 .admit
115 .read()
116 .unwrap_or_else(std::sync::PoisonError::into_inner)
117 {
118 return;
119 }
120
121 let entry_size = estimate_posting_size(&list);
122
123 if self.memory_ceiling == 0 {
124 return;
125 }
126
127 let existing_size = {
128 let map = self
129 .map
130 .read()
131 .unwrap_or_else(std::sync::PoisonError::into_inner);
132 map.get(&trigram).map_or(0, estimate_posting_size)
133 };
134
135 if existing_size > 0 {
136 let mut map = self
137 .map
138 .write()
139 .unwrap_or_else(std::sync::PoisonError::into_inner);
140 let mut order = self
141 .order
142 .write()
143 .unwrap_or_else(std::sync::PoisonError::into_inner);
144 let mut mem = self
145 .memory_used
146 .write()
147 .unwrap_or_else(std::sync::PoisonError::into_inner);
148 map.remove(&trigram);
149 order.retain(|t| *t != trigram);
150 *mem = mem.saturating_sub(existing_size);
151 map.insert(trigram, list);
152 order.push(trigram);
153 *mem += entry_size;
154 return;
155 }
156
157 {
158 let mut mem = self
159 .memory_used
160 .write()
161 .unwrap_or_else(std::sync::PoisonError::into_inner);
162 let mut map = self
163 .map
164 .write()
165 .unwrap_or_else(std::sync::PoisonError::into_inner);
166 let mut order = self
167 .order
168 .write()
169 .unwrap_or_else(std::sync::PoisonError::into_inner);
170 let mut stats = self
171 .stats
172 .write()
173 .unwrap_or_else(std::sync::PoisonError::into_inner);
174
175 while *mem + entry_size > self.memory_ceiling {
176 if let Some(evict) = order.first().copied() {
177 if evict == trigram {
178 break;
179 }
180 order.remove(0);
181 if let Some(removed) = map.remove(&evict) {
182 let removed_size = estimate_posting_size(&removed);
183 *mem = mem.saturating_sub(removed_size);
184 }
185 stats.evictions += 1;
186 } else {
187 break;
188 }
189 }
190
191 map.insert(trigram, list);
192 order.push(trigram);
193 *mem += entry_size;
194 }
195 }
196
197 pub fn invalidate(&self, trigram: Trigram) {
201 let mut map = self
202 .map
203 .write()
204 .unwrap_or_else(std::sync::PoisonError::into_inner);
205 let mut order = self
206 .order
207 .write()
208 .unwrap_or_else(std::sync::PoisonError::into_inner);
209 let mut mem = self
210 .memory_used
211 .write()
212 .unwrap_or_else(std::sync::PoisonError::into_inner);
213 if let Some(removed) = map.remove(&trigram) {
214 let removed_size = estimate_posting_size(&removed);
215 *mem = mem.saturating_sub(removed_size);
216 order.retain(|t| *t != trigram);
217 }
218 }
219
220 pub fn invalidate_all(&self) {
224 let mut map = self
225 .map
226 .write()
227 .unwrap_or_else(std::sync::PoisonError::into_inner);
228 let mut order = self
229 .order
230 .write()
231 .unwrap_or_else(std::sync::PoisonError::into_inner);
232 let mut mem = self
233 .memory_used
234 .write()
235 .unwrap_or_else(std::sync::PoisonError::into_inner);
236 map.clear();
237 order.clear();
238 *mem = 0;
239 }
240
241 #[must_use]
245 pub fn len(&self) -> usize {
246 self.map
247 .read()
248 .unwrap_or_else(std::sync::PoisonError::into_inner)
249 .len()
250 }
251
252 #[must_use]
256 pub fn is_empty(&self) -> bool {
257 self.len() == 0
258 }
259
260 #[must_use]
264 pub fn stats(&self) -> CacheStats {
265 self.stats
266 .read()
267 .unwrap_or_else(std::sync::PoisonError::into_inner)
268 .clone()
269 }
270
271 #[must_use]
275 pub fn memory_used(&self) -> usize {
276 *self
277 .memory_used
278 .read()
279 .unwrap_or_else(std::sync::PoisonError::into_inner)
280 }
281}
282
283impl Default for PostingCache {
284 fn default() -> Self {
285 Self::new(64 * 1024 * 1024)
286 }
287}
288
289#[cfg(test)]
290#[allow(clippy::as_conversions, clippy::unwrap_used, clippy::indexing_slicing)]
291mod tests {
292 use super::*;
293
294 fn trigram(a: u8, b: u8, c: u8) -> Trigram {
295 crate::trigram::from_bytes(a, b, c)
296 }
297
298 #[test]
299 fn basic_insert_get() -> Result<(), Box<dyn std::error::Error>> {
300 let cache = PostingCache::new(1024 * 1024);
301 let t = trigram(b'a', b'b', b'c');
302 let list = PostingList {
303 entries: vec![PostingEntry {
304 file_id: 1,
305 offsets: vec![10, 20],
306 }],
307 };
308
309 assert!(cache.get(t).is_none());
310 let stats = cache.stats();
311 assert_eq!(stats.misses, 1);
312 assert_eq!(stats.hits, 0);
313
314 cache.insert(t, list.clone());
315 let result = cache.get(t);
316 assert!(result.is_some());
317 assert_eq!(result.ok_or("expected cached posting")?, list);
318
319 let stats = cache.stats();
320 assert_eq!(stats.hits, 1);
321 assert_eq!(stats.misses, 1);
322 Ok(())
323 }
324
325 #[test]
326 fn eviction_under_memory_pressure() {
327 let small_list = PostingList {
328 entries: vec![PostingEntry {
329 file_id: 0,
330 offsets: vec![10, 20],
331 }],
332 };
333 let entry_size = estimate_posting_size(&small_list);
334 let ceiling = entry_size * 2 + 10;
335 let cache = PostingCache::new(ceiling);
336
337 let t1 = trigram(b'a', b'b', b'c');
338 let t2 = trigram(b'd', b'e', b'f');
339 let t3 = trigram(b'g', b'h', b'i');
340
341 cache.insert(t1, small_list.clone());
342 cache.insert(t2, small_list.clone());
343 assert!(cache.get(t1).is_some());
344 assert!(cache.get(t2).is_some());
345
346 cache.insert(t3, small_list);
347 let stats = cache.stats();
348 assert!(stats.evictions > 0, "evictions should have occurred");
349 assert!(cache.len() <= 2);
350 }
351
352 #[test]
353 fn invalidate_all() {
354 let cache = PostingCache::new(1024 * 1024);
355 let t1 = trigram(b'a', b'b', b'c');
356 let t2 = trigram(b'd', b'e', b'f');
357
358 let list = PostingList {
359 entries: vec![PostingEntry {
360 file_id: 1,
361 offsets: vec![10],
362 }],
363 };
364
365 cache.insert(t1, list.clone());
366 cache.insert(t2, list);
367 assert!(cache.get(t1).is_some());
368 assert!(cache.get(t2).is_some());
369
370 cache.invalidate_all();
371 assert!(cache.get(t1).is_none());
372 assert!(cache.get(t2).is_none());
373 assert_eq!(cache.memory_used(), 0);
374 }
375
376 #[test]
377 fn invalidate_single() {
378 let cache = PostingCache::new(1024 * 1024);
379 let t1 = trigram(b'a', b'b', b'c');
380 let t2 = trigram(b'd', b'e', b'f');
381
382 let list = PostingList {
383 entries: vec![PostingEntry {
384 file_id: 1,
385 offsets: vec![10],
386 }],
387 };
388
389 cache.insert(t1, list.clone());
390 cache.insert(t2, list);
391 cache.invalidate(t1);
392 assert!(cache.get(t1).is_none());
393 assert!(cache.get(t2).is_some());
394 }
395
396 #[test]
397 fn zero_ceiling_rejects_all() {
398 let cache = PostingCache::new(0);
399 let t = trigram(b'a', b'b', b'c');
400 let list = PostingList {
401 entries: vec![PostingEntry {
402 file_id: 1,
403 offsets: vec![10],
404 }],
405 };
406
407 cache.insert(t, list);
408 assert!(cache.get(t).is_none());
409 }
410}