1use std::cell::RefCell;
13use std::collections::HashMap;
14use std::fmt;
15
16use nodedb_types::Surrogate;
17
18use crate::backend::FtsBackend;
19use crate::posting::Posting;
20
21#[derive(Debug)]
23pub struct MemoryError(String);
24
25impl fmt::Display for MemoryError {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 write!(f, "memory backend: {}", self.0)
28 }
29}
30
31type TripleKey = (u64, String, String);
32type DocLenKey = (u64, String, Surrogate);
33type PairKey = (u64, String);
34
35#[derive(Debug, Default)]
41pub struct MemoryBackend {
42 postings: RefCell<HashMap<TripleKey, Vec<Posting>>>,
44 doc_lengths: RefCell<HashMap<DocLenKey, u32>>,
46 stats: RefCell<HashMap<PairKey, (u32, u64)>>,
48 meta: RefCell<HashMap<TripleKey, Vec<u8>>>,
50 segments: RefCell<HashMap<TripleKey, Vec<u8>>>,
52}
53
54impl MemoryBackend {
55 pub fn new() -> Self {
56 Self::default()
57 }
58}
59
60fn triple(tid: u64, collection: &str, sub: &str) -> TripleKey {
61 (tid, collection.to_string(), sub.to_string())
62}
63
64fn doc_len_key(tid: u64, collection: &str, doc_id: Surrogate) -> DocLenKey {
65 (tid, collection.to_string(), doc_id)
66}
67
68fn pair(tid: u64, collection: &str) -> PairKey {
69 (tid, collection.to_string())
70}
71
72impl FtsBackend for MemoryBackend {
73 type Error = MemoryError;
74
75 fn read_postings(
76 &self,
77 tid: u64,
78 collection: &str,
79 term: &str,
80 ) -> Result<Vec<Posting>, Self::Error> {
81 Ok(self
82 .postings
83 .borrow()
84 .get(&triple(tid, collection, term))
85 .cloned()
86 .unwrap_or_default())
87 }
88
89 fn write_postings(
90 &self,
91 tid: u64,
92 collection: &str,
93 term: &str,
94 postings: &[Posting],
95 ) -> Result<(), Self::Error> {
96 let key = triple(tid, collection, term);
97 let mut map = self.postings.borrow_mut();
98 if postings.is_empty() {
99 map.remove(&key);
100 } else {
101 map.insert(key, postings.to_vec());
102 }
103 Ok(())
104 }
105
106 fn remove_postings(&self, tid: u64, collection: &str, term: &str) -> Result<(), Self::Error> {
107 self.postings
108 .borrow_mut()
109 .remove(&triple(tid, collection, term));
110 Ok(())
111 }
112
113 fn read_doc_length(
114 &self,
115 tid: u64,
116 collection: &str,
117 doc_id: Surrogate,
118 ) -> Result<Option<u32>, Self::Error> {
119 Ok(self
120 .doc_lengths
121 .borrow()
122 .get(&doc_len_key(tid, collection, doc_id))
123 .copied())
124 }
125
126 fn write_doc_length(
127 &self,
128 tid: u64,
129 collection: &str,
130 doc_id: Surrogate,
131 length: u32,
132 ) -> Result<(), Self::Error> {
133 self.doc_lengths
134 .borrow_mut()
135 .insert(doc_len_key(tid, collection, doc_id), length);
136 Ok(())
137 }
138
139 fn remove_doc_length(
140 &self,
141 tid: u64,
142 collection: &str,
143 doc_id: Surrogate,
144 ) -> Result<(), Self::Error> {
145 self.doc_lengths
146 .borrow_mut()
147 .remove(&doc_len_key(tid, collection, doc_id));
148 Ok(())
149 }
150
151 fn collection_terms(&self, tid: u64, collection: &str) -> Result<Vec<String>, Self::Error> {
152 Ok(self
153 .postings
154 .borrow()
155 .keys()
156 .filter(|(t, c, _)| *t == tid && c == collection)
157 .map(|(_, _, term)| term.clone())
158 .collect())
159 }
160
161 fn collection_stats(&self, tid: u64, collection: &str) -> Result<(u32, u64), Self::Error> {
162 Ok(self
163 .stats
164 .borrow()
165 .get(&pair(tid, collection))
166 .copied()
167 .unwrap_or((0, 0)))
168 }
169
170 fn increment_stats(&self, tid: u64, collection: &str, doc_len: u32) -> Result<(), Self::Error> {
171 let mut stats = self.stats.borrow_mut();
172 let entry = stats.entry(pair(tid, collection)).or_insert((0, 0));
173 entry.0 += 1;
174 entry.1 += doc_len as u64;
175 Ok(())
176 }
177
178 fn decrement_stats(&self, tid: u64, collection: &str, doc_len: u32) -> Result<(), Self::Error> {
179 let mut stats = self.stats.borrow_mut();
180 let entry = stats.entry(pair(tid, collection)).or_insert((0, 0));
181 entry.0 = entry.0.saturating_sub(1);
182 entry.1 = entry.1.saturating_sub(doc_len as u64);
183 Ok(())
184 }
185
186 fn read_meta(
187 &self,
188 tid: u64,
189 collection: &str,
190 subkey: &str,
191 ) -> Result<Option<Vec<u8>>, Self::Error> {
192 Ok(self
193 .meta
194 .borrow()
195 .get(&triple(tid, collection, subkey))
196 .cloned())
197 }
198
199 fn write_meta(
200 &self,
201 tid: u64,
202 collection: &str,
203 subkey: &str,
204 value: &[u8],
205 ) -> Result<(), Self::Error> {
206 self.meta
207 .borrow_mut()
208 .insert(triple(tid, collection, subkey), value.to_vec());
209 Ok(())
210 }
211
212 fn write_segment(
213 &self,
214 tid: u64,
215 collection: &str,
216 segment_id: &str,
217 data: &[u8],
218 ) -> Result<(), Self::Error> {
219 self.segments
220 .borrow_mut()
221 .insert(triple(tid, collection, segment_id), data.to_vec());
222 Ok(())
223 }
224
225 fn read_segment(
226 &self,
227 tid: u64,
228 collection: &str,
229 segment_id: &str,
230 ) -> Result<Option<Vec<u8>>, Self::Error> {
231 Ok(self
232 .segments
233 .borrow()
234 .get(&triple(tid, collection, segment_id))
235 .cloned())
236 }
237
238 fn list_segments(&self, tid: u64, collection: &str) -> Result<Vec<String>, Self::Error> {
239 Ok(self
240 .segments
241 .borrow()
242 .keys()
243 .filter(|(t, c, _)| *t == tid && c == collection)
244 .map(|(_, _, seg)| seg.clone())
245 .collect())
246 }
247
248 fn remove_segment(
249 &self,
250 tid: u64,
251 collection: &str,
252 segment_id: &str,
253 ) -> Result<(), Self::Error> {
254 self.segments
255 .borrow_mut()
256 .remove(&triple(tid, collection, segment_id));
257 Ok(())
258 }
259
260 fn purge_collection(&self, tid: u64, collection: &str) -> Result<usize, Self::Error> {
261 let match_tc = |(t, c, _): &&TripleKey| *t == tid && c == collection;
262
263 let mut postings = self.postings.borrow_mut();
264 let mut doc_lengths = self.doc_lengths.borrow_mut();
265 let before = postings.len() + doc_lengths.len();
266 postings.retain(|k, _| !(k.0 == tid && k.1 == collection));
267 doc_lengths.retain(|k, _| !(k.0 == tid && k.1 == collection));
268 self.stats.borrow_mut().remove(&pair(tid, collection));
269 self.meta
270 .borrow_mut()
271 .retain(|k, _| !(k.0 == tid && k.1 == collection));
272 self.segments
273 .borrow_mut()
274 .retain(|k, _| !(k.0 == tid && k.1 == collection));
275 let after = postings.len() + doc_lengths.len();
276 let _ = match_tc;
277 Ok(before - after)
278 }
279
280 fn purge_tenant(&self, tid: u64) -> Result<usize, Self::Error> {
281 let mut postings = self.postings.borrow_mut();
282 let mut doc_lengths = self.doc_lengths.borrow_mut();
283 let before = postings.len() + doc_lengths.len();
284 postings.retain(|k, _| k.0 != tid);
285 doc_lengths.retain(|k, _| k.0 != tid);
286 self.stats.borrow_mut().retain(|k, _| k.0 != tid);
287 self.meta.borrow_mut().retain(|k, _| k.0 != tid);
288 self.segments.borrow_mut().retain(|k, _| k.0 != tid);
289 let after = postings.len() + doc_lengths.len();
290 Ok(before - after)
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 const T: u64 = 1;
299
300 #[test]
301 fn roundtrip_postings() {
302 let backend = MemoryBackend::new();
303 let postings = vec![Posting {
304 doc_id: Surrogate(1),
305 term_freq: 2,
306 positions: vec![0, 5],
307 }];
308 backend
309 .write_postings(T, "col", "hello", &postings)
310 .unwrap();
311
312 let read = backend.read_postings(T, "col", "hello").unwrap();
313 assert_eq!(read.len(), 1);
314 assert_eq!(read[0].doc_id, Surrogate(1));
315 }
316
317 #[test]
318 fn roundtrip_doc_lengths() {
319 let backend = MemoryBackend::new();
320 backend
321 .write_doc_length(T, "col", Surrogate(1), 42)
322 .unwrap();
323 assert_eq!(
324 backend.read_doc_length(T, "col", Surrogate(1)).unwrap(),
325 Some(42)
326 );
327
328 backend.remove_doc_length(T, "col", Surrogate(1)).unwrap();
329 assert_eq!(
330 backend.read_doc_length(T, "col", Surrogate(1)).unwrap(),
331 None
332 );
333 }
334
335 #[test]
336 fn incremental_stats() {
337 let backend = MemoryBackend::new();
338 backend.increment_stats(T, "col", 10).unwrap();
339 backend.increment_stats(T, "col", 20).unwrap();
340 assert_eq!(backend.collection_stats(T, "col").unwrap(), (2, 30));
341
342 backend.decrement_stats(T, "col", 10).unwrap();
343 assert_eq!(backend.collection_stats(T, "col").unwrap(), (1, 20));
344 }
345
346 #[test]
347 fn stats_saturating_sub() {
348 let backend = MemoryBackend::new();
349 backend.decrement_stats(T, "col", 100).unwrap();
350 assert_eq!(backend.collection_stats(T, "col").unwrap(), (0, 0));
351 }
352
353 #[test]
354 fn purge_clears_stats_and_isolates_collections() {
355 let backend = MemoryBackend::new();
356 backend.increment_stats(T, "col", 10).unwrap();
357 backend
358 .write_doc_length(T, "col", Surrogate(1), 10)
359 .unwrap();
360 backend
361 .write_postings(
362 T,
363 "col",
364 "hello",
365 &[Posting {
366 doc_id: Surrogate(1),
367 term_freq: 1,
368 positions: vec![0],
369 }],
370 )
371 .unwrap();
372
373 backend.increment_stats(T, "other", 7).unwrap();
374 backend
375 .write_doc_length(T, "other", Surrogate(1), 7)
376 .unwrap();
377 backend
378 .write_postings(
379 T,
380 "other",
381 "world",
382 &[Posting {
383 doc_id: Surrogate(1),
384 term_freq: 1,
385 positions: vec![0],
386 }],
387 )
388 .unwrap();
389
390 backend.purge_collection(T, "col").unwrap();
391 assert_eq!(backend.collection_stats(T, "col").unwrap(), (0, 0));
392 assert!(backend.read_postings(T, "col", "hello").unwrap().is_empty());
393 assert_eq!(
394 backend.read_doc_length(T, "col", Surrogate(1)).unwrap(),
395 None
396 );
397
398 assert_eq!(backend.collection_stats(T, "other").unwrap(), (1, 7));
399 assert_eq!(backend.read_postings(T, "other", "world").unwrap().len(), 1);
400 assert_eq!(
401 backend.read_doc_length(T, "other", Surrogate(1)).unwrap(),
402 Some(7)
403 );
404 }
405
406 #[test]
407 fn collection_terms() {
408 let backend = MemoryBackend::new();
409 backend
410 .write_postings(
411 T,
412 "col",
413 "hello",
414 &[Posting {
415 doc_id: Surrogate(1),
416 term_freq: 1,
417 positions: vec![0],
418 }],
419 )
420 .unwrap();
421 backend
422 .write_postings(
423 T,
424 "col",
425 "world",
426 &[Posting {
427 doc_id: Surrogate(1),
428 term_freq: 1,
429 positions: vec![1],
430 }],
431 )
432 .unwrap();
433
434 let mut terms = backend.collection_terms(T, "col").unwrap();
435 terms.sort();
436 assert_eq!(terms, vec!["hello", "world"]);
437 }
438
439 #[test]
440 fn segment_roundtrip() {
441 let backend = MemoryBackend::new();
442 let data = b"compressed segment bytes";
443 backend.write_segment(T, "col", "id1", data).unwrap();
444 assert_eq!(
445 backend.read_segment(T, "col", "id1").unwrap(),
446 Some(data.to_vec())
447 );
448 assert_eq!(backend.read_segment(T, "col", "missing").unwrap(), None);
449 }
450
451 #[test]
452 fn segment_list_filters_by_collection() {
453 let backend = MemoryBackend::new();
454 backend.write_segment(T, "col", "a", b"a").unwrap();
455 backend.write_segment(T, "col", "b", b"b").unwrap();
456 backend.write_segment(T, "other", "c", b"c").unwrap();
457
458 let mut segs = backend.list_segments(T, "col").unwrap();
459 segs.sort();
460 assert_eq!(segs, vec!["a", "b"]);
461
462 let other = backend.list_segments(T, "other").unwrap();
463 assert_eq!(other, vec!["c"]);
464 }
465
466 #[test]
467 fn segment_remove() {
468 let backend = MemoryBackend::new();
469 backend.write_segment(T, "col", "id1", b"data").unwrap();
470 backend.remove_segment(T, "col", "id1").unwrap();
471 assert_eq!(backend.read_segment(T, "col", "id1").unwrap(), None);
472 }
473
474 #[test]
475 fn purge_clears_segments() {
476 let backend = MemoryBackend::new();
477 backend.write_segment(T, "col", "a", b"a").unwrap();
478 backend.write_segment(T, "other", "b", b"b").unwrap();
479
480 backend.purge_collection(T, "col").unwrap();
481 assert!(backend.list_segments(T, "col").unwrap().is_empty());
482 assert_eq!(backend.list_segments(T, "other").unwrap().len(), 1);
483 }
484
485 #[test]
486 fn purge_tenant_isolates_tenants() {
487 let backend = MemoryBackend::new();
488 backend.increment_stats(1, "col", 5).unwrap();
489 backend.increment_stats(2, "col", 7).unwrap();
490 backend
491 .write_postings(
492 1,
493 "col",
494 "t",
495 &[Posting {
496 doc_id: Surrogate(1),
497 term_freq: 1,
498 positions: vec![0],
499 }],
500 )
501 .unwrap();
502 backend
503 .write_postings(
504 2,
505 "col",
506 "t",
507 &[Posting {
508 doc_id: Surrogate(1),
509 term_freq: 1,
510 positions: vec![0],
511 }],
512 )
513 .unwrap();
514
515 backend.purge_tenant(1).unwrap();
516 assert_eq!(backend.collection_stats(1, "col").unwrap(), (0, 0));
517 assert!(backend.read_postings(1, "col", "t").unwrap().is_empty());
518 assert_eq!(backend.collection_stats(2, "col").unwrap(), (1, 7));
519 assert_eq!(backend.read_postings(2, "col", "t").unwrap().len(), 1);
520 }
521}