Skip to main content

nodedb_fts/backend/
memory.rs

1//! In-memory FTS backend for Lite and WASM deployments.
2//!
3//! All data lives in HashMaps behind `RefCell` for interior mutability,
4//! matching the `&self` trait signature. Rebuilt from documents on cold
5//! start — acceptable for edge-scale datasets.
6
7use std::cell::RefCell;
8use std::collections::HashMap;
9use std::fmt;
10
11use crate::backend::FtsBackend;
12use crate::posting::Posting;
13
14/// In-memory backend error (infallible in practice, but trait requires it).
15#[derive(Debug)]
16pub struct MemoryError(String);
17
18impl fmt::Display for MemoryError {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        write!(f, "memory backend: {}", self.0)
21    }
22}
23
24/// In-memory FTS backend backed by HashMaps.
25///
26/// Keys are stored as `"{collection}:{term}"` for postings and
27/// `"{collection}:{doc_id}"` for document lengths, matching the
28/// scoping pattern used by the redb backend.
29///
30/// Uses `RefCell` for interior mutability so the `FtsBackend` trait
31/// can use `&self` uniformly (redb has its own transactional isolation).
32#[derive(Debug, Default)]
33pub struct MemoryBackend {
34    /// Scoped key "{collection}:{term}" → posting list.
35    postings: RefCell<HashMap<String, Vec<Posting>>>,
36    /// Scoped key "{collection}:{doc_id}" → token count.
37    doc_lengths: RefCell<HashMap<String, u32>>,
38    /// Per-collection incremental stats: collection → (doc_count, total_token_sum).
39    stats: RefCell<HashMap<String, (u32, u64)>>,
40    /// Generic metadata blobs (DocIdMap, fieldnorms, etc.).
41    meta: RefCell<HashMap<String, Vec<u8>>>,
42    /// Segment blobs: key → compressed segment bytes.
43    segments: RefCell<HashMap<String, Vec<u8>>>,
44}
45
46impl MemoryBackend {
47    pub fn new() -> Self {
48        Self::default()
49    }
50}
51
52impl FtsBackend for MemoryBackend {
53    type Error = MemoryError;
54
55    fn read_postings(&self, collection: &str, term: &str) -> Result<Vec<Posting>, Self::Error> {
56        let key = format!("{collection}:{term}");
57        Ok(self
58            .postings
59            .borrow()
60            .get(&key)
61            .cloned()
62            .unwrap_or_default())
63    }
64
65    fn write_postings(
66        &self,
67        collection: &str,
68        term: &str,
69        postings: &[Posting],
70    ) -> Result<(), Self::Error> {
71        let key = format!("{collection}:{term}");
72        let mut map = self.postings.borrow_mut();
73        if postings.is_empty() {
74            map.remove(&key);
75        } else {
76            map.insert(key, postings.to_vec());
77        }
78        Ok(())
79    }
80
81    fn remove_postings(&self, collection: &str, term: &str) -> Result<(), Self::Error> {
82        let key = format!("{collection}:{term}");
83        self.postings.borrow_mut().remove(&key);
84        Ok(())
85    }
86
87    fn read_doc_length(&self, collection: &str, doc_id: &str) -> Result<Option<u32>, Self::Error> {
88        let key = format!("{collection}:{doc_id}");
89        Ok(self.doc_lengths.borrow().get(&key).copied())
90    }
91
92    fn write_doc_length(
93        &self,
94        collection: &str,
95        doc_id: &str,
96        length: u32,
97    ) -> Result<(), Self::Error> {
98        let key = format!("{collection}:{doc_id}");
99        self.doc_lengths.borrow_mut().insert(key, length);
100        Ok(())
101    }
102
103    fn remove_doc_length(&self, collection: &str, doc_id: &str) -> Result<(), Self::Error> {
104        let key = format!("{collection}:{doc_id}");
105        self.doc_lengths.borrow_mut().remove(&key);
106        Ok(())
107    }
108
109    fn collection_terms(&self, collection: &str) -> Result<Vec<String>, Self::Error> {
110        let prefix = format!("{collection}:");
111        Ok(self
112            .postings
113            .borrow()
114            .keys()
115            .filter_map(|k| k.strip_prefix(&prefix).map(String::from))
116            .collect())
117    }
118
119    fn collection_stats(&self, collection: &str) -> Result<(u32, u64), Self::Error> {
120        Ok(self
121            .stats
122            .borrow()
123            .get(collection)
124            .copied()
125            .unwrap_or((0, 0)))
126    }
127
128    fn increment_stats(&self, collection: &str, doc_len: u32) -> Result<(), Self::Error> {
129        let mut stats = self.stats.borrow_mut();
130        let entry = stats.entry(collection.to_string()).or_insert((0, 0));
131        entry.0 += 1;
132        entry.1 += doc_len as u64;
133        Ok(())
134    }
135
136    fn decrement_stats(&self, collection: &str, doc_len: u32) -> Result<(), Self::Error> {
137        let mut stats = self.stats.borrow_mut();
138        let entry = stats.entry(collection.to_string()).or_insert((0, 0));
139        entry.0 = entry.0.saturating_sub(1);
140        entry.1 = entry.1.saturating_sub(doc_len as u64);
141        Ok(())
142    }
143
144    fn read_meta(&self, key: &str) -> Result<Option<Vec<u8>>, Self::Error> {
145        Ok(self.meta.borrow().get(key).cloned())
146    }
147
148    fn write_meta(&self, key: &str, value: &[u8]) -> Result<(), Self::Error> {
149        self.meta
150            .borrow_mut()
151            .insert(key.to_string(), value.to_vec());
152        Ok(())
153    }
154
155    fn write_segment(&self, key: &str, data: &[u8]) -> Result<(), Self::Error> {
156        self.segments
157            .borrow_mut()
158            .insert(key.to_string(), data.to_vec());
159        Ok(())
160    }
161
162    fn read_segment(&self, key: &str) -> Result<Option<Vec<u8>>, Self::Error> {
163        Ok(self.segments.borrow().get(key).cloned())
164    }
165
166    fn list_segments(&self, collection: &str) -> Result<Vec<String>, Self::Error> {
167        let prefix = format!("{collection}:seg:");
168        Ok(self
169            .segments
170            .borrow()
171            .keys()
172            .filter(|k| k.starts_with(&prefix))
173            .cloned()
174            .collect())
175    }
176
177    fn remove_segment(&self, key: &str) -> Result<(), Self::Error> {
178        self.segments.borrow_mut().remove(key);
179        Ok(())
180    }
181
182    fn purge_collection(&self, collection: &str) -> Result<usize, Self::Error> {
183        let prefix = format!("{collection}:");
184        let mut postings = self.postings.borrow_mut();
185        let mut doc_lengths = self.doc_lengths.borrow_mut();
186        let before = postings.len() + doc_lengths.len();
187        postings.retain(|k, _| !k.starts_with(&prefix));
188        doc_lengths.retain(|k, _| !k.starts_with(&prefix));
189        self.stats.borrow_mut().remove(collection);
190        let meta_prefix = format!("{collection}:");
191        self.meta
192            .borrow_mut()
193            .retain(|k, _| !k.starts_with(&meta_prefix));
194        self.segments
195            .borrow_mut()
196            .retain(|k, _| !k.starts_with(&prefix));
197        let after = postings.len() + doc_lengths.len();
198        Ok(before - after)
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn roundtrip_postings() {
208        let backend = MemoryBackend::new();
209        let postings = vec![Posting {
210            doc_id: "d1".into(),
211            term_freq: 2,
212            positions: vec![0, 5],
213        }];
214        backend.write_postings("col", "hello", &postings).unwrap();
215
216        let read = backend.read_postings("col", "hello").unwrap();
217        assert_eq!(read.len(), 1);
218        assert_eq!(read[0].doc_id, "d1");
219    }
220
221    #[test]
222    fn roundtrip_doc_lengths() {
223        let backend = MemoryBackend::new();
224        backend.write_doc_length("col", "d1", 42).unwrap();
225        assert_eq!(backend.read_doc_length("col", "d1").unwrap(), Some(42));
226
227        backend.remove_doc_length("col", "d1").unwrap();
228        assert_eq!(backend.read_doc_length("col", "d1").unwrap(), None);
229    }
230
231    #[test]
232    fn incremental_stats() {
233        let backend = MemoryBackend::new();
234        backend.increment_stats("col", 10).unwrap();
235        backend.increment_stats("col", 20).unwrap();
236        assert_eq!(backend.collection_stats("col").unwrap(), (2, 30));
237
238        backend.decrement_stats("col", 10).unwrap();
239        assert_eq!(backend.collection_stats("col").unwrap(), (1, 20));
240    }
241
242    #[test]
243    fn stats_saturating_sub() {
244        let backend = MemoryBackend::new();
245        backend.decrement_stats("col", 100).unwrap();
246        assert_eq!(backend.collection_stats("col").unwrap(), (0, 0));
247    }
248
249    #[test]
250    fn purge_clears_stats_and_isolates_collections() {
251        let backend = MemoryBackend::new();
252        // Set up two collections.
253        backend.increment_stats("col", 10).unwrap();
254        backend.write_doc_length("col", "d1", 10).unwrap();
255        backend
256            .write_postings(
257                "col",
258                "hello",
259                &[Posting {
260                    doc_id: "d1".into(),
261                    term_freq: 1,
262                    positions: vec![0],
263                }],
264            )
265            .unwrap();
266
267        backend.increment_stats("other", 7).unwrap();
268        backend.write_doc_length("other", "d1", 7).unwrap();
269        backend
270            .write_postings(
271                "other",
272                "world",
273                &[Posting {
274                    doc_id: "d1".into(),
275                    term_freq: 1,
276                    positions: vec![0],
277                }],
278            )
279            .unwrap();
280
281        // Purge only "col".
282        backend.purge_collection("col").unwrap();
283        assert_eq!(backend.collection_stats("col").unwrap(), (0, 0));
284        assert!(backend.read_postings("col", "hello").unwrap().is_empty());
285        assert_eq!(backend.read_doc_length("col", "d1").unwrap(), None);
286
287        // "other" must be completely unaffected.
288        assert_eq!(backend.collection_stats("other").unwrap(), (1, 7));
289        assert_eq!(backend.read_postings("other", "world").unwrap().len(), 1);
290        assert_eq!(backend.read_doc_length("other", "d1").unwrap(), Some(7));
291    }
292
293    #[test]
294    fn collection_terms() {
295        let backend = MemoryBackend::new();
296        backend
297            .write_postings(
298                "col",
299                "hello",
300                &[Posting {
301                    doc_id: "d1".into(),
302                    term_freq: 1,
303                    positions: vec![0],
304                }],
305            )
306            .unwrap();
307        backend
308            .write_postings(
309                "col",
310                "world",
311                &[Posting {
312                    doc_id: "d1".into(),
313                    term_freq: 1,
314                    positions: vec![1],
315                }],
316            )
317            .unwrap();
318
319        let mut terms = backend.collection_terms("col").unwrap();
320        terms.sort();
321        assert_eq!(terms, vec!["hello", "world"]);
322    }
323
324    #[test]
325    fn segment_roundtrip() {
326        let backend = MemoryBackend::new();
327        let data = b"compressed segment bytes";
328        backend.write_segment("col:seg:id1", data).unwrap();
329        assert_eq!(
330            backend.read_segment("col:seg:id1").unwrap(),
331            Some(data.to_vec())
332        );
333        assert_eq!(backend.read_segment("col:seg:missing").unwrap(), None);
334    }
335
336    #[test]
337    fn segment_list_filters_by_collection() {
338        let backend = MemoryBackend::new();
339        backend.write_segment("col:seg:a", b"a").unwrap();
340        backend.write_segment("col:seg:b", b"b").unwrap();
341        backend.write_segment("other:seg:c", b"c").unwrap();
342
343        let mut segs = backend.list_segments("col").unwrap();
344        segs.sort();
345        assert_eq!(segs, vec!["col:seg:a", "col:seg:b"]);
346
347        let other = backend.list_segments("other").unwrap();
348        assert_eq!(other, vec!["other:seg:c"]);
349    }
350
351    #[test]
352    fn segment_remove() {
353        let backend = MemoryBackend::new();
354        backend.write_segment("col:seg:id1", b"data").unwrap();
355        backend.remove_segment("col:seg:id1").unwrap();
356        assert_eq!(backend.read_segment("col:seg:id1").unwrap(), None);
357    }
358
359    #[test]
360    fn purge_clears_segments() {
361        let backend = MemoryBackend::new();
362        backend.write_segment("col:seg:a", b"a").unwrap();
363        backend.write_segment("other:seg:b", b"b").unwrap();
364
365        backend.purge_collection("col").unwrap();
366        assert!(backend.list_segments("col").unwrap().is_empty());
367        assert_eq!(backend.list_segments("other").unwrap().len(), 1);
368    }
369}