Skip to main content

nodedb_fts/backend/
memory.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! In-memory FTS backend for Lite and WASM deployments.
4//!
5//! All data lives in HashMaps behind `RefCell` for interior mutability,
6//! matching the `&self` trait signature. Rebuilt from documents on cold
7//! start — acceptable for edge-scale datasets.
8//!
9//! Keys are fully structural tuples `(tid, collection, …)` — tenant
10//! isolation never depends on lexical-prefix ordering.
11
12use std::cell::RefCell;
13use std::collections::HashMap;
14use std::fmt;
15
16use nodedb_types::Surrogate;
17
18use crate::backend::FtsBackend;
19use crate::posting::Posting;
20
21/// In-memory backend error (infallible in practice, but trait requires it).
22#[derive(Debug)]
23pub struct MemoryError(String);
24
25impl fmt::Display for MemoryError {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        write!(f, "memory backend: {}", self.0)
28    }
29}
30
31type TripleKey = (u64, String, String);
32type DocLenKey = (u64, String, Surrogate);
33type PairKey = (u64, String);
34
35/// In-memory FTS backend backed by HashMaps keyed by `(tid, collection, …)`
36/// tuples.
37///
38/// Uses `RefCell` for interior mutability so the `FtsBackend` trait
39/// can use `&self` uniformly (redb has its own transactional isolation).
40#[derive(Debug, Default)]
41pub struct MemoryBackend {
42    /// `(tid, collection, term) → posting list`.
43    postings: RefCell<HashMap<TripleKey, Vec<Posting>>>,
44    /// `(tid, collection, doc_id) → token count`.
45    doc_lengths: RefCell<HashMap<DocLenKey, u32>>,
46    /// `(tid, collection) → (doc_count, total_token_sum)`.
47    stats: RefCell<HashMap<PairKey, (u32, u64)>>,
48    /// `(tid, collection, subkey) → blob` for docmap, fieldnorms, analyzer, language.
49    meta: RefCell<HashMap<TripleKey, Vec<u8>>>,
50    /// `(tid, collection, segment_id) → compressed segment bytes`.
51    segments: RefCell<HashMap<TripleKey, Vec<u8>>>,
52}
53
54impl MemoryBackend {
55    pub fn new() -> Self {
56        Self::default()
57    }
58}
59
60fn triple(tid: u64, collection: &str, sub: &str) -> TripleKey {
61    (tid, collection.to_string(), sub.to_string())
62}
63
64fn doc_len_key(tid: u64, collection: &str, doc_id: Surrogate) -> DocLenKey {
65    (tid, collection.to_string(), doc_id)
66}
67
68fn pair(tid: u64, collection: &str) -> PairKey {
69    (tid, collection.to_string())
70}
71
72impl FtsBackend for MemoryBackend {
73    type Error = MemoryError;
74
75    fn read_postings(
76        &self,
77        tid: u64,
78        collection: &str,
79        term: &str,
80    ) -> Result<Vec<Posting>, Self::Error> {
81        Ok(self
82            .postings
83            .borrow()
84            .get(&triple(tid, collection, term))
85            .cloned()
86            .unwrap_or_default())
87    }
88
89    fn write_postings(
90        &self,
91        tid: u64,
92        collection: &str,
93        term: &str,
94        postings: &[Posting],
95    ) -> Result<(), Self::Error> {
96        let key = triple(tid, collection, term);
97        let mut map = self.postings.borrow_mut();
98        if postings.is_empty() {
99            map.remove(&key);
100        } else {
101            map.insert(key, postings.to_vec());
102        }
103        Ok(())
104    }
105
106    fn remove_postings(&self, tid: u64, collection: &str, term: &str) -> Result<(), Self::Error> {
107        self.postings
108            .borrow_mut()
109            .remove(&triple(tid, collection, term));
110        Ok(())
111    }
112
113    fn read_doc_length(
114        &self,
115        tid: u64,
116        collection: &str,
117        doc_id: Surrogate,
118    ) -> Result<Option<u32>, Self::Error> {
119        Ok(self
120            .doc_lengths
121            .borrow()
122            .get(&doc_len_key(tid, collection, doc_id))
123            .copied())
124    }
125
126    fn write_doc_length(
127        &self,
128        tid: u64,
129        collection: &str,
130        doc_id: Surrogate,
131        length: u32,
132    ) -> Result<(), Self::Error> {
133        self.doc_lengths
134            .borrow_mut()
135            .insert(doc_len_key(tid, collection, doc_id), length);
136        Ok(())
137    }
138
139    fn remove_doc_length(
140        &self,
141        tid: u64,
142        collection: &str,
143        doc_id: Surrogate,
144    ) -> Result<(), Self::Error> {
145        self.doc_lengths
146            .borrow_mut()
147            .remove(&doc_len_key(tid, collection, doc_id));
148        Ok(())
149    }
150
151    fn collection_terms(&self, tid: u64, collection: &str) -> Result<Vec<String>, Self::Error> {
152        Ok(self
153            .postings
154            .borrow()
155            .keys()
156            .filter(|(t, c, _)| *t == tid && c == collection)
157            .map(|(_, _, term)| term.clone())
158            .collect())
159    }
160
161    fn collection_stats(&self, tid: u64, collection: &str) -> Result<(u32, u64), Self::Error> {
162        Ok(self
163            .stats
164            .borrow()
165            .get(&pair(tid, collection))
166            .copied()
167            .unwrap_or((0, 0)))
168    }
169
170    fn increment_stats(&self, tid: u64, collection: &str, doc_len: u32) -> Result<(), Self::Error> {
171        let mut stats = self.stats.borrow_mut();
172        let entry = stats.entry(pair(tid, collection)).or_insert((0, 0));
173        entry.0 += 1;
174        entry.1 += doc_len as u64;
175        Ok(())
176    }
177
178    fn decrement_stats(&self, tid: u64, collection: &str, doc_len: u32) -> Result<(), Self::Error> {
179        let mut stats = self.stats.borrow_mut();
180        let entry = stats.entry(pair(tid, collection)).or_insert((0, 0));
181        entry.0 = entry.0.saturating_sub(1);
182        entry.1 = entry.1.saturating_sub(doc_len as u64);
183        Ok(())
184    }
185
186    fn read_meta(
187        &self,
188        tid: u64,
189        collection: &str,
190        subkey: &str,
191    ) -> Result<Option<Vec<u8>>, Self::Error> {
192        Ok(self
193            .meta
194            .borrow()
195            .get(&triple(tid, collection, subkey))
196            .cloned())
197    }
198
199    fn write_meta(
200        &self,
201        tid: u64,
202        collection: &str,
203        subkey: &str,
204        value: &[u8],
205    ) -> Result<(), Self::Error> {
206        self.meta
207            .borrow_mut()
208            .insert(triple(tid, collection, subkey), value.to_vec());
209        Ok(())
210    }
211
212    fn write_segment(
213        &self,
214        tid: u64,
215        collection: &str,
216        segment_id: &str,
217        data: &[u8],
218    ) -> Result<(), Self::Error> {
219        self.segments
220            .borrow_mut()
221            .insert(triple(tid, collection, segment_id), data.to_vec());
222        Ok(())
223    }
224
225    fn read_segment(
226        &self,
227        tid: u64,
228        collection: &str,
229        segment_id: &str,
230    ) -> Result<Option<Vec<u8>>, Self::Error> {
231        Ok(self
232            .segments
233            .borrow()
234            .get(&triple(tid, collection, segment_id))
235            .cloned())
236    }
237
238    fn list_segments(&self, tid: u64, collection: &str) -> Result<Vec<String>, Self::Error> {
239        Ok(self
240            .segments
241            .borrow()
242            .keys()
243            .filter(|(t, c, _)| *t == tid && c == collection)
244            .map(|(_, _, seg)| seg.clone())
245            .collect())
246    }
247
248    fn remove_segment(
249        &self,
250        tid: u64,
251        collection: &str,
252        segment_id: &str,
253    ) -> Result<(), Self::Error> {
254        self.segments
255            .borrow_mut()
256            .remove(&triple(tid, collection, segment_id));
257        Ok(())
258    }
259
260    fn purge_collection(&self, tid: u64, collection: &str) -> Result<usize, Self::Error> {
261        let match_tc = |(t, c, _): &&TripleKey| *t == tid && c == collection;
262
263        let mut postings = self.postings.borrow_mut();
264        let mut doc_lengths = self.doc_lengths.borrow_mut();
265        let before = postings.len() + doc_lengths.len();
266        postings.retain(|k, _| !(k.0 == tid && k.1 == collection));
267        doc_lengths.retain(|k, _| !(k.0 == tid && k.1 == collection));
268        self.stats.borrow_mut().remove(&pair(tid, collection));
269        self.meta
270            .borrow_mut()
271            .retain(|k, _| !(k.0 == tid && k.1 == collection));
272        self.segments
273            .borrow_mut()
274            .retain(|k, _| !(k.0 == tid && k.1 == collection));
275        let after = postings.len() + doc_lengths.len();
276        let _ = match_tc;
277        Ok(before - after)
278    }
279
280    fn purge_tenant(&self, tid: u64) -> Result<usize, Self::Error> {
281        let mut postings = self.postings.borrow_mut();
282        let mut doc_lengths = self.doc_lengths.borrow_mut();
283        let before = postings.len() + doc_lengths.len();
284        postings.retain(|k, _| k.0 != tid);
285        doc_lengths.retain(|k, _| k.0 != tid);
286        self.stats.borrow_mut().retain(|k, _| k.0 != tid);
287        self.meta.borrow_mut().retain(|k, _| k.0 != tid);
288        self.segments.borrow_mut().retain(|k, _| k.0 != tid);
289        let after = postings.len() + doc_lengths.len();
290        Ok(before - after)
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    const T: u64 = 1;
299
300    #[test]
301    fn roundtrip_postings() {
302        let backend = MemoryBackend::new();
303        let postings = vec![Posting {
304            doc_id: Surrogate(1),
305            term_freq: 2,
306            positions: vec![0, 5],
307        }];
308        backend
309            .write_postings(T, "col", "hello", &postings)
310            .unwrap();
311
312        let read = backend.read_postings(T, "col", "hello").unwrap();
313        assert_eq!(read.len(), 1);
314        assert_eq!(read[0].doc_id, Surrogate(1));
315    }
316
317    #[test]
318    fn roundtrip_doc_lengths() {
319        let backend = MemoryBackend::new();
320        backend
321            .write_doc_length(T, "col", Surrogate(1), 42)
322            .unwrap();
323        assert_eq!(
324            backend.read_doc_length(T, "col", Surrogate(1)).unwrap(),
325            Some(42)
326        );
327
328        backend.remove_doc_length(T, "col", Surrogate(1)).unwrap();
329        assert_eq!(
330            backend.read_doc_length(T, "col", Surrogate(1)).unwrap(),
331            None
332        );
333    }
334
335    #[test]
336    fn incremental_stats() {
337        let backend = MemoryBackend::new();
338        backend.increment_stats(T, "col", 10).unwrap();
339        backend.increment_stats(T, "col", 20).unwrap();
340        assert_eq!(backend.collection_stats(T, "col").unwrap(), (2, 30));
341
342        backend.decrement_stats(T, "col", 10).unwrap();
343        assert_eq!(backend.collection_stats(T, "col").unwrap(), (1, 20));
344    }
345
346    #[test]
347    fn stats_saturating_sub() {
348        let backend = MemoryBackend::new();
349        backend.decrement_stats(T, "col", 100).unwrap();
350        assert_eq!(backend.collection_stats(T, "col").unwrap(), (0, 0));
351    }
352
353    #[test]
354    fn purge_clears_stats_and_isolates_collections() {
355        let backend = MemoryBackend::new();
356        backend.increment_stats(T, "col", 10).unwrap();
357        backend
358            .write_doc_length(T, "col", Surrogate(1), 10)
359            .unwrap();
360        backend
361            .write_postings(
362                T,
363                "col",
364                "hello",
365                &[Posting {
366                    doc_id: Surrogate(1),
367                    term_freq: 1,
368                    positions: vec![0],
369                }],
370            )
371            .unwrap();
372
373        backend.increment_stats(T, "other", 7).unwrap();
374        backend
375            .write_doc_length(T, "other", Surrogate(1), 7)
376            .unwrap();
377        backend
378            .write_postings(
379                T,
380                "other",
381                "world",
382                &[Posting {
383                    doc_id: Surrogate(1),
384                    term_freq: 1,
385                    positions: vec![0],
386                }],
387            )
388            .unwrap();
389
390        backend.purge_collection(T, "col").unwrap();
391        assert_eq!(backend.collection_stats(T, "col").unwrap(), (0, 0));
392        assert!(backend.read_postings(T, "col", "hello").unwrap().is_empty());
393        assert_eq!(
394            backend.read_doc_length(T, "col", Surrogate(1)).unwrap(),
395            None
396        );
397
398        assert_eq!(backend.collection_stats(T, "other").unwrap(), (1, 7));
399        assert_eq!(backend.read_postings(T, "other", "world").unwrap().len(), 1);
400        assert_eq!(
401            backend.read_doc_length(T, "other", Surrogate(1)).unwrap(),
402            Some(7)
403        );
404    }
405
406    #[test]
407    fn collection_terms() {
408        let backend = MemoryBackend::new();
409        backend
410            .write_postings(
411                T,
412                "col",
413                "hello",
414                &[Posting {
415                    doc_id: Surrogate(1),
416                    term_freq: 1,
417                    positions: vec![0],
418                }],
419            )
420            .unwrap();
421        backend
422            .write_postings(
423                T,
424                "col",
425                "world",
426                &[Posting {
427                    doc_id: Surrogate(1),
428                    term_freq: 1,
429                    positions: vec![1],
430                }],
431            )
432            .unwrap();
433
434        let mut terms = backend.collection_terms(T, "col").unwrap();
435        terms.sort();
436        assert_eq!(terms, vec!["hello", "world"]);
437    }
438
439    #[test]
440    fn segment_roundtrip() {
441        let backend = MemoryBackend::new();
442        let data = b"compressed segment bytes";
443        backend.write_segment(T, "col", "id1", data).unwrap();
444        assert_eq!(
445            backend.read_segment(T, "col", "id1").unwrap(),
446            Some(data.to_vec())
447        );
448        assert_eq!(backend.read_segment(T, "col", "missing").unwrap(), None);
449    }
450
451    #[test]
452    fn segment_list_filters_by_collection() {
453        let backend = MemoryBackend::new();
454        backend.write_segment(T, "col", "a", b"a").unwrap();
455        backend.write_segment(T, "col", "b", b"b").unwrap();
456        backend.write_segment(T, "other", "c", b"c").unwrap();
457
458        let mut segs = backend.list_segments(T, "col").unwrap();
459        segs.sort();
460        assert_eq!(segs, vec!["a", "b"]);
461
462        let other = backend.list_segments(T, "other").unwrap();
463        assert_eq!(other, vec!["c"]);
464    }
465
466    #[test]
467    fn segment_remove() {
468        let backend = MemoryBackend::new();
469        backend.write_segment(T, "col", "id1", b"data").unwrap();
470        backend.remove_segment(T, "col", "id1").unwrap();
471        assert_eq!(backend.read_segment(T, "col", "id1").unwrap(), None);
472    }
473
474    #[test]
475    fn purge_clears_segments() {
476        let backend = MemoryBackend::new();
477        backend.write_segment(T, "col", "a", b"a").unwrap();
478        backend.write_segment(T, "other", "b", b"b").unwrap();
479
480        backend.purge_collection(T, "col").unwrap();
481        assert!(backend.list_segments(T, "col").unwrap().is_empty());
482        assert_eq!(backend.list_segments(T, "other").unwrap().len(), 1);
483    }
484
485    #[test]
486    fn purge_tenant_isolates_tenants() {
487        let backend = MemoryBackend::new();
488        backend.increment_stats(1, "col", 5).unwrap();
489        backend.increment_stats(2, "col", 7).unwrap();
490        backend
491            .write_postings(
492                1,
493                "col",
494                "t",
495                &[Posting {
496                    doc_id: Surrogate(1),
497                    term_freq: 1,
498                    positions: vec![0],
499                }],
500            )
501            .unwrap();
502        backend
503            .write_postings(
504                2,
505                "col",
506                "t",
507                &[Posting {
508                    doc_id: Surrogate(1),
509                    term_freq: 1,
510                    positions: vec![0],
511                }],
512            )
513            .unwrap();
514
515        backend.purge_tenant(1).unwrap();
516        assert_eq!(backend.collection_stats(1, "col").unwrap(), (0, 0));
517        assert!(backend.read_postings(1, "col", "t").unwrap().is_empty());
518        assert_eq!(backend.collection_stats(2, "col").unwrap(), (1, 7));
519        assert_eq!(backend.read_postings(2, "col", "t").unwrap().len(), 1);
520    }
521}