Skip to main content

nodedb_fts/backend/
memory.rs

1//! In-memory FTS backend for Lite and WASM deployments.
2//!
3//! All data lives in HashMaps behind `RefCell` for interior mutability,
4//! matching the `&self` trait signature. Rebuilt from documents on cold
5//! start — acceptable for edge-scale datasets.
6//!
7//! Keys are fully structural tuples `(tid, collection, …)` — tenant
8//! isolation never depends on lexical-prefix ordering.
9
10use std::cell::RefCell;
11use std::collections::HashMap;
12use std::fmt;
13
14use crate::backend::FtsBackend;
15use crate::posting::Posting;
16
17/// In-memory backend error (infallible in practice, but trait requires it).
18#[derive(Debug)]
19pub struct MemoryError(String);
20
21impl fmt::Display for MemoryError {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        write!(f, "memory backend: {}", self.0)
24    }
25}
26
27type TripleKey = (u32, String, String);
28type PairKey = (u32, String);
29
30/// In-memory FTS backend backed by HashMaps keyed by `(tid, collection, …)`
31/// tuples.
32///
33/// Uses `RefCell` for interior mutability so the `FtsBackend` trait
34/// can use `&self` uniformly (redb has its own transactional isolation).
35#[derive(Debug, Default)]
36pub struct MemoryBackend {
37    /// `(tid, collection, term) → posting list`.
38    postings: RefCell<HashMap<TripleKey, Vec<Posting>>>,
39    /// `(tid, collection, doc_id) → token count`.
40    doc_lengths: RefCell<HashMap<TripleKey, u32>>,
41    /// `(tid, collection) → (doc_count, total_token_sum)`.
42    stats: RefCell<HashMap<PairKey, (u32, u64)>>,
43    /// `(tid, collection, subkey) → blob` for docmap, fieldnorms, analyzer, language.
44    meta: RefCell<HashMap<TripleKey, Vec<u8>>>,
45    /// `(tid, collection, segment_id) → compressed segment bytes`.
46    segments: RefCell<HashMap<TripleKey, Vec<u8>>>,
47}
48
49impl MemoryBackend {
50    pub fn new() -> Self {
51        Self::default()
52    }
53}
54
55fn triple(tid: u32, collection: &str, sub: &str) -> TripleKey {
56    (tid, collection.to_string(), sub.to_string())
57}
58
59fn pair(tid: u32, collection: &str) -> PairKey {
60    (tid, collection.to_string())
61}
62
63impl FtsBackend for MemoryBackend {
64    type Error = MemoryError;
65
66    fn read_postings(
67        &self,
68        tid: u32,
69        collection: &str,
70        term: &str,
71    ) -> Result<Vec<Posting>, Self::Error> {
72        Ok(self
73            .postings
74            .borrow()
75            .get(&triple(tid, collection, term))
76            .cloned()
77            .unwrap_or_default())
78    }
79
80    fn write_postings(
81        &self,
82        tid: u32,
83        collection: &str,
84        term: &str,
85        postings: &[Posting],
86    ) -> Result<(), Self::Error> {
87        let key = triple(tid, collection, term);
88        let mut map = self.postings.borrow_mut();
89        if postings.is_empty() {
90            map.remove(&key);
91        } else {
92            map.insert(key, postings.to_vec());
93        }
94        Ok(())
95    }
96
97    fn remove_postings(&self, tid: u32, collection: &str, term: &str) -> Result<(), Self::Error> {
98        self.postings
99            .borrow_mut()
100            .remove(&triple(tid, collection, term));
101        Ok(())
102    }
103
104    fn read_doc_length(
105        &self,
106        tid: u32,
107        collection: &str,
108        doc_id: &str,
109    ) -> Result<Option<u32>, Self::Error> {
110        Ok(self
111            .doc_lengths
112            .borrow()
113            .get(&triple(tid, collection, doc_id))
114            .copied())
115    }
116
117    fn write_doc_length(
118        &self,
119        tid: u32,
120        collection: &str,
121        doc_id: &str,
122        length: u32,
123    ) -> Result<(), Self::Error> {
124        self.doc_lengths
125            .borrow_mut()
126            .insert(triple(tid, collection, doc_id), length);
127        Ok(())
128    }
129
130    fn remove_doc_length(
131        &self,
132        tid: u32,
133        collection: &str,
134        doc_id: &str,
135    ) -> Result<(), Self::Error> {
136        self.doc_lengths
137            .borrow_mut()
138            .remove(&triple(tid, collection, doc_id));
139        Ok(())
140    }
141
142    fn collection_terms(&self, tid: u32, collection: &str) -> Result<Vec<String>, Self::Error> {
143        Ok(self
144            .postings
145            .borrow()
146            .keys()
147            .filter(|(t, c, _)| *t == tid && c == collection)
148            .map(|(_, _, term)| term.clone())
149            .collect())
150    }
151
152    fn collection_stats(&self, tid: u32, collection: &str) -> Result<(u32, u64), Self::Error> {
153        Ok(self
154            .stats
155            .borrow()
156            .get(&pair(tid, collection))
157            .copied()
158            .unwrap_or((0, 0)))
159    }
160
161    fn increment_stats(&self, tid: u32, collection: &str, doc_len: u32) -> Result<(), Self::Error> {
162        let mut stats = self.stats.borrow_mut();
163        let entry = stats.entry(pair(tid, collection)).or_insert((0, 0));
164        entry.0 += 1;
165        entry.1 += doc_len as u64;
166        Ok(())
167    }
168
169    fn decrement_stats(&self, tid: u32, collection: &str, doc_len: u32) -> Result<(), Self::Error> {
170        let mut stats = self.stats.borrow_mut();
171        let entry = stats.entry(pair(tid, collection)).or_insert((0, 0));
172        entry.0 = entry.0.saturating_sub(1);
173        entry.1 = entry.1.saturating_sub(doc_len as u64);
174        Ok(())
175    }
176
177    fn read_meta(
178        &self,
179        tid: u32,
180        collection: &str,
181        subkey: &str,
182    ) -> Result<Option<Vec<u8>>, Self::Error> {
183        Ok(self
184            .meta
185            .borrow()
186            .get(&triple(tid, collection, subkey))
187            .cloned())
188    }
189
190    fn write_meta(
191        &self,
192        tid: u32,
193        collection: &str,
194        subkey: &str,
195        value: &[u8],
196    ) -> Result<(), Self::Error> {
197        self.meta
198            .borrow_mut()
199            .insert(triple(tid, collection, subkey), value.to_vec());
200        Ok(())
201    }
202
203    fn write_segment(
204        &self,
205        tid: u32,
206        collection: &str,
207        segment_id: &str,
208        data: &[u8],
209    ) -> Result<(), Self::Error> {
210        self.segments
211            .borrow_mut()
212            .insert(triple(tid, collection, segment_id), data.to_vec());
213        Ok(())
214    }
215
216    fn read_segment(
217        &self,
218        tid: u32,
219        collection: &str,
220        segment_id: &str,
221    ) -> Result<Option<Vec<u8>>, Self::Error> {
222        Ok(self
223            .segments
224            .borrow()
225            .get(&triple(tid, collection, segment_id))
226            .cloned())
227    }
228
229    fn list_segments(&self, tid: u32, collection: &str) -> Result<Vec<String>, Self::Error> {
230        Ok(self
231            .segments
232            .borrow()
233            .keys()
234            .filter(|(t, c, _)| *t == tid && c == collection)
235            .map(|(_, _, seg)| seg.clone())
236            .collect())
237    }
238
239    fn remove_segment(
240        &self,
241        tid: u32,
242        collection: &str,
243        segment_id: &str,
244    ) -> Result<(), Self::Error> {
245        self.segments
246            .borrow_mut()
247            .remove(&triple(tid, collection, segment_id));
248        Ok(())
249    }
250
251    fn purge_collection(&self, tid: u32, collection: &str) -> Result<usize, Self::Error> {
252        let match_tc = |(t, c, _): &&TripleKey| *t == tid && c == collection;
253
254        let mut postings = self.postings.borrow_mut();
255        let mut doc_lengths = self.doc_lengths.borrow_mut();
256        let before = postings.len() + doc_lengths.len();
257        postings.retain(|k, _| !(k.0 == tid && k.1 == collection));
258        doc_lengths.retain(|k, _| !(k.0 == tid && k.1 == collection));
259        self.stats.borrow_mut().remove(&pair(tid, collection));
260        self.meta
261            .borrow_mut()
262            .retain(|k, _| !(k.0 == tid && k.1 == collection));
263        self.segments
264            .borrow_mut()
265            .retain(|k, _| !(k.0 == tid && k.1 == collection));
266        let after = postings.len() + doc_lengths.len();
267        let _ = match_tc;
268        Ok(before - after)
269    }
270
271    fn purge_tenant(&self, tid: u32) -> Result<usize, Self::Error> {
272        let mut postings = self.postings.borrow_mut();
273        let mut doc_lengths = self.doc_lengths.borrow_mut();
274        let before = postings.len() + doc_lengths.len();
275        postings.retain(|k, _| k.0 != tid);
276        doc_lengths.retain(|k, _| k.0 != tid);
277        self.stats.borrow_mut().retain(|k, _| k.0 != tid);
278        self.meta.borrow_mut().retain(|k, _| k.0 != tid);
279        self.segments.borrow_mut().retain(|k, _| k.0 != tid);
280        let after = postings.len() + doc_lengths.len();
281        Ok(before - after)
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    const T: u32 = 1;
290
291    #[test]
292    fn roundtrip_postings() {
293        let backend = MemoryBackend::new();
294        let postings = vec![Posting {
295            doc_id: "d1".into(),
296            term_freq: 2,
297            positions: vec![0, 5],
298        }];
299        backend
300            .write_postings(T, "col", "hello", &postings)
301            .unwrap();
302
303        let read = backend.read_postings(T, "col", "hello").unwrap();
304        assert_eq!(read.len(), 1);
305        assert_eq!(read[0].doc_id, "d1");
306    }
307
308    #[test]
309    fn roundtrip_doc_lengths() {
310        let backend = MemoryBackend::new();
311        backend.write_doc_length(T, "col", "d1", 42).unwrap();
312        assert_eq!(backend.read_doc_length(T, "col", "d1").unwrap(), Some(42));
313
314        backend.remove_doc_length(T, "col", "d1").unwrap();
315        assert_eq!(backend.read_doc_length(T, "col", "d1").unwrap(), None);
316    }
317
318    #[test]
319    fn incremental_stats() {
320        let backend = MemoryBackend::new();
321        backend.increment_stats(T, "col", 10).unwrap();
322        backend.increment_stats(T, "col", 20).unwrap();
323        assert_eq!(backend.collection_stats(T, "col").unwrap(), (2, 30));
324
325        backend.decrement_stats(T, "col", 10).unwrap();
326        assert_eq!(backend.collection_stats(T, "col").unwrap(), (1, 20));
327    }
328
329    #[test]
330    fn stats_saturating_sub() {
331        let backend = MemoryBackend::new();
332        backend.decrement_stats(T, "col", 100).unwrap();
333        assert_eq!(backend.collection_stats(T, "col").unwrap(), (0, 0));
334    }
335
336    #[test]
337    fn purge_clears_stats_and_isolates_collections() {
338        let backend = MemoryBackend::new();
339        backend.increment_stats(T, "col", 10).unwrap();
340        backend.write_doc_length(T, "col", "d1", 10).unwrap();
341        backend
342            .write_postings(
343                T,
344                "col",
345                "hello",
346                &[Posting {
347                    doc_id: "d1".into(),
348                    term_freq: 1,
349                    positions: vec![0],
350                }],
351            )
352            .unwrap();
353
354        backend.increment_stats(T, "other", 7).unwrap();
355        backend.write_doc_length(T, "other", "d1", 7).unwrap();
356        backend
357            .write_postings(
358                T,
359                "other",
360                "world",
361                &[Posting {
362                    doc_id: "d1".into(),
363                    term_freq: 1,
364                    positions: vec![0],
365                }],
366            )
367            .unwrap();
368
369        backend.purge_collection(T, "col").unwrap();
370        assert_eq!(backend.collection_stats(T, "col").unwrap(), (0, 0));
371        assert!(backend.read_postings(T, "col", "hello").unwrap().is_empty());
372        assert_eq!(backend.read_doc_length(T, "col", "d1").unwrap(), None);
373
374        assert_eq!(backend.collection_stats(T, "other").unwrap(), (1, 7));
375        assert_eq!(backend.read_postings(T, "other", "world").unwrap().len(), 1);
376        assert_eq!(backend.read_doc_length(T, "other", "d1").unwrap(), Some(7));
377    }
378
379    #[test]
380    fn collection_terms() {
381        let backend = MemoryBackend::new();
382        backend
383            .write_postings(
384                T,
385                "col",
386                "hello",
387                &[Posting {
388                    doc_id: "d1".into(),
389                    term_freq: 1,
390                    positions: vec![0],
391                }],
392            )
393            .unwrap();
394        backend
395            .write_postings(
396                T,
397                "col",
398                "world",
399                &[Posting {
400                    doc_id: "d1".into(),
401                    term_freq: 1,
402                    positions: vec![1],
403                }],
404            )
405            .unwrap();
406
407        let mut terms = backend.collection_terms(T, "col").unwrap();
408        terms.sort();
409        assert_eq!(terms, vec!["hello", "world"]);
410    }
411
412    #[test]
413    fn segment_roundtrip() {
414        let backend = MemoryBackend::new();
415        let data = b"compressed segment bytes";
416        backend.write_segment(T, "col", "id1", data).unwrap();
417        assert_eq!(
418            backend.read_segment(T, "col", "id1").unwrap(),
419            Some(data.to_vec())
420        );
421        assert_eq!(backend.read_segment(T, "col", "missing").unwrap(), None);
422    }
423
424    #[test]
425    fn segment_list_filters_by_collection() {
426        let backend = MemoryBackend::new();
427        backend.write_segment(T, "col", "a", b"a").unwrap();
428        backend.write_segment(T, "col", "b", b"b").unwrap();
429        backend.write_segment(T, "other", "c", b"c").unwrap();
430
431        let mut segs = backend.list_segments(T, "col").unwrap();
432        segs.sort();
433        assert_eq!(segs, vec!["a", "b"]);
434
435        let other = backend.list_segments(T, "other").unwrap();
436        assert_eq!(other, vec!["c"]);
437    }
438
439    #[test]
440    fn segment_remove() {
441        let backend = MemoryBackend::new();
442        backend.write_segment(T, "col", "id1", b"data").unwrap();
443        backend.remove_segment(T, "col", "id1").unwrap();
444        assert_eq!(backend.read_segment(T, "col", "id1").unwrap(), None);
445    }
446
447    #[test]
448    fn purge_clears_segments() {
449        let backend = MemoryBackend::new();
450        backend.write_segment(T, "col", "a", b"a").unwrap();
451        backend.write_segment(T, "other", "b", b"b").unwrap();
452
453        backend.purge_collection(T, "col").unwrap();
454        assert!(backend.list_segments(T, "col").unwrap().is_empty());
455        assert_eq!(backend.list_segments(T, "other").unwrap().len(), 1);
456    }
457
458    #[test]
459    fn purge_tenant_isolates_tenants() {
460        let backend = MemoryBackend::new();
461        backend.increment_stats(1, "col", 5).unwrap();
462        backend.increment_stats(2, "col", 7).unwrap();
463        backend
464            .write_postings(
465                1,
466                "col",
467                "t",
468                &[Posting {
469                    doc_id: "d".into(),
470                    term_freq: 1,
471                    positions: vec![0],
472                }],
473            )
474            .unwrap();
475        backend
476            .write_postings(
477                2,
478                "col",
479                "t",
480                &[Posting {
481                    doc_id: "d".into(),
482                    term_freq: 1,
483                    positions: vec![0],
484                }],
485            )
486            .unwrap();
487
488        backend.purge_tenant(1).unwrap();
489        assert_eq!(backend.collection_stats(1, "col").unwrap(), (0, 0));
490        assert!(backend.read_postings(1, "col", "t").unwrap().is_empty());
491        assert_eq!(backend.collection_stats(2, "col").unwrap(), (1, 7));
492        assert_eq!(backend.read_postings(2, "col", "t").unwrap().len(), 1);
493    }
494}