1use std::cell::RefCell;
8use std::collections::HashMap;
9use std::fmt;
10
11use crate::backend::FtsBackend;
12use crate::posting::Posting;
13
14#[derive(Debug)]
16pub struct MemoryError(String);
17
18impl fmt::Display for MemoryError {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 write!(f, "memory backend: {}", self.0)
21 }
22}
23
24#[derive(Debug, Default)]
33pub struct MemoryBackend {
34 postings: RefCell<HashMap<String, Vec<Posting>>>,
36 doc_lengths: RefCell<HashMap<String, u32>>,
38 stats: RefCell<HashMap<String, (u32, u64)>>,
40 meta: RefCell<HashMap<String, Vec<u8>>>,
42 segments: RefCell<HashMap<String, Vec<u8>>>,
44}
45
46impl MemoryBackend {
47 pub fn new() -> Self {
48 Self::default()
49 }
50}
51
52impl FtsBackend for MemoryBackend {
53 type Error = MemoryError;
54
55 fn read_postings(&self, collection: &str, term: &str) -> Result<Vec<Posting>, Self::Error> {
56 let key = format!("{collection}:{term}");
57 Ok(self
58 .postings
59 .borrow()
60 .get(&key)
61 .cloned()
62 .unwrap_or_default())
63 }
64
65 fn write_postings(
66 &self,
67 collection: &str,
68 term: &str,
69 postings: &[Posting],
70 ) -> Result<(), Self::Error> {
71 let key = format!("{collection}:{term}");
72 let mut map = self.postings.borrow_mut();
73 if postings.is_empty() {
74 map.remove(&key);
75 } else {
76 map.insert(key, postings.to_vec());
77 }
78 Ok(())
79 }
80
81 fn remove_postings(&self, collection: &str, term: &str) -> Result<(), Self::Error> {
82 let key = format!("{collection}:{term}");
83 self.postings.borrow_mut().remove(&key);
84 Ok(())
85 }
86
87 fn read_doc_length(&self, collection: &str, doc_id: &str) -> Result<Option<u32>, Self::Error> {
88 let key = format!("{collection}:{doc_id}");
89 Ok(self.doc_lengths.borrow().get(&key).copied())
90 }
91
92 fn write_doc_length(
93 &self,
94 collection: &str,
95 doc_id: &str,
96 length: u32,
97 ) -> Result<(), Self::Error> {
98 let key = format!("{collection}:{doc_id}");
99 self.doc_lengths.borrow_mut().insert(key, length);
100 Ok(())
101 }
102
103 fn remove_doc_length(&self, collection: &str, doc_id: &str) -> Result<(), Self::Error> {
104 let key = format!("{collection}:{doc_id}");
105 self.doc_lengths.borrow_mut().remove(&key);
106 Ok(())
107 }
108
109 fn collection_terms(&self, collection: &str) -> Result<Vec<String>, Self::Error> {
110 let prefix = format!("{collection}:");
111 Ok(self
112 .postings
113 .borrow()
114 .keys()
115 .filter_map(|k| k.strip_prefix(&prefix).map(String::from))
116 .collect())
117 }
118
119 fn collection_stats(&self, collection: &str) -> Result<(u32, u64), Self::Error> {
120 Ok(self
121 .stats
122 .borrow()
123 .get(collection)
124 .copied()
125 .unwrap_or((0, 0)))
126 }
127
128 fn increment_stats(&self, collection: &str, doc_len: u32) -> Result<(), Self::Error> {
129 let mut stats = self.stats.borrow_mut();
130 let entry = stats.entry(collection.to_string()).or_insert((0, 0));
131 entry.0 += 1;
132 entry.1 += doc_len as u64;
133 Ok(())
134 }
135
136 fn decrement_stats(&self, collection: &str, doc_len: u32) -> Result<(), Self::Error> {
137 let mut stats = self.stats.borrow_mut();
138 let entry = stats.entry(collection.to_string()).or_insert((0, 0));
139 entry.0 = entry.0.saturating_sub(1);
140 entry.1 = entry.1.saturating_sub(doc_len as u64);
141 Ok(())
142 }
143
144 fn read_meta(&self, key: &str) -> Result<Option<Vec<u8>>, Self::Error> {
145 Ok(self.meta.borrow().get(key).cloned())
146 }
147
148 fn write_meta(&self, key: &str, value: &[u8]) -> Result<(), Self::Error> {
149 self.meta
150 .borrow_mut()
151 .insert(key.to_string(), value.to_vec());
152 Ok(())
153 }
154
155 fn write_segment(&self, key: &str, data: &[u8]) -> Result<(), Self::Error> {
156 self.segments
157 .borrow_mut()
158 .insert(key.to_string(), data.to_vec());
159 Ok(())
160 }
161
162 fn read_segment(&self, key: &str) -> Result<Option<Vec<u8>>, Self::Error> {
163 Ok(self.segments.borrow().get(key).cloned())
164 }
165
166 fn list_segments(&self, collection: &str) -> Result<Vec<String>, Self::Error> {
167 let prefix = format!("{collection}:seg:");
168 Ok(self
169 .segments
170 .borrow()
171 .keys()
172 .filter(|k| k.starts_with(&prefix))
173 .cloned()
174 .collect())
175 }
176
177 fn remove_segment(&self, key: &str) -> Result<(), Self::Error> {
178 self.segments.borrow_mut().remove(key);
179 Ok(())
180 }
181
182 fn purge_collection(&self, collection: &str) -> Result<usize, Self::Error> {
183 let prefix = format!("{collection}:");
184 let mut postings = self.postings.borrow_mut();
185 let mut doc_lengths = self.doc_lengths.borrow_mut();
186 let before = postings.len() + doc_lengths.len();
187 postings.retain(|k, _| !k.starts_with(&prefix));
188 doc_lengths.retain(|k, _| !k.starts_with(&prefix));
189 self.stats.borrow_mut().remove(collection);
190 let meta_prefix = format!("{collection}:");
191 self.meta
192 .borrow_mut()
193 .retain(|k, _| !k.starts_with(&meta_prefix));
194 self.segments
195 .borrow_mut()
196 .retain(|k, _| !k.starts_with(&prefix));
197 let after = postings.len() + doc_lengths.len();
198 Ok(before - after)
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn roundtrip_postings() {
208 let backend = MemoryBackend::new();
209 let postings = vec![Posting {
210 doc_id: "d1".into(),
211 term_freq: 2,
212 positions: vec![0, 5],
213 }];
214 backend.write_postings("col", "hello", &postings).unwrap();
215
216 let read = backend.read_postings("col", "hello").unwrap();
217 assert_eq!(read.len(), 1);
218 assert_eq!(read[0].doc_id, "d1");
219 }
220
221 #[test]
222 fn roundtrip_doc_lengths() {
223 let backend = MemoryBackend::new();
224 backend.write_doc_length("col", "d1", 42).unwrap();
225 assert_eq!(backend.read_doc_length("col", "d1").unwrap(), Some(42));
226
227 backend.remove_doc_length("col", "d1").unwrap();
228 assert_eq!(backend.read_doc_length("col", "d1").unwrap(), None);
229 }
230
231 #[test]
232 fn incremental_stats() {
233 let backend = MemoryBackend::new();
234 backend.increment_stats("col", 10).unwrap();
235 backend.increment_stats("col", 20).unwrap();
236 assert_eq!(backend.collection_stats("col").unwrap(), (2, 30));
237
238 backend.decrement_stats("col", 10).unwrap();
239 assert_eq!(backend.collection_stats("col").unwrap(), (1, 20));
240 }
241
242 #[test]
243 fn stats_saturating_sub() {
244 let backend = MemoryBackend::new();
245 backend.decrement_stats("col", 100).unwrap();
246 assert_eq!(backend.collection_stats("col").unwrap(), (0, 0));
247 }
248
249 #[test]
250 fn purge_clears_stats_and_isolates_collections() {
251 let backend = MemoryBackend::new();
252 backend.increment_stats("col", 10).unwrap();
254 backend.write_doc_length("col", "d1", 10).unwrap();
255 backend
256 .write_postings(
257 "col",
258 "hello",
259 &[Posting {
260 doc_id: "d1".into(),
261 term_freq: 1,
262 positions: vec![0],
263 }],
264 )
265 .unwrap();
266
267 backend.increment_stats("other", 7).unwrap();
268 backend.write_doc_length("other", "d1", 7).unwrap();
269 backend
270 .write_postings(
271 "other",
272 "world",
273 &[Posting {
274 doc_id: "d1".into(),
275 term_freq: 1,
276 positions: vec![0],
277 }],
278 )
279 .unwrap();
280
281 backend.purge_collection("col").unwrap();
283 assert_eq!(backend.collection_stats("col").unwrap(), (0, 0));
284 assert!(backend.read_postings("col", "hello").unwrap().is_empty());
285 assert_eq!(backend.read_doc_length("col", "d1").unwrap(), None);
286
287 assert_eq!(backend.collection_stats("other").unwrap(), (1, 7));
289 assert_eq!(backend.read_postings("other", "world").unwrap().len(), 1);
290 assert_eq!(backend.read_doc_length("other", "d1").unwrap(), Some(7));
291 }
292
293 #[test]
294 fn collection_terms() {
295 let backend = MemoryBackend::new();
296 backend
297 .write_postings(
298 "col",
299 "hello",
300 &[Posting {
301 doc_id: "d1".into(),
302 term_freq: 1,
303 positions: vec![0],
304 }],
305 )
306 .unwrap();
307 backend
308 .write_postings(
309 "col",
310 "world",
311 &[Posting {
312 doc_id: "d1".into(),
313 term_freq: 1,
314 positions: vec![1],
315 }],
316 )
317 .unwrap();
318
319 let mut terms = backend.collection_terms("col").unwrap();
320 terms.sort();
321 assert_eq!(terms, vec!["hello", "world"]);
322 }
323
324 #[test]
325 fn segment_roundtrip() {
326 let backend = MemoryBackend::new();
327 let data = b"compressed segment bytes";
328 backend.write_segment("col:seg:id1", data).unwrap();
329 assert_eq!(
330 backend.read_segment("col:seg:id1").unwrap(),
331 Some(data.to_vec())
332 );
333 assert_eq!(backend.read_segment("col:seg:missing").unwrap(), None);
334 }
335
336 #[test]
337 fn segment_list_filters_by_collection() {
338 let backend = MemoryBackend::new();
339 backend.write_segment("col:seg:a", b"a").unwrap();
340 backend.write_segment("col:seg:b", b"b").unwrap();
341 backend.write_segment("other:seg:c", b"c").unwrap();
342
343 let mut segs = backend.list_segments("col").unwrap();
344 segs.sort();
345 assert_eq!(segs, vec!["col:seg:a", "col:seg:b"]);
346
347 let other = backend.list_segments("other").unwrap();
348 assert_eq!(other, vec!["other:seg:c"]);
349 }
350
351 #[test]
352 fn segment_remove() {
353 let backend = MemoryBackend::new();
354 backend.write_segment("col:seg:id1", b"data").unwrap();
355 backend.remove_segment("col:seg:id1").unwrap();
356 assert_eq!(backend.read_segment("col:seg:id1").unwrap(), None);
357 }
358
359 #[test]
360 fn purge_clears_segments() {
361 let backend = MemoryBackend::new();
362 backend.write_segment("col:seg:a", b"a").unwrap();
363 backend.write_segment("other:seg:b", b"b").unwrap();
364
365 backend.purge_collection("col").unwrap();
366 assert!(backend.list_segments("col").unwrap().is_empty());
367 assert_eq!(backend.list_segments("other").unwrap().len(), 1);
368 }
369}