tantivy/query/
bm25.rs

1use serde::{Deserialize, Serialize};
2
3use crate::fieldnorm::FieldNormReader;
4use crate::query::Explanation;
5use crate::schema::Field;
6use crate::{Score, Searcher, Term};
7
8const K1: Score = 1.2;
9const B: Score = 0.75;
10
11/// An interface to compute the statistics needed in BM25 scoring.
12///
13/// The standard implementation is a [Searcher] but you can also
14/// create your own to adjust the statistics.
15pub trait Bm25StatisticsProvider {
16    /// The total number of tokens in a given field across all documents in
17    /// the index.
18    fn total_num_tokens(&self, field: Field) -> crate::Result<u64>;
19
20    /// The total number of documents in the index.
21    fn total_num_docs(&self) -> crate::Result<u64>;
22
23    /// The number of documents containing the given term.
24    fn doc_freq(&self, term: &Term) -> crate::Result<u64>;
25}
26
27impl Bm25StatisticsProvider for Searcher {
28    fn total_num_tokens(&self, field: Field) -> crate::Result<u64> {
29        let mut total_num_tokens = 0u64;
30
31        for segment_reader in self.segment_readers() {
32            let inverted_index = segment_reader.inverted_index(field)?;
33            total_num_tokens += inverted_index.total_num_tokens();
34        }
35        Ok(total_num_tokens)
36    }
37
38    fn total_num_docs(&self) -> crate::Result<u64> {
39        let mut total_num_docs = 0u64;
40
41        for segment_reader in self.segment_readers() {
42            total_num_docs += u64::from(segment_reader.max_doc());
43        }
44        Ok(total_num_docs)
45    }
46
47    fn doc_freq(&self, term: &Term) -> crate::Result<u64> {
48        self.doc_freq(term)
49    }
50}
51
52pub(crate) fn idf(doc_freq: u64, doc_count: u64) -> Score {
53    assert!(doc_count >= doc_freq, "{doc_count} >= {doc_freq}");
54    let x = ((doc_count - doc_freq) as Score + 0.5) / (doc_freq as Score + 0.5);
55    (1.0 + x).ln()
56}
57
58fn cached_tf_component(fieldnorm: u32, average_fieldnorm: Score) -> Score {
59    K1 * (1.0 - B + B * fieldnorm as Score / average_fieldnorm)
60}
61
62fn compute_tf_cache(average_fieldnorm: Score) -> [Score; 256] {
63    let mut cache: [Score; 256] = [0.0; 256];
64    for (fieldnorm_id, cache_mut) in cache.iter_mut().enumerate() {
65        let fieldnorm = FieldNormReader::id_to_fieldnorm(fieldnorm_id as u8);
66        *cache_mut = cached_tf_component(fieldnorm, average_fieldnorm);
67    }
68    cache
69}
70
71#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
72pub struct Bm25Params {
73    pub idf: Score,
74    pub avg_fieldnorm: Score,
75}
76
77/// A struct used for computing BM25 scores.
78#[derive(Clone)]
79pub struct Bm25Weight {
80    idf_explain: Option<Explanation>,
81    weight: Score,
82    cache: [Score; 256],
83    average_fieldnorm: Score,
84}
85
86impl Bm25Weight {
87    /// Increase the weight by a multiplicative factor.
88    pub fn boost_by(&self, boost: Score) -> Bm25Weight {
89        Bm25Weight {
90            idf_explain: self.idf_explain.clone(),
91            weight: self.weight * boost,
92            cache: self.cache,
93            average_fieldnorm: self.average_fieldnorm,
94        }
95    }
96
97    /// Construct a [Bm25Weight] for a phrase of terms.
98    pub fn for_terms(
99        statistics: &dyn Bm25StatisticsProvider,
100        terms: &[Term],
101    ) -> crate::Result<Bm25Weight> {
102        assert!(!terms.is_empty(), "Bm25 requires at least one term");
103        let field = terms[0].field();
104        for term in &terms[1..] {
105            assert_eq!(
106                term.field(),
107                field,
108                "All terms must belong to the same field."
109            );
110        }
111
112        let total_num_tokens = statistics.total_num_tokens(field)?;
113        let total_num_docs = statistics.total_num_docs()?;
114        let average_fieldnorm = total_num_tokens as Score / total_num_docs as Score;
115
116        if terms.len() == 1 {
117            let term_doc_freq = statistics.doc_freq(&terms[0])?;
118            Ok(Bm25Weight::for_one_term(
119                term_doc_freq,
120                total_num_docs,
121                average_fieldnorm,
122            ))
123        } else {
124            let mut idf_sum: Score = 0.0;
125            for term in terms {
126                let term_doc_freq = statistics.doc_freq(term)?;
127                idf_sum += idf(term_doc_freq, total_num_docs);
128            }
129            let idf_explain = Explanation::new("idf", idf_sum);
130            Ok(Bm25Weight::new(idf_explain, average_fieldnorm))
131        }
132    }
133
134    /// Construct a [Bm25Weight] for a single term.
135    pub fn for_one_term(
136        term_doc_freq: u64,
137        total_num_docs: u64,
138        avg_fieldnorm: Score,
139    ) -> Bm25Weight {
140        let idf = idf(term_doc_freq, total_num_docs);
141        let mut idf_explain =
142            Explanation::new("idf, computed as log(1 + (N - n + 0.5) / (n + 0.5))", idf);
143        idf_explain.add_const(
144            "n, number of docs containing this term",
145            term_doc_freq as Score,
146        );
147        idf_explain.add_const("N, total number of docs", total_num_docs as Score);
148        Bm25Weight::new(idf_explain, avg_fieldnorm)
149    }
150    /// Construct a [Bm25Weight] for a single term.
151    /// This method does not carry the [Explanation] for the idf.
152    pub fn for_one_term_without_explain(
153        term_doc_freq: u64,
154        total_num_docs: u64,
155        avg_fieldnorm: Score,
156    ) -> Bm25Weight {
157        let idf = idf(term_doc_freq, total_num_docs);
158        Bm25Weight::new_without_explain(idf, avg_fieldnorm)
159    }
160
161    pub(crate) fn new(idf_explain: Explanation, average_fieldnorm: Score) -> Bm25Weight {
162        let weight = idf_explain.value() * (1.0 + K1);
163        Bm25Weight {
164            idf_explain: Some(idf_explain),
165            weight,
166            cache: compute_tf_cache(average_fieldnorm),
167            average_fieldnorm,
168        }
169    }
170    pub(crate) fn new_without_explain(idf: f32, average_fieldnorm: Score) -> Bm25Weight {
171        let weight = idf * (1.0 + K1);
172        Bm25Weight {
173            idf_explain: None,
174            weight,
175            cache: compute_tf_cache(average_fieldnorm),
176            average_fieldnorm,
177        }
178    }
179
180    /// Compute the BM25 score of a single document.
181    #[inline]
182    pub fn score(&self, fieldnorm_id: u8, term_freq: u32) -> Score {
183        self.weight * self.tf_factor(fieldnorm_id, term_freq)
184    }
185
186    /// Compute the maximum possible BM25 score given this weight.
187    pub fn max_score(&self) -> Score {
188        self.score(255u8, 2_013_265_944)
189    }
190
191    #[inline]
192    pub(crate) fn tf_factor(&self, fieldnorm_id: u8, term_freq: u32) -> Score {
193        let term_freq = term_freq as Score;
194        let norm = self.cache[fieldnorm_id as usize];
195        term_freq / (term_freq + norm)
196    }
197
198    /// Produce an [Explanation] of a BM25 score.
199    pub fn explain(&self, fieldnorm_id: u8, term_freq: u32) -> Explanation {
200        // The explain format is directly copied from Lucene's.
201        // (So, Kudos to Lucene)
202        let score = self.score(fieldnorm_id, term_freq);
203
204        let norm = self.cache[fieldnorm_id as usize];
205        let term_freq = term_freq as Score;
206        let right_factor = term_freq / (term_freq + norm);
207
208        let mut tf_explanation = Explanation::new(
209            "freq / (freq + k1 * (1 - b + b * dl / avgdl))",
210            right_factor,
211        );
212
213        tf_explanation.add_const("freq, occurrences of term within document", term_freq);
214        tf_explanation.add_const("k1, term saturation parameter", K1);
215        tf_explanation.add_const("b, length normalization parameter", B);
216        tf_explanation.add_const(
217            "dl, length of field",
218            FieldNormReader::id_to_fieldnorm(fieldnorm_id) as Score,
219        );
220        tf_explanation.add_const("avgdl, average length of field", self.average_fieldnorm);
221
222        let mut explanation = Explanation::new("TermQuery, product of...", score);
223        explanation.add_detail(Explanation::new("(K1+1)", K1 + 1.0));
224        if let Some(idf_explain) = &self.idf_explain {
225            explanation.add_detail(idf_explain.clone());
226        }
227        explanation.add_detail(tf_explanation);
228        explanation
229    }
230}
231
232#[cfg(test)]
233mod tests {
234
235    use super::idf;
236    use crate::{assert_nearly_equals, Score};
237
238    #[test]
239    fn test_idf() {
240        let score: Score = 2.0;
241        assert_nearly_equals!(idf(1, 2), score.ln());
242    }
243}