1use std::cell::RefCell;
11use std::collections::HashMap;
12use std::fmt;
13
14use crate::backend::FtsBackend;
15use crate::posting::Posting;
16
17#[derive(Debug)]
19pub struct MemoryError(String);
20
21impl fmt::Display for MemoryError {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 write!(f, "memory backend: {}", self.0)
24 }
25}
26
27type TripleKey = (u32, String, String);
28type PairKey = (u32, String);
29
30#[derive(Debug, Default)]
36pub struct MemoryBackend {
37 postings: RefCell<HashMap<TripleKey, Vec<Posting>>>,
39 doc_lengths: RefCell<HashMap<TripleKey, u32>>,
41 stats: RefCell<HashMap<PairKey, (u32, u64)>>,
43 meta: RefCell<HashMap<TripleKey, Vec<u8>>>,
45 segments: RefCell<HashMap<TripleKey, Vec<u8>>>,
47}
48
49impl MemoryBackend {
50 pub fn new() -> Self {
51 Self::default()
52 }
53}
54
55fn triple(tid: u32, collection: &str, sub: &str) -> TripleKey {
56 (tid, collection.to_string(), sub.to_string())
57}
58
59fn pair(tid: u32, collection: &str) -> PairKey {
60 (tid, collection.to_string())
61}
62
63impl FtsBackend for MemoryBackend {
64 type Error = MemoryError;
65
66 fn read_postings(
67 &self,
68 tid: u32,
69 collection: &str,
70 term: &str,
71 ) -> Result<Vec<Posting>, Self::Error> {
72 Ok(self
73 .postings
74 .borrow()
75 .get(&triple(tid, collection, term))
76 .cloned()
77 .unwrap_or_default())
78 }
79
80 fn write_postings(
81 &self,
82 tid: u32,
83 collection: &str,
84 term: &str,
85 postings: &[Posting],
86 ) -> Result<(), Self::Error> {
87 let key = triple(tid, collection, term);
88 let mut map = self.postings.borrow_mut();
89 if postings.is_empty() {
90 map.remove(&key);
91 } else {
92 map.insert(key, postings.to_vec());
93 }
94 Ok(())
95 }
96
97 fn remove_postings(&self, tid: u32, collection: &str, term: &str) -> Result<(), Self::Error> {
98 self.postings
99 .borrow_mut()
100 .remove(&triple(tid, collection, term));
101 Ok(())
102 }
103
104 fn read_doc_length(
105 &self,
106 tid: u32,
107 collection: &str,
108 doc_id: &str,
109 ) -> Result<Option<u32>, Self::Error> {
110 Ok(self
111 .doc_lengths
112 .borrow()
113 .get(&triple(tid, collection, doc_id))
114 .copied())
115 }
116
117 fn write_doc_length(
118 &self,
119 tid: u32,
120 collection: &str,
121 doc_id: &str,
122 length: u32,
123 ) -> Result<(), Self::Error> {
124 self.doc_lengths
125 .borrow_mut()
126 .insert(triple(tid, collection, doc_id), length);
127 Ok(())
128 }
129
130 fn remove_doc_length(
131 &self,
132 tid: u32,
133 collection: &str,
134 doc_id: &str,
135 ) -> Result<(), Self::Error> {
136 self.doc_lengths
137 .borrow_mut()
138 .remove(&triple(tid, collection, doc_id));
139 Ok(())
140 }
141
142 fn collection_terms(&self, tid: u32, collection: &str) -> Result<Vec<String>, Self::Error> {
143 Ok(self
144 .postings
145 .borrow()
146 .keys()
147 .filter(|(t, c, _)| *t == tid && c == collection)
148 .map(|(_, _, term)| term.clone())
149 .collect())
150 }
151
152 fn collection_stats(&self, tid: u32, collection: &str) -> Result<(u32, u64), Self::Error> {
153 Ok(self
154 .stats
155 .borrow()
156 .get(&pair(tid, collection))
157 .copied()
158 .unwrap_or((0, 0)))
159 }
160
161 fn increment_stats(&self, tid: u32, collection: &str, doc_len: u32) -> Result<(), Self::Error> {
162 let mut stats = self.stats.borrow_mut();
163 let entry = stats.entry(pair(tid, collection)).or_insert((0, 0));
164 entry.0 += 1;
165 entry.1 += doc_len as u64;
166 Ok(())
167 }
168
169 fn decrement_stats(&self, tid: u32, collection: &str, doc_len: u32) -> Result<(), Self::Error> {
170 let mut stats = self.stats.borrow_mut();
171 let entry = stats.entry(pair(tid, collection)).or_insert((0, 0));
172 entry.0 = entry.0.saturating_sub(1);
173 entry.1 = entry.1.saturating_sub(doc_len as u64);
174 Ok(())
175 }
176
177 fn read_meta(
178 &self,
179 tid: u32,
180 collection: &str,
181 subkey: &str,
182 ) -> Result<Option<Vec<u8>>, Self::Error> {
183 Ok(self
184 .meta
185 .borrow()
186 .get(&triple(tid, collection, subkey))
187 .cloned())
188 }
189
190 fn write_meta(
191 &self,
192 tid: u32,
193 collection: &str,
194 subkey: &str,
195 value: &[u8],
196 ) -> Result<(), Self::Error> {
197 self.meta
198 .borrow_mut()
199 .insert(triple(tid, collection, subkey), value.to_vec());
200 Ok(())
201 }
202
203 fn write_segment(
204 &self,
205 tid: u32,
206 collection: &str,
207 segment_id: &str,
208 data: &[u8],
209 ) -> Result<(), Self::Error> {
210 self.segments
211 .borrow_mut()
212 .insert(triple(tid, collection, segment_id), data.to_vec());
213 Ok(())
214 }
215
216 fn read_segment(
217 &self,
218 tid: u32,
219 collection: &str,
220 segment_id: &str,
221 ) -> Result<Option<Vec<u8>>, Self::Error> {
222 Ok(self
223 .segments
224 .borrow()
225 .get(&triple(tid, collection, segment_id))
226 .cloned())
227 }
228
229 fn list_segments(&self, tid: u32, collection: &str) -> Result<Vec<String>, Self::Error> {
230 Ok(self
231 .segments
232 .borrow()
233 .keys()
234 .filter(|(t, c, _)| *t == tid && c == collection)
235 .map(|(_, _, seg)| seg.clone())
236 .collect())
237 }
238
239 fn remove_segment(
240 &self,
241 tid: u32,
242 collection: &str,
243 segment_id: &str,
244 ) -> Result<(), Self::Error> {
245 self.segments
246 .borrow_mut()
247 .remove(&triple(tid, collection, segment_id));
248 Ok(())
249 }
250
251 fn purge_collection(&self, tid: u32, collection: &str) -> Result<usize, Self::Error> {
252 let match_tc = |(t, c, _): &&TripleKey| *t == tid && c == collection;
253
254 let mut postings = self.postings.borrow_mut();
255 let mut doc_lengths = self.doc_lengths.borrow_mut();
256 let before = postings.len() + doc_lengths.len();
257 postings.retain(|k, _| !(k.0 == tid && k.1 == collection));
258 doc_lengths.retain(|k, _| !(k.0 == tid && k.1 == collection));
259 self.stats.borrow_mut().remove(&pair(tid, collection));
260 self.meta
261 .borrow_mut()
262 .retain(|k, _| !(k.0 == tid && k.1 == collection));
263 self.segments
264 .borrow_mut()
265 .retain(|k, _| !(k.0 == tid && k.1 == collection));
266 let after = postings.len() + doc_lengths.len();
267 let _ = match_tc;
268 Ok(before - after)
269 }
270
271 fn purge_tenant(&self, tid: u32) -> Result<usize, Self::Error> {
272 let mut postings = self.postings.borrow_mut();
273 let mut doc_lengths = self.doc_lengths.borrow_mut();
274 let before = postings.len() + doc_lengths.len();
275 postings.retain(|k, _| k.0 != tid);
276 doc_lengths.retain(|k, _| k.0 != tid);
277 self.stats.borrow_mut().retain(|k, _| k.0 != tid);
278 self.meta.borrow_mut().retain(|k, _| k.0 != tid);
279 self.segments.borrow_mut().retain(|k, _| k.0 != tid);
280 let after = postings.len() + doc_lengths.len();
281 Ok(before - after)
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 const T: u32 = 1;
290
291 #[test]
292 fn roundtrip_postings() {
293 let backend = MemoryBackend::new();
294 let postings = vec![Posting {
295 doc_id: "d1".into(),
296 term_freq: 2,
297 positions: vec![0, 5],
298 }];
299 backend
300 .write_postings(T, "col", "hello", &postings)
301 .unwrap();
302
303 let read = backend.read_postings(T, "col", "hello").unwrap();
304 assert_eq!(read.len(), 1);
305 assert_eq!(read[0].doc_id, "d1");
306 }
307
308 #[test]
309 fn roundtrip_doc_lengths() {
310 let backend = MemoryBackend::new();
311 backend.write_doc_length(T, "col", "d1", 42).unwrap();
312 assert_eq!(backend.read_doc_length(T, "col", "d1").unwrap(), Some(42));
313
314 backend.remove_doc_length(T, "col", "d1").unwrap();
315 assert_eq!(backend.read_doc_length(T, "col", "d1").unwrap(), None);
316 }
317
318 #[test]
319 fn incremental_stats() {
320 let backend = MemoryBackend::new();
321 backend.increment_stats(T, "col", 10).unwrap();
322 backend.increment_stats(T, "col", 20).unwrap();
323 assert_eq!(backend.collection_stats(T, "col").unwrap(), (2, 30));
324
325 backend.decrement_stats(T, "col", 10).unwrap();
326 assert_eq!(backend.collection_stats(T, "col").unwrap(), (1, 20));
327 }
328
329 #[test]
330 fn stats_saturating_sub() {
331 let backend = MemoryBackend::new();
332 backend.decrement_stats(T, "col", 100).unwrap();
333 assert_eq!(backend.collection_stats(T, "col").unwrap(), (0, 0));
334 }
335
336 #[test]
337 fn purge_clears_stats_and_isolates_collections() {
338 let backend = MemoryBackend::new();
339 backend.increment_stats(T, "col", 10).unwrap();
340 backend.write_doc_length(T, "col", "d1", 10).unwrap();
341 backend
342 .write_postings(
343 T,
344 "col",
345 "hello",
346 &[Posting {
347 doc_id: "d1".into(),
348 term_freq: 1,
349 positions: vec![0],
350 }],
351 )
352 .unwrap();
353
354 backend.increment_stats(T, "other", 7).unwrap();
355 backend.write_doc_length(T, "other", "d1", 7).unwrap();
356 backend
357 .write_postings(
358 T,
359 "other",
360 "world",
361 &[Posting {
362 doc_id: "d1".into(),
363 term_freq: 1,
364 positions: vec![0],
365 }],
366 )
367 .unwrap();
368
369 backend.purge_collection(T, "col").unwrap();
370 assert_eq!(backend.collection_stats(T, "col").unwrap(), (0, 0));
371 assert!(backend.read_postings(T, "col", "hello").unwrap().is_empty());
372 assert_eq!(backend.read_doc_length(T, "col", "d1").unwrap(), None);
373
374 assert_eq!(backend.collection_stats(T, "other").unwrap(), (1, 7));
375 assert_eq!(backend.read_postings(T, "other", "world").unwrap().len(), 1);
376 assert_eq!(backend.read_doc_length(T, "other", "d1").unwrap(), Some(7));
377 }
378
379 #[test]
380 fn collection_terms() {
381 let backend = MemoryBackend::new();
382 backend
383 .write_postings(
384 T,
385 "col",
386 "hello",
387 &[Posting {
388 doc_id: "d1".into(),
389 term_freq: 1,
390 positions: vec![0],
391 }],
392 )
393 .unwrap();
394 backend
395 .write_postings(
396 T,
397 "col",
398 "world",
399 &[Posting {
400 doc_id: "d1".into(),
401 term_freq: 1,
402 positions: vec![1],
403 }],
404 )
405 .unwrap();
406
407 let mut terms = backend.collection_terms(T, "col").unwrap();
408 terms.sort();
409 assert_eq!(terms, vec!["hello", "world"]);
410 }
411
412 #[test]
413 fn segment_roundtrip() {
414 let backend = MemoryBackend::new();
415 let data = b"compressed segment bytes";
416 backend.write_segment(T, "col", "id1", data).unwrap();
417 assert_eq!(
418 backend.read_segment(T, "col", "id1").unwrap(),
419 Some(data.to_vec())
420 );
421 assert_eq!(backend.read_segment(T, "col", "missing").unwrap(), None);
422 }
423
424 #[test]
425 fn segment_list_filters_by_collection() {
426 let backend = MemoryBackend::new();
427 backend.write_segment(T, "col", "a", b"a").unwrap();
428 backend.write_segment(T, "col", "b", b"b").unwrap();
429 backend.write_segment(T, "other", "c", b"c").unwrap();
430
431 let mut segs = backend.list_segments(T, "col").unwrap();
432 segs.sort();
433 assert_eq!(segs, vec!["a", "b"]);
434
435 let other = backend.list_segments(T, "other").unwrap();
436 assert_eq!(other, vec!["c"]);
437 }
438
439 #[test]
440 fn segment_remove() {
441 let backend = MemoryBackend::new();
442 backend.write_segment(T, "col", "id1", b"data").unwrap();
443 backend.remove_segment(T, "col", "id1").unwrap();
444 assert_eq!(backend.read_segment(T, "col", "id1").unwrap(), None);
445 }
446
447 #[test]
448 fn purge_clears_segments() {
449 let backend = MemoryBackend::new();
450 backend.write_segment(T, "col", "a", b"a").unwrap();
451 backend.write_segment(T, "other", "b", b"b").unwrap();
452
453 backend.purge_collection(T, "col").unwrap();
454 assert!(backend.list_segments(T, "col").unwrap().is_empty());
455 assert_eq!(backend.list_segments(T, "other").unwrap().len(), 1);
456 }
457
458 #[test]
459 fn purge_tenant_isolates_tenants() {
460 let backend = MemoryBackend::new();
461 backend.increment_stats(1, "col", 5).unwrap();
462 backend.increment_stats(2, "col", 7).unwrap();
463 backend
464 .write_postings(
465 1,
466 "col",
467 "t",
468 &[Posting {
469 doc_id: "d".into(),
470 term_freq: 1,
471 positions: vec![0],
472 }],
473 )
474 .unwrap();
475 backend
476 .write_postings(
477 2,
478 "col",
479 "t",
480 &[Posting {
481 doc_id: "d".into(),
482 term_freq: 1,
483 positions: vec![0],
484 }],
485 )
486 .unwrap();
487
488 backend.purge_tenant(1).unwrap();
489 assert_eq!(backend.collection_stats(1, "col").unwrap(), (0, 0));
490 assert!(backend.read_postings(1, "col", "t").unwrap().is_empty());
491 assert_eq!(backend.collection_stats(2, "col").unwrap(), (1, 7));
492 assert_eq!(backend.read_postings(2, "col", "t").unwrap().len(), 1);
493 }
494}