Skip to main content

nodedb_fts/search/
bm25_search.rs

1//! BM25 search over the FtsIndex with AND-first OR-fallback and phrase boost.
2
3use std::collections::HashMap;
4
5use crate::backend::FtsBackend;
6use crate::bm25::bm25_score;
7use crate::index::FtsIndex;
8use crate::posting::{Posting, QueryMode, TextSearchResult};
9use crate::search::phrase;
10
11impl<B: FtsBackend> FtsIndex<B> {
12    /// Search the index using BM25 scoring.
13    pub fn search(
14        &self,
15        tid: u32,
16        collection: &str,
17        query: &str,
18        top_k: usize,
19        fuzzy_enabled: bool,
20    ) -> Result<Vec<TextSearchResult>, B::Error> {
21        self.search_with_mode(tid, collection, query, top_k, fuzzy_enabled, QueryMode::And)
22    }
23
24    /// Search with explicit boolean mode (AND or OR).
25    pub fn search_with_mode(
26        &self,
27        tid: u32,
28        collection: &str,
29        query: &str,
30        top_k: usize,
31        fuzzy_enabled: bool,
32        mode: QueryMode,
33    ) -> Result<Vec<TextSearchResult>, B::Error> {
34        let query_tokens = self.analyze_for_collection(tid, collection, query)?;
35        if query_tokens.is_empty() {
36            return Ok(Vec::new());
37        }
38        let num_query_terms = query_tokens.len();
39
40        let raw_tokens = if fuzzy_enabled {
41            self.tokenize_raw_for_collection(tid, collection, query)?
42        } else {
43            Vec::new()
44        };
45
46        let (total_docs, avg_doc_len) = self.index_stats(tid, collection)?;
47        if total_docs == 0 {
48            return Ok(Vec::new());
49        }
50
51        let bmw_params = super::bmw::query::BmwParams {
52            query_tokens: &query_tokens,
53            raw_tokens: &raw_tokens,
54            fuzzy_enabled,
55            top_k: if mode == QueryMode::And && num_query_terms > 1 {
56                top_k.saturating_mul(3).max(20)
57            } else {
58                top_k
59            },
60            total_docs,
61            avg_doc_len,
62            bm25: &self.bm25_params,
63        };
64        if let Ok(Some(bmw_results)) =
65            super::bmw::query::bmw_search(self, tid, collection, &bmw_params)
66        {
67            if mode == QueryMode::Or || num_query_terms == 1 {
68                return Ok(bmw_results.into_iter().take(top_k).collect());
69            }
70
71            let and_results = self.filter_and_mode(
72                tid,
73                collection,
74                &query_tokens,
75                &bmw_results,
76                num_query_terms,
77            )?;
78
79            if !and_results.is_empty() {
80                return Ok(and_results.into_iter().take(top_k).collect());
81            }
82
83            let penalized: Vec<TextSearchResult> = bmw_results
84                .into_iter()
85                .map(|mut r| {
86                    let matched =
87                        self.count_term_matches(tid, collection, &query_tokens, &r.doc_id);
88                    let coverage = matched as f32 / num_query_terms as f32;
89                    r.score *= coverage;
90                    r
91                })
92                .collect();
93            let mut sorted = penalized;
94            sorted.sort_by(|a, b| {
95                b.score
96                    .partial_cmp(&a.score)
97                    .unwrap_or(std::cmp::Ordering::Equal)
98            });
99            sorted.truncate(top_k);
100            return Ok(sorted);
101        }
102
103        // Fallback: exhaustive BM25 scoring reading directly from the backend.
104        let mut term_postings: Vec<(Vec<Posting>, bool)> = Vec::with_capacity(num_query_terms);
105        for (i, token) in query_tokens.iter().enumerate() {
106            let postings = self.backend.read_postings(tid, collection, token)?;
107            if !postings.is_empty() {
108                term_postings.push((postings, false));
109            } else if fuzzy_enabled {
110                let raw = raw_tokens
111                    .get(i)
112                    .map(String::as_str)
113                    .unwrap_or(token.as_str());
114                let (fuzzy_posts, is_fuzzy) = self.fuzzy_lookup(tid, collection, raw)?;
115                term_postings.push((fuzzy_posts, is_fuzzy));
116            } else {
117                term_postings.push((Vec::new(), false));
118            }
119        }
120
121        let mut doc_scores: HashMap<String, (f32, bool, usize)> = HashMap::new();
122
123        for (token_idx, (postings, is_fuzzy)) in term_postings.iter().enumerate() {
124            if postings.is_empty() {
125                continue;
126            }
127            let df = postings.len() as u32;
128
129            for posting in postings {
130                let doc_len = self
131                    .backend
132                    .read_doc_length(tid, collection, &posting.doc_id)?
133                    .unwrap_or(1);
134
135                let mut score = bm25_score(
136                    posting.term_freq,
137                    df,
138                    doc_len,
139                    total_docs,
140                    avg_doc_len,
141                    &self.bm25_params,
142                );
143
144                if *is_fuzzy {
145                    score *= crate::fuzzy::fuzzy_discount(1);
146                }
147
148                let entry = doc_scores
149                    .entry(posting.doc_id.clone())
150                    .or_insert((0.0, false, 0));
151                entry.0 += score;
152                if *is_fuzzy {
153                    entry.1 = true;
154                }
155                entry.2 += 1;
156            }
157            let _ = token_idx;
158        }
159
160        if num_query_terms >= 2 {
161            let doc_postings_map = phrase::collect_doc_postings(&query_tokens, &term_postings);
162            for (doc_id, token_postings) in &doc_postings_map {
163                if let Some(entry) = doc_scores.get_mut(doc_id.as_str()) {
164                    let boost = phrase::phrase_boost(&query_tokens, token_postings);
165                    entry.0 *= boost;
166                }
167            }
168        }
169
170        if mode == QueryMode::And && num_query_terms > 1 {
171            let and_results: HashMap<String, (f32, bool, usize)> = doc_scores
172                .iter()
173                .filter(|(_, (_, _, match_count))| *match_count >= num_query_terms)
174                .map(|(k, v)| (k.clone(), *v))
175                .collect();
176
177            if !and_results.is_empty() {
178                return Ok(Self::to_sorted_results(and_results, top_k));
179            }
180
181            for (score, _, match_count) in doc_scores.values_mut() {
182                let coverage = *match_count as f32 / num_query_terms as f32;
183                *score *= coverage;
184            }
185        }
186
187        Ok(Self::to_sorted_results(doc_scores, top_k))
188    }
189
190    fn filter_and_mode(
191        &self,
192        tid: u32,
193        collection: &str,
194        query_tokens: &[String],
195        candidates: &[TextSearchResult],
196        num_terms: usize,
197    ) -> Result<Vec<TextSearchResult>, B::Error> {
198        let doc_map = self.load_doc_id_map(tid, collection)?;
199        let term_blocks = crate::lsm::query::collect_merged_term_blocks(
200            &self.backend,
201            tid,
202            collection,
203            self.memtable(),
204            query_tokens,
205        )?;
206
207        let mut results = Vec::new();
208        for candidate in candidates {
209            let int_id = doc_map.to_u32(&candidate.doc_id);
210            let matched = term_blocks
211                .iter()
212                .filter(|tb| {
213                    int_id.is_some_and(|id| tb.blocks.iter().any(|b| b.doc_ids.contains(&id)))
214                })
215                .count();
216            if matched >= num_terms {
217                results.push(candidate.clone());
218            }
219        }
220        Ok(results)
221    }
222
223    fn count_term_matches(
224        &self,
225        tid: u32,
226        collection: &str,
227        query_tokens: &[String],
228        doc_id: &str,
229    ) -> usize {
230        let doc_map = match self.load_doc_id_map(tid, collection) {
231            Ok(m) => m,
232            Err(_) => return 0,
233        };
234        let Some(int_id) = doc_map.to_u32(doc_id) else {
235            return 0;
236        };
237        let term_blocks = match crate::lsm::query::collect_merged_term_blocks(
238            &self.backend,
239            tid,
240            collection,
241            self.memtable(),
242            query_tokens,
243        ) {
244            Ok(tb) => tb,
245            Err(_) => return 0,
246        };
247        term_blocks
248            .iter()
249            .filter(|tb| tb.blocks.iter().any(|b| b.doc_ids.contains(&int_id)))
250            .count()
251    }
252
253    fn to_sorted_results(
254        doc_scores: HashMap<String, (f32, bool, usize)>,
255        top_k: usize,
256    ) -> Vec<TextSearchResult> {
257        let mut results: Vec<TextSearchResult> = doc_scores
258            .into_iter()
259            .map(|(doc_id, (score, fuzzy_flag, _))| TextSearchResult {
260                doc_id,
261                score,
262                fuzzy: fuzzy_flag,
263            })
264            .collect();
265        results.sort_by(|a, b| {
266            b.score
267                .partial_cmp(&a.score)
268                .unwrap_or(std::cmp::Ordering::Equal)
269        });
270        results.truncate(top_k);
271        results
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use crate::backend::memory::MemoryBackend;
278    use crate::index::FtsIndex;
279    use crate::posting::QueryMode;
280
281    const T: u32 = 1;
282
283    fn make_index() -> FtsIndex<MemoryBackend> {
284        let idx = FtsIndex::new(MemoryBackend::new());
285        idx.index_document(
286            T,
287            "docs",
288            "d1",
289            "The quick brown fox jumps over the lazy dog",
290        )
291        .unwrap();
292        idx.index_document(T, "docs", "d2", "A fast brown dog runs across the field")
293            .unwrap();
294        idx.index_document(T, "docs", "d3", "Rust programming language for systems")
295            .unwrap();
296        idx
297    }
298
299    #[test]
300    fn basic_search() {
301        let idx = make_index();
302        let results = idx.search(T, "docs", "brown fox", 10, false).unwrap();
303        assert!(!results.is_empty());
304        assert_eq!(results[0].doc_id, "d1");
305    }
306
307    #[test]
308    fn search_with_stemming() {
309        let idx = FtsIndex::new(MemoryBackend::new());
310        idx.index_document(T, "docs", "d1", "running distributed databases")
311            .unwrap();
312        idx.index_document(T, "docs", "d2", "the cat sat on a mat")
313            .unwrap();
314
315        let results = idx
316            .search(T, "docs", "database distribution", 10, false)
317            .unwrap();
318        assert!(!results.is_empty());
319        assert_eq!(results[0].doc_id, "d1");
320    }
321
322    #[test]
323    fn or_mode() {
324        let idx = make_index();
325        let results = idx
326            .search_with_mode(T, "docs", "brown fox", 10, false, QueryMode::Or)
327            .unwrap();
328        assert!(results.len() >= 2);
329    }
330
331    #[test]
332    fn and_mode_filters() {
333        let idx = FtsIndex::new(MemoryBackend::new());
334        idx.index_document(T, "docs", "d1", "Rust programming language")
335            .unwrap();
336        idx.index_document(T, "docs", "d2", "Python programming language")
337            .unwrap();
338
339        let results = idx
340            .search_with_mode(T, "docs", "rust programming", 10, false, QueryMode::And)
341            .unwrap();
342        assert_eq!(results.len(), 1);
343        assert_eq!(results[0].doc_id, "d1");
344    }
345
346    #[test]
347    fn and_fallback_to_or() {
348        let idx = FtsIndex::new(MemoryBackend::new());
349        idx.index_document(T, "docs", "d1", "rust programming language")
350            .unwrap();
351        idx.index_document(T, "docs", "d2", "python programming language")
352            .unwrap();
353
354        let results = idx.search(T, "docs", "rust python", 10, false).unwrap();
355        assert_eq!(results.len(), 2);
356        for r in &results {
357            assert!(r.score > 0.0);
358        }
359    }
360
361    #[test]
362    fn and_no_fallback_when_results_exist() {
363        let idx = FtsIndex::new(MemoryBackend::new());
364        idx.index_document(T, "docs", "d1", "rust programming language")
365            .unwrap();
366        idx.index_document(T, "docs", "d2", "python programming language")
367            .unwrap();
368
369        let results = idx
370            .search(T, "docs", "rust programming", 10, false)
371            .unwrap();
372        assert_eq!(results.len(), 1);
373        assert_eq!(results[0].doc_id, "d1");
374    }
375
376    #[test]
377    fn empty_query() {
378        let idx = make_index();
379        let results = idx.search(T, "docs", "the a is", 10, false).unwrap();
380        assert!(results.is_empty());
381    }
382
383    #[test]
384    fn collections_isolated() {
385        let idx = FtsIndex::new(MemoryBackend::new());
386        idx.index_document(T, "col_a", "d1", "alpha bravo charlie")
387            .unwrap();
388        idx.index_document(T, "col_b", "d1", "delta echo foxtrot")
389            .unwrap();
390
391        assert_eq!(idx.search(T, "col_a", "alpha", 10, false).unwrap().len(), 1);
392        assert!(
393            idx.search(T, "col_b", "alpha", 10, false)
394                .unwrap()
395                .is_empty()
396        );
397    }
398
399    #[test]
400    fn fuzzy_search() {
401        let idx = FtsIndex::new(MemoryBackend::new());
402        idx.index_document(T, "docs", "d1", "distributed database systems")
403            .unwrap();
404
405        let results = idx.search(T, "docs", "databse", 10, true).unwrap();
406        assert!(!results.is_empty());
407        assert!(results[0].fuzzy);
408    }
409
410    #[test]
411    fn phrase_boost_consecutive() {
412        let idx = FtsIndex::new(MemoryBackend::new());
413        idx.index_document(T, "docs", "d1", "the quick brown fox jumped")
414            .unwrap();
415        idx.index_document(T, "docs", "d2", "a brown dog chased a fox")
416            .unwrap();
417
418        let results = idx
419            .search_with_mode(T, "docs", "brown fox", 10, false, QueryMode::Or)
420            .unwrap();
421        assert!(results.len() >= 2);
422        assert_eq!(results[0].doc_id, "d1");
423    }
424
425    #[test]
426    fn phrase_boost_no_effect_single_term() {
427        let idx = FtsIndex::new(MemoryBackend::new());
428        idx.index_document(T, "docs", "d1", "hello world").unwrap();
429
430        let results = idx.search(T, "docs", "hello", 10, false).unwrap();
431        assert_eq!(results.len(), 1);
432    }
433
434    #[test]
435    fn tenants_isolated() {
436        let idx = FtsIndex::new(MemoryBackend::new());
437        idx.index_document(1, "docs", "d1", "alpha bravo").unwrap();
438        idx.index_document(2, "docs", "d1", "charlie delta")
439            .unwrap();
440
441        let r1 = idx.search(1, "docs", "alpha", 10, false).unwrap();
442        let r2 = idx.search(2, "docs", "alpha", 10, false).unwrap();
443        assert_eq!(r1.len(), 1);
444        assert!(r2.is_empty());
445    }
446}