1use std::collections::HashMap;
15
16pub use nodedb_document::text_analyzer::analyze;
18
19#[derive(Debug, Default)]
23pub struct InvertedIndex {
24 postings: HashMap<String, HashMap<String, u32>>,
26 doc_lengths: HashMap<String, u32>,
28 doc_count: u32,
30 total_length: u64,
32}
33
34#[derive(Debug, Clone)]
36pub struct TextSearchResult {
37 pub doc_id: String,
38 pub score: f64,
39}
40
41#[derive(Debug, Clone, Copy)]
43pub struct Bm25Params {
44 pub k1: f64,
46 pub b: f64,
48}
49
50impl Default for Bm25Params {
51 fn default() -> Self {
52 Self { k1: 1.2, b: 0.75 }
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum QueryMode {
59 And,
60 Or,
61}
62
63impl InvertedIndex {
64 pub fn new() -> Self {
65 Self::default()
66 }
67
68 pub fn index_document(&mut self, doc_id: &str, text: &str) {
73 self.remove_document(doc_id);
75
76 let tokens = analyze(text);
77 if tokens.is_empty() {
78 return;
79 }
80
81 let doc_len = tokens.len() as u32;
82 self.doc_lengths.insert(doc_id.to_string(), doc_len);
83 self.doc_count += 1;
84 self.total_length += doc_len as u64;
85
86 let mut tf: HashMap<String, u32> = HashMap::new();
88 for token in &tokens {
89 *tf.entry(token.clone()).or_insert(0) += 1;
90 }
91
92 for (token, freq) in tf {
94 self.postings
95 .entry(token)
96 .or_default()
97 .insert(doc_id.to_string(), freq);
98 }
99 }
100
101 pub fn remove_document(&mut self, doc_id: &str) {
103 if let Some(old_len) = self.doc_lengths.remove(doc_id) {
104 self.doc_count = self.doc_count.saturating_sub(1);
105 self.total_length = self.total_length.saturating_sub(old_len as u64);
106
107 self.postings.retain(|_, docs| {
109 docs.remove(doc_id);
110 !docs.is_empty()
111 });
112 }
113 }
114
115 pub fn search(
119 &self,
120 query: &str,
121 top_k: usize,
122 mode: QueryMode,
123 params: Bm25Params,
124 ) -> Vec<TextSearchResult> {
125 let tokens = analyze(query);
126 if tokens.is_empty() {
127 return Vec::new();
128 }
129
130 let avg_dl = if self.doc_count > 0 {
131 self.total_length as f64 / self.doc_count as f64
132 } else {
133 1.0
134 };
135
136 let mut scores: HashMap<String, f64> = HashMap::new();
137
138 for token in &tokens {
139 let Some(posting) = self.postings.get(token) else {
140 continue;
141 };
142
143 let df = posting.len() as f64;
145 let idf = ((self.doc_count as f64 - df + 0.5) / (df + 0.5) + 1.0).ln();
146
147 for (doc_id, &tf) in posting {
148 let dl = *self.doc_lengths.get(doc_id).unwrap_or(&1) as f64;
149 let tf_f = tf as f64;
151 let numerator = tf_f * (params.k1 + 1.0);
152 let denominator = tf_f + params.k1 * (1.0 - params.b + params.b * dl / avg_dl);
153 let bm25 = idf * numerator / denominator;
154
155 *scores.entry(doc_id.clone()).or_insert(0.0) += bm25;
156 }
157 }
158
159 if mode == QueryMode::And {
161 let query_token_count = tokens.len();
162 scores.retain(|doc_id, _| {
163 let matched_tokens = tokens
164 .iter()
165 .filter(|t| {
166 self.postings
167 .get(*t)
168 .is_some_and(|p| p.contains_key(doc_id))
169 })
170 .count();
171 matched_tokens == query_token_count
172 });
173 }
174
175 let mut results: Vec<TextSearchResult> = scores
176 .into_iter()
177 .map(|(doc_id, score)| TextSearchResult { doc_id, score })
178 .collect();
179
180 results.sort_by(|a, b| {
181 b.score
182 .partial_cmp(&a.score)
183 .unwrap_or(std::cmp::Ordering::Equal)
184 });
185 results.truncate(top_k);
186 results
187 }
188
189 pub fn search_fuzzy(
191 &self,
192 query: &str,
193 max_distance: usize,
194 top_k: usize,
195 params: Bm25Params,
196 ) -> Vec<TextSearchResult> {
197 let tokens = analyze(query);
198 if tokens.is_empty() {
199 return Vec::new();
200 }
201
202 let mut expanded_query = String::new();
204 for token in &tokens {
205 let matching: Vec<&str> = self
206 .postings
207 .keys()
208 .filter(|idx_token| levenshtein(token, idx_token) <= max_distance)
209 .map(|s| s.as_str())
210 .collect();
211 if !matching.is_empty() {
212 if !expanded_query.is_empty() {
213 expanded_query.push(' ');
214 }
215 expanded_query.push_str(&matching.join(" "));
216 }
217 }
218
219 if expanded_query.is_empty() {
220 return Vec::new();
221 }
222
223 self.search(&expanded_query, top_k, QueryMode::Or, params)
224 }
225
226 pub fn doc_count(&self) -> u32 {
227 self.doc_count
228 }
229
230 pub fn token_count(&self) -> usize {
231 self.postings.len()
232 }
233}
234
235fn levenshtein(a: &str, b: &str) -> usize {
237 nodedb_document::levenshtein(a, b)
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn analyze_basic() {
246 let tokens = analyze("The quick brown fox jumps over the lazy dog");
247 assert!(!tokens.is_empty());
248 assert!(tokens.iter().all(|t| t != "the"));
250 }
251
252 #[test]
253 fn analyze_stemming() {
254 let tokens = analyze("running jumps quickly");
255 assert!(tokens.contains(&"run".to_string()));
257 assert!(tokens.contains(&"jump".to_string()));
258 assert!(tokens.contains(&"quick".to_string()));
259 }
260
261 #[test]
262 fn index_and_search() {
263 let mut idx = InvertedIndex::new();
264 idx.index_document("d1", "Rust is a systems programming language");
265 idx.index_document("d2", "Python is great for machine learning");
266 idx.index_document("d3", "Rust and Python are both great languages");
267
268 let results = idx.search("rust programming", 10, QueryMode::Or, Bm25Params::default());
269 assert!(!results.is_empty());
270 assert_eq!(results[0].doc_id, "d1");
272 }
273
274 #[test]
275 fn and_mode() {
276 let mut idx = InvertedIndex::new();
277 idx.index_document("d1", "Rust programming language");
278 idx.index_document("d2", "Python programming language");
279
280 let results = idx.search(
281 "rust programming",
282 10,
283 QueryMode::And,
284 Bm25Params::default(),
285 );
286 assert_eq!(results.len(), 1);
287 assert_eq!(results[0].doc_id, "d1");
288 }
289
290 #[test]
291 fn fuzzy_search() {
292 let mut idx = InvertedIndex::new();
293 idx.index_document("d1", "programming language design");
294 idx.index_document("d2", "progrmmng language review"); let results = idx.search_fuzzy("programming", 3, 10, Bm25Params::default());
299 assert!(!results.is_empty(), "fuzzy search should find matches");
300 let doc_ids: Vec<&str> = results.iter().map(|r| r.doc_id.as_str()).collect();
301 assert!(doc_ids.contains(&"d1"), "should find d1 (exact match)");
302 }
303
304 #[test]
305 fn remove_document() {
306 let mut idx = InvertedIndex::new();
307 idx.index_document("d1", "hello world");
308 assert_eq!(idx.doc_count(), 1);
309
310 idx.remove_document("d1");
311 assert_eq!(idx.doc_count(), 0);
312
313 let results = idx.search("hello", 10, QueryMode::Or, Bm25Params::default());
314 assert!(results.is_empty());
315 }
316
317 #[test]
318 fn levenshtein_basic() {
319 assert_eq!(levenshtein("kitten", "sitting"), 3);
320 assert_eq!(levenshtein("", "abc"), 3);
321 assert_eq!(levenshtein("abc", "abc"), 0);
322 assert_eq!(levenshtein("abc", "ab"), 1);
323 }
324
325 #[test]
326 fn reindex_replaces() {
327 let mut idx = InvertedIndex::new();
328 idx.index_document("d1", "old content");
329 idx.index_document("d1", "new content");
330 assert_eq!(idx.doc_count(), 1);
331
332 let results = idx.search("old", 10, QueryMode::Or, Bm25Params::default());
333 assert!(results.is_empty()); }
335}