1use serde::{Deserialize, Serialize};
2
3use crate::fieldnorm::FieldNormReader;
4use crate::query::Explanation;
5use crate::schema::Field;
6use crate::{Score, Searcher, Term};
7
8const K1: Score = 1.2;
9const B: Score = 0.75;
10
11pub trait Bm25StatisticsProvider {
16 fn total_num_tokens(&self, field: Field) -> crate::Result<u64>;
19
20 fn total_num_docs(&self) -> crate::Result<u64>;
22
23 fn doc_freq(&self, term: &Term) -> crate::Result<u64>;
25}
26
27impl Bm25StatisticsProvider for Searcher {
28 fn total_num_tokens(&self, field: Field) -> crate::Result<u64> {
29 let mut total_num_tokens = 0u64;
30
31 for segment_reader in self.segment_readers() {
32 let inverted_index = segment_reader.inverted_index(field)?;
33 total_num_tokens += inverted_index.total_num_tokens();
34 }
35 Ok(total_num_tokens)
36 }
37
38 fn total_num_docs(&self) -> crate::Result<u64> {
39 let mut total_num_docs = 0u64;
40
41 for segment_reader in self.segment_readers() {
42 total_num_docs += u64::from(segment_reader.max_doc());
43 }
44 Ok(total_num_docs)
45 }
46
47 fn doc_freq(&self, term: &Term) -> crate::Result<u64> {
48 self.doc_freq(term)
49 }
50}
51
52pub(crate) fn idf(doc_freq: u64, doc_count: u64) -> Score {
53 assert!(doc_count >= doc_freq, "{doc_count} >= {doc_freq}");
54 let x = ((doc_count - doc_freq) as Score + 0.5) / (doc_freq as Score + 0.5);
55 (1.0 + x).ln()
56}
57
58fn cached_tf_component(fieldnorm: u32, average_fieldnorm: Score) -> Score {
59 K1 * (1.0 - B + B * fieldnorm as Score / average_fieldnorm)
60}
61
62fn compute_tf_cache(average_fieldnorm: Score) -> [Score; 256] {
63 let mut cache: [Score; 256] = [0.0; 256];
64 for (fieldnorm_id, cache_mut) in cache.iter_mut().enumerate() {
65 let fieldnorm = FieldNormReader::id_to_fieldnorm(fieldnorm_id as u8);
66 *cache_mut = cached_tf_component(fieldnorm, average_fieldnorm);
67 }
68 cache
69}
70
71#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
72pub struct Bm25Params {
73 pub idf: Score,
74 pub avg_fieldnorm: Score,
75}
76
77#[derive(Clone)]
79pub struct Bm25Weight {
80 idf_explain: Option<Explanation>,
81 weight: Score,
82 cache: [Score; 256],
83 average_fieldnorm: Score,
84}
85
86impl Bm25Weight {
87 pub fn boost_by(&self, boost: Score) -> Bm25Weight {
89 Bm25Weight {
90 idf_explain: self.idf_explain.clone(),
91 weight: self.weight * boost,
92 cache: self.cache,
93 average_fieldnorm: self.average_fieldnorm,
94 }
95 }
96
97 pub fn for_terms(
99 statistics: &dyn Bm25StatisticsProvider,
100 terms: &[Term],
101 ) -> crate::Result<Bm25Weight> {
102 assert!(!terms.is_empty(), "Bm25 requires at least one term");
103 let field = terms[0].field();
104 for term in &terms[1..] {
105 assert_eq!(
106 term.field(),
107 field,
108 "All terms must belong to the same field."
109 );
110 }
111
112 let total_num_tokens = statistics.total_num_tokens(field)?;
113 let total_num_docs = statistics.total_num_docs()?;
114 let average_fieldnorm = total_num_tokens as Score / total_num_docs as Score;
115
116 if terms.len() == 1 {
117 let term_doc_freq = statistics.doc_freq(&terms[0])?;
118 Ok(Bm25Weight::for_one_term(
119 term_doc_freq,
120 total_num_docs,
121 average_fieldnorm,
122 ))
123 } else {
124 let mut idf_sum: Score = 0.0;
125 for term in terms {
126 let term_doc_freq = statistics.doc_freq(term)?;
127 idf_sum += idf(term_doc_freq, total_num_docs);
128 }
129 let idf_explain = Explanation::new("idf", idf_sum);
130 Ok(Bm25Weight::new(idf_explain, average_fieldnorm))
131 }
132 }
133
134 pub fn for_one_term(
136 term_doc_freq: u64,
137 total_num_docs: u64,
138 avg_fieldnorm: Score,
139 ) -> Bm25Weight {
140 let idf = idf(term_doc_freq, total_num_docs);
141 let mut idf_explain =
142 Explanation::new("idf, computed as log(1 + (N - n + 0.5) / (n + 0.5))", idf);
143 idf_explain.add_const(
144 "n, number of docs containing this term",
145 term_doc_freq as Score,
146 );
147 idf_explain.add_const("N, total number of docs", total_num_docs as Score);
148 Bm25Weight::new(idf_explain, avg_fieldnorm)
149 }
150 pub fn for_one_term_without_explain(
153 term_doc_freq: u64,
154 total_num_docs: u64,
155 avg_fieldnorm: Score,
156 ) -> Bm25Weight {
157 let idf = idf(term_doc_freq, total_num_docs);
158 Bm25Weight::new_without_explain(idf, avg_fieldnorm)
159 }
160
161 pub(crate) fn new(idf_explain: Explanation, average_fieldnorm: Score) -> Bm25Weight {
162 let weight = idf_explain.value() * (1.0 + K1);
163 Bm25Weight {
164 idf_explain: Some(idf_explain),
165 weight,
166 cache: compute_tf_cache(average_fieldnorm),
167 average_fieldnorm,
168 }
169 }
170 pub(crate) fn new_without_explain(idf: f32, average_fieldnorm: Score) -> Bm25Weight {
171 let weight = idf * (1.0 + K1);
172 Bm25Weight {
173 idf_explain: None,
174 weight,
175 cache: compute_tf_cache(average_fieldnorm),
176 average_fieldnorm,
177 }
178 }
179
180 #[inline]
182 pub fn score(&self, fieldnorm_id: u8, term_freq: u32) -> Score {
183 self.weight * self.tf_factor(fieldnorm_id, term_freq)
184 }
185
186 pub fn max_score(&self) -> Score {
188 self.score(255u8, 2_013_265_944)
189 }
190
191 #[inline]
192 pub(crate) fn tf_factor(&self, fieldnorm_id: u8, term_freq: u32) -> Score {
193 let term_freq = term_freq as Score;
194 let norm = self.cache[fieldnorm_id as usize];
195 term_freq / (term_freq + norm)
196 }
197
198 pub fn explain(&self, fieldnorm_id: u8, term_freq: u32) -> Explanation {
200 let score = self.score(fieldnorm_id, term_freq);
203
204 let norm = self.cache[fieldnorm_id as usize];
205 let term_freq = term_freq as Score;
206 let right_factor = term_freq / (term_freq + norm);
207
208 let mut tf_explanation = Explanation::new(
209 "freq / (freq + k1 * (1 - b + b * dl / avgdl))",
210 right_factor,
211 );
212
213 tf_explanation.add_const("freq, occurrences of term within document", term_freq);
214 tf_explanation.add_const("k1, term saturation parameter", K1);
215 tf_explanation.add_const("b, length normalization parameter", B);
216 tf_explanation.add_const(
217 "dl, length of field",
218 FieldNormReader::id_to_fieldnorm(fieldnorm_id) as Score,
219 );
220 tf_explanation.add_const("avgdl, average length of field", self.average_fieldnorm);
221
222 let mut explanation = Explanation::new("TermQuery, product of...", score);
223 explanation.add_detail(Explanation::new("(K1+1)", K1 + 1.0));
224 if let Some(idf_explain) = &self.idf_explain {
225 explanation.add_detail(idf_explain.clone());
226 }
227 explanation.add_detail(tf_explanation);
228 explanation
229 }
230}
231
232#[cfg(test)]
233mod tests {
234
235 use super::idf;
236 use crate::{assert_nearly_equals, Score};
237
238 #[test]
239 fn test_idf() {
240 let score: Score = 2.0;
241 assert_nearly_equals!(idf(1, 2), score.ln());
242 }
243}