Skip to main content

nodedb_fts/search/
bm25_search.rs

1//! BM25 search over the FtsIndex with AND-first OR-fallback and phrase boost.
2
3use std::collections::HashMap;
4
5use crate::backend::FtsBackend;
6use crate::bm25::bm25_score;
7use crate::index::FtsIndex;
8use crate::posting::{Posting, QueryMode, TextSearchResult};
9use crate::search::phrase;
10
11impl<B: FtsBackend> FtsIndex<B> {
12    /// Search the index using BM25 scoring.
13    ///
14    /// Uses AND-first with automatic OR-fallback: if AND yields zero results
15    /// for a multi-term query, retries with OR and applies a coverage penalty
16    /// of `matched_terms / total_terms` to each document's score.
17    pub fn search(
18        &self,
19        collection: &str,
20        query: &str,
21        top_k: usize,
22        fuzzy_enabled: bool,
23    ) -> Result<Vec<TextSearchResult>, B::Error> {
24        self.search_with_mode(collection, query, top_k, fuzzy_enabled, QueryMode::And)
25    }
26
27    /// Search with explicit boolean mode (AND or OR).
28    ///
29    /// When `mode` is AND and a multi-term query returns zero results,
30    /// automatically falls back to OR with a coverage penalty.
31    ///
32    /// Dispatches to Block-Max WAND (BMW) for OR-mode queries when a DocIdMap
33    /// is available. Falls back to exhaustive BM25 scoring otherwise.
34    pub fn search_with_mode(
35        &self,
36        collection: &str,
37        query: &str,
38        top_k: usize,
39        fuzzy_enabled: bool,
40        mode: QueryMode,
41    ) -> Result<Vec<TextSearchResult>, B::Error> {
42        let query_tokens = self.analyze_for_collection(collection, query)?;
43        if query_tokens.is_empty() {
44            return Ok(Vec::new());
45        }
46        let num_query_terms = query_tokens.len();
47
48        // Raw (unstemmed) tokens for fuzzy matching — edit distance should be
49        // computed on original word forms, not after stemming distorts them.
50        let raw_tokens = if fuzzy_enabled {
51            self.tokenize_raw_for_collection(collection, query)?
52        } else {
53            Vec::new()
54        };
55
56        let (total_docs, avg_doc_len) = self.index_stats(collection)?;
57        if total_docs == 0 {
58            return Ok(Vec::new());
59        }
60
61        // Try BMW for OR-mode or as the first pass of AND-with-fallback.
62        let bmw_params = super::bmw::query::BmwParams {
63            query_tokens: &query_tokens,
64            raw_tokens: &raw_tokens,
65            fuzzy_enabled,
66            top_k: if mode == QueryMode::And && num_query_terms > 1 {
67                top_k.saturating_mul(3).max(20)
68            } else {
69                top_k
70            },
71            total_docs,
72            avg_doc_len,
73            bm25: &self.bm25_params,
74        };
75        if let Ok(Some(bmw_results)) = super::bmw::query::bmw_search(self, collection, &bmw_params)
76        {
77            if mode == QueryMode::Or || num_query_terms == 1 {
78                return Ok(bmw_results.into_iter().take(top_k).collect());
79            }
80
81            // AND mode: filter BMW results to docs matching all terms.
82            // Re-check term coverage for top candidates.
83            let and_results =
84                self.filter_and_mode(collection, &query_tokens, &bmw_results, num_query_terms)?;
85
86            if !and_results.is_empty() {
87                return Ok(and_results.into_iter().take(top_k).collect());
88            }
89
90            // AND returned nothing — apply coverage penalty to BMW OR results.
91            let penalized: Vec<TextSearchResult> = bmw_results
92                .into_iter()
93                .map(|mut r| {
94                    let matched = self.count_term_matches(collection, &query_tokens, &r.doc_id);
95                    let coverage = matched as f32 / num_query_terms as f32;
96                    r.score *= coverage;
97                    r
98                })
99                .collect();
100            let mut sorted = penalized;
101            sorted.sort_by(|a, b| {
102                b.score
103                    .partial_cmp(&a.score)
104                    .unwrap_or(std::cmp::Ordering::Equal)
105            });
106            sorted.truncate(top_k);
107            return Ok(sorted);
108        }
109
110        // Fallback: exhaustive BM25 scoring.
111        // Read directly from backend posting store (handles Origin's transaction-based writes
112        // which bypass the LSM memtable/segment path).
113        let mut term_postings: Vec<(Vec<Posting>, bool)> = Vec::with_capacity(num_query_terms);
114        for (i, token) in query_tokens.iter().enumerate() {
115            let postings = self.backend.read_postings(collection, token)?;
116            if !postings.is_empty() {
117                term_postings.push((postings, false));
118            } else if fuzzy_enabled {
119                // Use raw (unstemmed) token for fuzzy matching — stemming distorts edit distance.
120                let raw = raw_tokens
121                    .get(i)
122                    .map(String::as_str)
123                    .unwrap_or(token.as_str());
124                let (fuzzy_posts, is_fuzzy) = self.fuzzy_lookup(collection, raw)?;
125                term_postings.push((fuzzy_posts, is_fuzzy));
126            } else {
127                term_postings.push((Vec::new(), false));
128            }
129        }
130
131        // Score all documents.
132        // (score, fuzzy_flag, term_match_count)
133        let mut doc_scores: HashMap<String, (f32, bool, usize)> = HashMap::new();
134
135        for (token_idx, (postings, is_fuzzy)) in term_postings.iter().enumerate() {
136            if postings.is_empty() {
137                continue;
138            }
139            let df = postings.len() as u32;
140
141            for posting in postings {
142                let doc_len = self
143                    .backend
144                    .read_doc_length(collection, &posting.doc_id)?
145                    .unwrap_or(1);
146
147                let mut score = bm25_score(
148                    posting.term_freq,
149                    df,
150                    doc_len,
151                    total_docs,
152                    avg_doc_len,
153                    &self.bm25_params,
154                );
155
156                if *is_fuzzy {
157                    score *= crate::fuzzy::fuzzy_discount(1);
158                }
159
160                let entry = doc_scores
161                    .entry(posting.doc_id.clone())
162                    .or_insert((0.0, false, 0));
163                entry.0 += score;
164                if *is_fuzzy {
165                    entry.1 = true;
166                }
167                entry.2 += 1;
168            }
169            let _ = token_idx; // used by phrase boost below
170        }
171
172        // Apply phrase proximity boost.
173        if num_query_terms >= 2 {
174            let doc_postings_map =
175                phrase::collect_doc_postings(&query_tokens, &term_postings, &self.backend);
176            for (doc_id, token_postings) in &doc_postings_map {
177                if let Some(entry) = doc_scores.get_mut(doc_id.as_str()) {
178                    let boost = phrase::phrase_boost(&query_tokens, token_postings);
179                    entry.0 *= boost;
180                }
181            }
182        }
183
184        // AND mode with OR fallback.
185        if mode == QueryMode::And && num_query_terms > 1 {
186            let and_results: HashMap<String, (f32, bool, usize)> = doc_scores
187                .iter()
188                .filter(|(_, (_, _, match_count))| *match_count >= num_query_terms)
189                .map(|(k, v)| (k.clone(), *v))
190                .collect();
191
192            if !and_results.is_empty() {
193                return Ok(Self::to_sorted_results(and_results, top_k));
194            }
195
196            // AND returned nothing — fall back to OR with coverage penalty.
197            for (score, _, match_count) in doc_scores.values_mut() {
198                let coverage = *match_count as f32 / num_query_terms as f32;
199                *score *= coverage;
200            }
201        }
202
203        Ok(Self::to_sorted_results(doc_scores, top_k))
204    }
205
206    /// Filter BMW results to only docs matching ALL query terms (AND mode).
207    fn filter_and_mode(
208        &self,
209        collection: &str,
210        query_tokens: &[String],
211        candidates: &[TextSearchResult],
212        num_terms: usize,
213    ) -> Result<Vec<TextSearchResult>, B::Error> {
214        let doc_map = self.load_doc_id_map(collection)?;
215        let term_blocks = crate::lsm::query::collect_merged_term_blocks(
216            &self.backend,
217            collection,
218            self.memtable(),
219            query_tokens,
220        )?;
221
222        let mut results = Vec::new();
223        for candidate in candidates {
224            let int_id = doc_map.to_u32(&candidate.doc_id);
225            let matched = term_blocks
226                .iter()
227                .filter(|tb| {
228                    int_id.is_some_and(|id| tb.blocks.iter().any(|b| b.doc_ids.contains(&id)))
229                })
230                .count();
231            if matched >= num_terms {
232                results.push(candidate.clone());
233            }
234        }
235        Ok(results)
236    }
237
238    /// Count how many query terms appear in a document's posting lists (via LSM).
239    fn count_term_matches(&self, collection: &str, query_tokens: &[String], doc_id: &str) -> usize {
240        let doc_map = match self.load_doc_id_map(collection) {
241            Ok(m) => m,
242            Err(_) => return 0,
243        };
244        let Some(int_id) = doc_map.to_u32(doc_id) else {
245            return 0;
246        };
247        let term_blocks = match crate::lsm::query::collect_merged_term_blocks(
248            &self.backend,
249            collection,
250            self.memtable(),
251            query_tokens,
252        ) {
253            Ok(tb) => tb,
254            Err(_) => return 0,
255        };
256        term_blocks
257            .iter()
258            .filter(|tb| tb.blocks.iter().any(|b| b.doc_ids.contains(&int_id)))
259            .count()
260    }
261
262    /// Convert score map to sorted, truncated results.
263    fn to_sorted_results(
264        doc_scores: HashMap<String, (f32, bool, usize)>,
265        top_k: usize,
266    ) -> Vec<TextSearchResult> {
267        let mut results: Vec<TextSearchResult> = doc_scores
268            .into_iter()
269            .map(|(doc_id, (score, fuzzy_flag, _))| TextSearchResult {
270                doc_id,
271                score,
272                fuzzy: fuzzy_flag,
273            })
274            .collect();
275        results.sort_by(|a, b| {
276            b.score
277                .partial_cmp(&a.score)
278                .unwrap_or(std::cmp::Ordering::Equal)
279        });
280        results.truncate(top_k);
281        results
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use crate::backend::memory::MemoryBackend;
288    use crate::index::FtsIndex;
289    use crate::posting::QueryMode;
290
291    fn make_index() -> FtsIndex<MemoryBackend> {
292        let idx = FtsIndex::new(MemoryBackend::new());
293        idx.index_document("docs", "d1", "The quick brown fox jumps over the lazy dog")
294            .unwrap();
295        idx.index_document("docs", "d2", "A fast brown dog runs across the field")
296            .unwrap();
297        idx.index_document("docs", "d3", "Rust programming language for systems")
298            .unwrap();
299        idx
300    }
301
302    #[test]
303    fn basic_search() {
304        let idx = make_index();
305        let results = idx.search("docs", "brown fox", 10, false).unwrap();
306        assert!(!results.is_empty());
307        assert_eq!(results[0].doc_id, "d1");
308    }
309
310    #[test]
311    fn search_with_stemming() {
312        let idx = FtsIndex::new(MemoryBackend::new());
313        idx.index_document("docs", "d1", "running distributed databases")
314            .unwrap();
315        idx.index_document("docs", "d2", "the cat sat on a mat")
316            .unwrap();
317
318        let results = idx
319            .search("docs", "database distribution", 10, false)
320            .unwrap();
321        assert!(!results.is_empty());
322        assert_eq!(results[0].doc_id, "d1");
323    }
324
325    #[test]
326    fn or_mode() {
327        let idx = make_index();
328        let results = idx
329            .search_with_mode("docs", "brown fox", 10, false, QueryMode::Or)
330            .unwrap();
331        assert!(results.len() >= 2);
332    }
333
334    #[test]
335    fn and_mode_filters() {
336        let idx = FtsIndex::new(MemoryBackend::new());
337        idx.index_document("docs", "d1", "Rust programming language")
338            .unwrap();
339        idx.index_document("docs", "d2", "Python programming language")
340            .unwrap();
341
342        let results = idx
343            .search_with_mode("docs", "rust programming", 10, false, QueryMode::And)
344            .unwrap();
345        assert_eq!(results.len(), 1);
346        assert_eq!(results[0].doc_id, "d1");
347    }
348
349    #[test]
350    fn and_fallback_to_or() {
351        let idx = FtsIndex::new(MemoryBackend::new());
352        idx.index_document("docs", "d1", "rust programming language")
353            .unwrap();
354        idx.index_document("docs", "d2", "python programming language")
355            .unwrap();
356
357        // "rust python" — no doc has BOTH, AND yields nothing, falls back to OR.
358        let results = idx.search("docs", "rust python", 10, false).unwrap();
359        assert_eq!(results.len(), 2);
360        // Coverage penalty: each doc matches 1/2 terms → scores penalized by 0.5.
361        for r in &results {
362            assert!(r.score > 0.0);
363        }
364    }
365
366    #[test]
367    fn and_no_fallback_when_results_exist() {
368        let idx = FtsIndex::new(MemoryBackend::new());
369        idx.index_document("docs", "d1", "rust programming language")
370            .unwrap();
371        idx.index_document("docs", "d2", "python programming language")
372            .unwrap();
373
374        // "rust programming" — d1 has both, AND succeeds, no fallback.
375        let results = idx.search("docs", "rust programming", 10, false).unwrap();
376        assert_eq!(results.len(), 1);
377        assert_eq!(results[0].doc_id, "d1");
378    }
379
380    #[test]
381    fn empty_query() {
382        let idx = make_index();
383        let results = idx.search("docs", "the a is", 10, false).unwrap();
384        assert!(results.is_empty());
385    }
386
387    #[test]
388    fn collections_isolated() {
389        let idx = FtsIndex::new(MemoryBackend::new());
390        idx.index_document("col_a", "d1", "alpha bravo charlie")
391            .unwrap();
392        idx.index_document("col_b", "d1", "delta echo foxtrot")
393            .unwrap();
394
395        assert_eq!(idx.search("col_a", "alpha", 10, false).unwrap().len(), 1);
396        assert!(idx.search("col_b", "alpha", 10, false).unwrap().is_empty());
397    }
398
399    #[test]
400    fn fuzzy_search() {
401        let idx = FtsIndex::new(MemoryBackend::new());
402        idx.index_document("docs", "d1", "distributed database systems")
403            .unwrap();
404
405        // "databse" (7 chars raw) fuzzy-matches "databas" (stemmed from "database").
406        // Fuzzy uses raw tokens: levenshtein("databse", "databas") = 2, max_dist(7) = 2 → match.
407        let results = idx.search("docs", "databse", 10, true).unwrap();
408        assert!(!results.is_empty());
409        assert!(results[0].fuzzy);
410    }
411
412    #[test]
413    fn phrase_boost_consecutive() {
414        let idx = FtsIndex::new(MemoryBackend::new());
415        // d1 has "brown fox" as consecutive tokens.
416        idx.index_document("docs", "d1", "the quick brown fox jumped")
417            .unwrap();
418        // d2 has "brown" and "fox" but separated.
419        idx.index_document("docs", "d2", "a brown dog chased a fox")
420            .unwrap();
421
422        let results = idx
423            .search_with_mode("docs", "brown fox", 10, false, QueryMode::Or)
424            .unwrap();
425        assert!(results.len() >= 2);
426        // d1 should rank higher due to phrase boost.
427        assert_eq!(results[0].doc_id, "d1");
428    }
429
430    #[test]
431    fn phrase_boost_no_effect_single_term() {
432        let idx = FtsIndex::new(MemoryBackend::new());
433        idx.index_document("docs", "d1", "hello world").unwrap();
434
435        let results = idx.search("docs", "hello", 10, false).unwrap();
436        assert_eq!(results.len(), 1);
437    }
438}