Skip to main content

nodedb_query/
text_search.rs

1//! In-memory BM25 text search engine.
2//!
3//! Shared between Origin (as a complement to the redb-backed index) and
4//! Lite (as the primary text search engine). Rebuilt from documents on
5//! cold start — acceptable for edge-scale datasets.
6//!
7//! Features:
8//! - BM25 scoring with configurable k1/b parameters
9//! - Snowball stemming (15 languages, defaults to English)
10//! - Unicode normalization + stop word removal
11//! - Fuzzy matching via Levenshtein edit distance
12//! - AND/OR boolean query modes
13
14use std::collections::HashMap;
15
16// Re-export text analysis from shared nodedb-document crate.
17pub use nodedb_document::text_analyzer::analyze;
18
19// ── Inverted Index ────────────────────────────────────────────────────
20
21/// In-memory inverted index with BM25 scoring.
22#[derive(Debug, Default)]
23pub struct InvertedIndex {
24    /// token → { doc_id → term_frequency }.
25    postings: HashMap<String, HashMap<String, u32>>,
26    /// doc_id → total token count (document length).
27    doc_lengths: HashMap<String, u32>,
28    /// Total number of documents indexed.
29    doc_count: u32,
30    /// Sum of all document lengths (for average calculation).
31    total_length: u64,
32}
33
34/// A single search result with BM25 score.
35#[derive(Debug, Clone)]
36pub struct TextSearchResult {
37    pub doc_id: String,
38    pub score: f64,
39}
40
41/// BM25 parameters.
42#[derive(Debug, Clone, Copy)]
43pub struct Bm25Params {
44    /// Term frequency saturation. Default: 1.2.
45    pub k1: f64,
46    /// Length normalization. Default: 0.75.
47    pub b: f64,
48}
49
50impl Default for Bm25Params {
51    fn default() -> Self {
52        Self { k1: 1.2, b: 0.75 }
53    }
54}
55
56/// Query mode: AND (all terms must match) or OR (any term can match).
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum QueryMode {
59    And,
60    Or,
61}
62
63impl InvertedIndex {
64    pub fn new() -> Self {
65        Self::default()
66    }
67
68    /// Index a document's text content.
69    ///
70    /// `doc_id` is the unique identifier. `text` is analyzed into tokens.
71    /// Calling again with the same `doc_id` replaces the previous entry.
72    pub fn index_document(&mut self, doc_id: &str, text: &str) {
73        // Remove old entry if re-indexing.
74        self.remove_document(doc_id);
75
76        let tokens = analyze(text);
77        if tokens.is_empty() {
78            return;
79        }
80
81        let doc_len = tokens.len() as u32;
82        self.doc_lengths.insert(doc_id.to_string(), doc_len);
83        self.doc_count += 1;
84        self.total_length += doc_len as u64;
85
86        // Count term frequencies.
87        let mut tf: HashMap<String, u32> = HashMap::new();
88        for token in &tokens {
89            *tf.entry(token.clone()).or_insert(0) += 1;
90        }
91
92        // Insert into postings.
93        for (token, freq) in tf {
94            self.postings
95                .entry(token)
96                .or_default()
97                .insert(doc_id.to_string(), freq);
98        }
99    }
100
101    /// Remove a document from the index.
102    pub fn remove_document(&mut self, doc_id: &str) {
103        if let Some(old_len) = self.doc_lengths.remove(doc_id) {
104            self.doc_count = self.doc_count.saturating_sub(1);
105            self.total_length = self.total_length.saturating_sub(old_len as u64);
106
107            // Remove from all postings lists.
108            self.postings.retain(|_, docs| {
109                docs.remove(doc_id);
110                !docs.is_empty()
111            });
112        }
113    }
114
115    /// Search with BM25 scoring.
116    ///
117    /// Returns results sorted by descending score, limited to `top_k`.
118    pub fn search(
119        &self,
120        query: &str,
121        top_k: usize,
122        mode: QueryMode,
123        params: Bm25Params,
124    ) -> Vec<TextSearchResult> {
125        let tokens = analyze(query);
126        if tokens.is_empty() {
127            return Vec::new();
128        }
129
130        let avg_dl = if self.doc_count > 0 {
131            self.total_length as f64 / self.doc_count as f64
132        } else {
133            1.0
134        };
135
136        let mut scores: HashMap<String, f64> = HashMap::new();
137
138        for token in &tokens {
139            let Some(posting) = self.postings.get(token) else {
140                continue;
141            };
142
143            // IDF: log((N - df + 0.5) / (df + 0.5) + 1)
144            let df = posting.len() as f64;
145            let idf = ((self.doc_count as f64 - df + 0.5) / (df + 0.5) + 1.0).ln();
146
147            for (doc_id, &tf) in posting {
148                let dl = *self.doc_lengths.get(doc_id).unwrap_or(&1) as f64;
149                // BM25: idf * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * dl / avgdl))
150                let tf_f = tf as f64;
151                let numerator = tf_f * (params.k1 + 1.0);
152                let denominator = tf_f + params.k1 * (1.0 - params.b + params.b * dl / avg_dl);
153                let bm25 = idf * numerator / denominator;
154
155                *scores.entry(doc_id.clone()).or_insert(0.0) += bm25;
156            }
157        }
158
159        // AND mode: remove docs that don't match ALL query tokens.
160        if mode == QueryMode::And {
161            let query_token_count = tokens.len();
162            scores.retain(|doc_id, _| {
163                let matched_tokens = tokens
164                    .iter()
165                    .filter(|t| {
166                        self.postings
167                            .get(*t)
168                            .is_some_and(|p| p.contains_key(doc_id))
169                    })
170                    .count();
171                matched_tokens == query_token_count
172            });
173        }
174
175        let mut results: Vec<TextSearchResult> = scores
176            .into_iter()
177            .map(|(doc_id, score)| TextSearchResult { doc_id, score })
178            .collect();
179
180        results.sort_by(|a, b| {
181            b.score
182                .partial_cmp(&a.score)
183                .unwrap_or(std::cmp::Ordering::Equal)
184        });
185        results.truncate(top_k);
186        results
187    }
188
189    /// Fuzzy search: find documents matching query terms within Levenshtein distance.
190    pub fn search_fuzzy(
191        &self,
192        query: &str,
193        max_distance: usize,
194        top_k: usize,
195        params: Bm25Params,
196    ) -> Vec<TextSearchResult> {
197        let tokens = analyze(query);
198        if tokens.is_empty() {
199            return Vec::new();
200        }
201
202        // Expand each query token to matching index tokens within edit distance.
203        let mut expanded_query = String::new();
204        for token in &tokens {
205            let matching: Vec<&str> = self
206                .postings
207                .keys()
208                .filter(|idx_token| levenshtein(token, idx_token) <= max_distance)
209                .map(|s| s.as_str())
210                .collect();
211            if !matching.is_empty() {
212                if !expanded_query.is_empty() {
213                    expanded_query.push(' ');
214                }
215                expanded_query.push_str(&matching.join(" "));
216            }
217        }
218
219        if expanded_query.is_empty() {
220            return Vec::new();
221        }
222
223        self.search(&expanded_query, top_k, QueryMode::Or, params)
224    }
225
226    pub fn doc_count(&self) -> u32 {
227        self.doc_count
228    }
229
230    pub fn token_count(&self) -> usize {
231        self.postings.len()
232    }
233}
234
235// Re-export Levenshtein from shared nodedb-document crate.
236fn levenshtein(a: &str, b: &str) -> usize {
237    nodedb_document::levenshtein(a, b)
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn analyze_basic() {
246        let tokens = analyze("The quick brown fox jumps over the lazy dog");
247        assert!(!tokens.is_empty());
248        // "the" is a stop word, should be removed.
249        assert!(tokens.iter().all(|t| t != "the"));
250    }
251
252    #[test]
253    fn analyze_stemming() {
254        let tokens = analyze("running jumps quickly");
255        // "running" → "run", "jumps" → "jump", "quickly" → "quick"
256        assert!(tokens.contains(&"run".to_string()));
257        assert!(tokens.contains(&"jump".to_string()));
258        assert!(tokens.contains(&"quick".to_string()));
259    }
260
261    #[test]
262    fn index_and_search() {
263        let mut idx = InvertedIndex::new();
264        idx.index_document("d1", "Rust is a systems programming language");
265        idx.index_document("d2", "Python is great for machine learning");
266        idx.index_document("d3", "Rust and Python are both great languages");
267
268        let results = idx.search("rust programming", 10, QueryMode::Or, Bm25Params::default());
269        assert!(!results.is_empty());
270        // d1 should rank highest (has both "rust" and "programming").
271        assert_eq!(results[0].doc_id, "d1");
272    }
273
274    #[test]
275    fn and_mode() {
276        let mut idx = InvertedIndex::new();
277        idx.index_document("d1", "Rust programming language");
278        idx.index_document("d2", "Python programming language");
279
280        let results = idx.search(
281            "rust programming",
282            10,
283            QueryMode::And,
284            Bm25Params::default(),
285        );
286        assert_eq!(results.len(), 1);
287        assert_eq!(results[0].doc_id, "d1");
288    }
289
290    #[test]
291    fn fuzzy_search() {
292        let mut idx = InvertedIndex::new();
293        idx.index_document("d1", "programming language design");
294        idx.index_document("d2", "progrmmng language review"); // typos
295
296        // Fuzzy search expands "programming" (stemmed) to match index tokens
297        // within edit distance. Should find both documents.
298        let results = idx.search_fuzzy("programming", 3, 10, Bm25Params::default());
299        assert!(!results.is_empty(), "fuzzy search should find matches");
300        let doc_ids: Vec<&str> = results.iter().map(|r| r.doc_id.as_str()).collect();
301        assert!(doc_ids.contains(&"d1"), "should find d1 (exact match)");
302    }
303
304    #[test]
305    fn remove_document() {
306        let mut idx = InvertedIndex::new();
307        idx.index_document("d1", "hello world");
308        assert_eq!(idx.doc_count(), 1);
309
310        idx.remove_document("d1");
311        assert_eq!(idx.doc_count(), 0);
312
313        let results = idx.search("hello", 10, QueryMode::Or, Bm25Params::default());
314        assert!(results.is_empty());
315    }
316
317    #[test]
318    fn levenshtein_basic() {
319        assert_eq!(levenshtein("kitten", "sitting"), 3);
320        assert_eq!(levenshtein("", "abc"), 3);
321        assert_eq!(levenshtein("abc", "abc"), 0);
322        assert_eq!(levenshtein("abc", "ab"), 1);
323    }
324
325    #[test]
326    fn reindex_replaces() {
327        let mut idx = InvertedIndex::new();
328        idx.index_document("d1", "old content");
329        idx.index_document("d1", "new content");
330        assert_eq!(idx.doc_count(), 1);
331
332        let results = idx.search("old", 10, QueryMode::Or, Bm25Params::default());
333        assert!(results.is_empty()); // old content should be gone
334    }
335}