bm_25/
scorer.rs

1use crate::embedder::{DefaultEmbeddingSpace, Embedding};
2use std::{
3    cmp::Ordering,
4    collections::{HashMap, HashSet},
5    fmt::Debug,
6    hash::Hash,
7};
8
9/// A document scored by the BM25 algorithm. K is the type of the document id.
10#[derive(PartialEq, Debug)]
11pub struct ScoredDocument<K> {
12    /// The id of the document.
13    pub id: K,
14    /// The BM25 score of the document.
15    pub score: f32,
16}
17
18/// Efficiently scores the relevance of a query embedding to document embeddings using BM25.
19/// K is the type of the document id and D is the type of the embedding space.
20#[derive(Default)]
21pub struct Scorer<K, D = DefaultEmbeddingSpace> {
22    // A mapping from document ids to the document embeddings.
23    embeddings: HashMap<K, Embedding<D>>,
24    // A mapping from token indices to the set of documents that contain that token.
25    inverted_token_index: HashMap<D, HashSet<K>>,
26}
27
28impl<K, D> Scorer<K, D>
29where
30    D: Eq + Hash + Clone,
31    K: Eq + Hash + Clone,
32{
33    /// Creates a new `Scorer`.
34    pub fn new() -> Scorer<K, D> {
35        Scorer {
36            embeddings: HashMap::new(),
37            inverted_token_index: HashMap::new(),
38        }
39    }
40
41    /// Upserts a document embedding into the scorer. If an embedding with the same id already
42    /// exists, it will be replaced. Note that upserting a document will change the true value of
43    /// `avgdl`. The more `avgdl` drifts from its true value, the less accurate the BM25 scores
44    /// will be.
45    pub fn upsert(&mut self, document_id: &K, embedding: Embedding<D>) {
46        if self.embeddings.contains_key(document_id) {
47            self.remove(document_id);
48        }
49        for token_index in embedding.indices() {
50            let documents_containing_token = self
51                .inverted_token_index
52                .entry(token_index.clone())
53                .or_default();
54            documents_containing_token.insert(document_id.clone());
55        }
56        self.embeddings.insert(document_id.clone(), embedding);
57    }
58
59    /// Removes a document embedding from the scorer if it exists.
60    pub fn remove(&mut self, document_id: &K) {
61        if let Some(embedding) = self.embeddings.remove(document_id) {
62            for token_index in embedding.indices() {
63                if let Some(matches) = self.inverted_token_index.get_mut(token_index) {
64                    matches.remove(document_id);
65                }
66            }
67        }
68    }
69
70    /// Scores the embedding for the given document against a given query embedding. Returns `None`
71    /// if the document does not exist in the scorer.
72    pub fn score(&self, document_id: &K, query_embedding: &Embedding<D>) -> Option<f32> {
73        let document_embedding = self.embeddings.get(document_id)?;
74        Some(self.score_(document_embedding, query_embedding))
75    }
76
77    /// Returns all documents relevant (i.e., score > 0) to the given query embedding, sorted by
78    /// relevance.
79    pub fn matches(&self, query_embedding: &Embedding<D>) -> Vec<ScoredDocument<K>> {
80        let relevant_embeddings_it = query_embedding
81            .indices()
82            .filter_map(|token_index| self.inverted_token_index.get(token_index))
83            .flat_map(|document_set| document_set.iter())
84            .collect::<HashSet<_>>()
85            .into_iter()
86            .filter_map(|document_id| self.embeddings.get(document_id).map(|e| (document_id, e)));
87
88        let mut scores: Vec<_> = relevant_embeddings_it
89            .map(|(document_id, document_embedding)| ScoredDocument {
90                id: document_id.clone(),
91                score: self.score_(document_embedding, query_embedding),
92            })
93            .collect();
94
95        scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
96        scores
97    }
98
99    fn idf(&self, token_index: &D) -> f32 {
100        let token_frequency = self
101            .inverted_token_index
102            .get(token_index)
103            .map_or(0, |documents| documents.len()) as f32;
104        let numerator = self.embeddings.len() as f32 - token_frequency + 0.5;
105        let denominator = token_frequency + 0.5;
106        (1f32 + (numerator / denominator)).ln()
107    }
108
109    fn score_(&self, document_embedding: &Embedding<D>, query_embedding: &Embedding<D>) -> f32 {
110        let mut document_score = 0f32;
111
112        for token_index in query_embedding.indices() {
113            let token_idf = self.idf(token_index);
114            let token_index_value = document_embedding
115                .iter()
116                .find(|token_embedding| token_embedding.index == *token_index)
117                .map(|token_embedding| token_embedding.value)
118                .unwrap_or(0f32);
119            let token_score = token_idf * token_index_value;
120            document_score += token_score;
121        }
122        document_score
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use crate::TokenEmbedding;
129
130    use super::*;
131
132    fn scorer_with_embeddings(embeddings: &Vec<Embedding>) -> Scorer<usize> {
133        let mut scorer = Scorer::<usize>::new();
134
135        for (i, document_embedding) in embeddings.iter().enumerate() {
136            scorer.upsert(&i, document_embedding.clone());
137        }
138
139        scorer
140    }
141
142    #[test]
143    fn it_scores_missing_document_as_none() {
144        let scorer = Scorer::<usize>::new();
145        let query_embedding = Embedding::any();
146        let score = scorer.score(&12345, &query_embedding);
147        let matches = scorer.matches(&query_embedding);
148        assert_eq!(score, None);
149        assert!(matches.is_empty());
150    }
151
152    #[test]
153    fn it_scores_mutually_exclusive_indices_as_zero() {
154        let document_embeddings = vec![Embedding(vec![TokenEmbedding::new(1, 1.0)])];
155        let scorer = scorer_with_embeddings(&document_embeddings);
156
157        let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
158        let score = scorer.score(&0, &query_embedding);
159
160        assert_eq!(score, Some(0.0));
161    }
162
163    #[test]
164    fn it_scores_rare_indices_higher_than_common_ones() {
165        // BM25 should score rare token matches higher than common token matches.
166        let document_embeddings = vec![
167            Embedding(vec![TokenEmbedding::new(0, 1.0)]),
168            Embedding(vec![TokenEmbedding::new(0, 1.0)]),
169            Embedding(vec![TokenEmbedding::new(1, 1.0)]),
170        ];
171        let scorer = scorer_with_embeddings(&document_embeddings);
172
173        let score_1 = scorer.score(&0, &Embedding(vec![TokenEmbedding::new(0, 1.0)]));
174        let score_2 = scorer.score(&2, &Embedding(vec![TokenEmbedding::new(1, 1.0)]));
175
176        assert!(score_1.unwrap() < score_2.unwrap());
177    }
178
179    #[test]
180    fn it_scores_longer_embeddings_lower_than_shorter_ones() {
181        let document_embeddings = vec![
182            // Longer embeddings will have a lower value for unique tokens.
183            Embedding(vec![
184                TokenEmbedding::new(0, 0.9),
185                TokenEmbedding::new(1, 0.9),
186            ]),
187            Embedding(vec![TokenEmbedding::new(0, 1.0)]),
188        ];
189        let scorer = scorer_with_embeddings(&document_embeddings);
190
191        let score_1 = scorer.score(&0, &Embedding(vec![TokenEmbedding::new(0, 1.0)]));
192        let score_2 = scorer.score(&1, &Embedding(vec![TokenEmbedding::new(0, 1.0)]));
193
194        assert!(score_1.unwrap() < score_2.unwrap());
195    }
196
197    #[test]
198    fn it_only_matches_embeddings_with_non_zero_score() {
199        let document_embeddings = vec![
200            Embedding(vec![TokenEmbedding::new(0, 1.0)]),
201            Embedding(vec![TokenEmbedding::new(1, 1.0)]),
202        ];
203        let scorer = scorer_with_embeddings(&document_embeddings);
204
205        let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
206        let matches = scorer.matches(&query_embedding);
207
208        assert_eq!(
209            matches,
210            vec![ScoredDocument {
211                id: 0,
212                score: 0.6931472
213            }]
214        );
215    }
216
217    #[test]
218    fn it_does_not_score_frequent_terms_negatively() {
219        // In versions 2.2.1 and earlier, the IDF considered the total occurrences of a token where
220        // it should have considered the total number of documents containing the token. In
221        // instances where the occurrences exceeded the number of documents, the IDF (and therefore
222        // the score) would be negative.
223        // See this bug report for more information: https://github.com/Michael-JB/bm25/pull/20
224        let document_embeddings = vec![Embedding(vec![
225            TokenEmbedding::new(0, 1.5),
226            TokenEmbedding::new(0, 1.5),
227        ])];
228        let scorer = scorer_with_embeddings(&document_embeddings);
229        let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
230
231        let matches = scorer.matches(&query_embedding);
232
233        assert!(matches[0].score >= 0.0);
234    }
235
236    #[test]
237    fn it_sorts_matches_by_score() {
238        let document_embeddings = vec![
239            Embedding(vec![
240                TokenEmbedding::new(0, 0.9),
241                TokenEmbedding::new(1, 0.9),
242            ]),
243            Embedding(vec![TokenEmbedding::new(0, 1.0)]),
244        ];
245        let scorer = scorer_with_embeddings(&document_embeddings);
246
247        let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
248        let matches = scorer.matches(&query_embedding);
249
250        assert_eq!(
251            matches,
252            vec![
253                ScoredDocument {
254                    id: 1,
255                    score: 0.1823216
256                },
257                ScoredDocument {
258                    id: 0,
259                    score: 0.16408943
260                }
261            ]
262        );
263    }
264}