bm25/
scorer.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
use crate::embedder::DefaultTokenEmbedder;
use crate::embedder::Embedding;
use std::{
    cmp::Ordering,
    collections::{HashMap, HashSet},
    fmt::Debug,
    hash::Hash,
};

/// A document scored by the BM25 algorithm. K is the type of the document id.
#[derive(PartialEq, Debug)]
pub struct ScoredDocument<K> {
    /// The id of the document.
    pub id: K,
    /// The BM25 score of the document.
    pub score: f32,
}

/// Efficiently scores the relevance of a query embedding to document embeddings using BM25.
/// K is the type of the document id and D is the type of the token embedder.
#[derive(Default)]
pub struct Scorer<K, D = DefaultTokenEmbedder> {
    // A mapping from document ids to the document embeddings.
    embeddings: HashMap<K, Embedding<D>>,
    // A mapping from token indices to the number of documents that contain that token.
    token_frequencies: HashMap<D, u32>,
    // A mapping from token indices to the set of documents that contain that token.
    inverted_token_index: HashMap<D, HashSet<K>>,
}

impl<K, D> Scorer<K, D>
where
    D: Eq + Hash + Clone,
    K: Eq + Hash + Clone,
{
    /// Creates a new `Scorer`.
    pub fn new() -> Scorer<K, D> {
        Scorer {
            embeddings: HashMap::new(),
            token_frequencies: HashMap::new(),
            inverted_token_index: HashMap::new(),
        }
    }

    /// Upserts a document embedding into the scorer. If an embedding with the same id already
    /// exists, it will be replaced. Note that upserting a document will change the true value of
    /// `avgdl`. The more `avgdl` drifts from its true value, the less accurate the BM25 scores
    /// will be.
    pub fn upsert(&mut self, document_id: &K, embedding: Embedding<D>) {
        if self.embeddings.contains_key(document_id) {
            self.remove(document_id);
        }
        for token_index in embedding.indices() {
            let token_frequency = self
                .token_frequencies
                .entry(token_index.clone())
                .or_insert(0);
            *token_frequency += 1;
            let documents_containing_token = self
                .inverted_token_index
                .entry(token_index.clone())
                .or_default();
            documents_containing_token.insert(document_id.clone());
        }
        self.embeddings.insert(document_id.clone(), embedding);
    }

    /// Removes a document embedding from the scorer if it exists.
    pub fn remove(&mut self, document_id: &K) {
        if let Some(embedding) = self.embeddings.remove(document_id) {
            for token_index in embedding.indices() {
                if let Some(token_frequency) = self.token_frequencies.get_mut(token_index) {
                    *token_frequency -= 1;
                }
                if let Some(matches) = self.inverted_token_index.get_mut(token_index) {
                    matches.remove(document_id);
                }
            }
        }
    }

    /// Scores the embedding for the given document against a given query embedding. Returns `None`
    /// if the document does not exist in the scorer.
    pub fn score(&self, document_id: &K, query_embedding: &Embedding<D>) -> Option<f32> {
        let document_embedding = self.embeddings.get(document_id)?;
        Some(self.score_(document_embedding, query_embedding))
    }

    /// Returns all documents relevant (i.e., score > 0) to the given query embedding, sorted by
    /// relevance.
    pub fn matches(&self, query_embedding: &Embedding<D>) -> Vec<ScoredDocument<K>> {
        let relevant_embeddings_it = query_embedding
            .indices()
            .filter_map(|token_index| self.inverted_token_index.get(token_index))
            .flat_map(|document_set| document_set.iter())
            .collect::<HashSet<_>>()
            .into_iter()
            .filter_map(|document_id| self.embeddings.get(document_id).map(|e| (document_id, e)));

        let mut scores: Vec<_> = relevant_embeddings_it
            .map(|(document_id, document_embedding)| ScoredDocument {
                id: document_id.clone(),
                score: self.score_(document_embedding, query_embedding),
            })
            .collect();

        scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
        scores
    }

    fn idf(&self, token_index: &D) -> f32 {
        let token_frequency = *self.token_frequencies.get(token_index).unwrap_or(&0) as f32;
        let numerator = self.embeddings.len() as f32 - token_frequency + 0.5;
        let denominator = token_frequency + 0.5;
        (1f32 + (numerator / denominator)).ln()
    }

    fn score_(&self, document_embedding: &Embedding<D>, query_embedding: &Embedding<D>) -> f32 {
        let mut document_score = 0f32;

        for token_index in query_embedding.indices() {
            let token_idf = self.idf(token_index);
            let token_index_value = document_embedding
                .iter()
                .find(|token_embedding| token_embedding.index == *token_index)
                .map(|token_embedding| token_embedding.value)
                .unwrap_or(0f32);
            let token_score = token_idf * token_index_value;
            document_score += token_score;
        }
        document_score
    }
}

#[cfg(test)]
mod tests {
    use crate::TokenEmbedding;

    use super::*;

    fn scorer_with_embeddings(embeddings: &Vec<Embedding>) -> Scorer<usize> {
        let mut scorer = Scorer::<usize>::new();

        for (i, document_embedding) in embeddings.iter().enumerate() {
            scorer.upsert(&i, document_embedding.clone());
        }

        scorer
    }

    #[test]
    fn it_scores_missing_document_as_none() {
        let scorer = Scorer::<usize>::new();
        let query_embedding = Embedding::any();
        let score = scorer.score(&12345, &query_embedding);
        let matches = scorer.matches(&query_embedding);
        assert_eq!(score, None);
        assert!(matches.is_empty());
    }

    #[test]
    fn it_scores_mutually_exclusive_indices_as_zero() {
        let document_embeddings = vec![Embedding(vec![TokenEmbedding::new(1, 1.0)])];
        let scorer = scorer_with_embeddings(&document_embeddings);

        let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
        let score = scorer.score(&0, &query_embedding);

        assert_eq!(score, Some(0.0));
    }

    #[test]
    fn it_scores_rare_indices_higher_than_common_ones() {
        // BM25 should score rare token matches higher than common token matches.
        let document_embeddings = vec![
            Embedding(vec![TokenEmbedding::new(0, 1.0)]),
            Embedding(vec![TokenEmbedding::new(0, 1.0)]),
            Embedding(vec![TokenEmbedding::new(1, 1.0)]),
        ];
        let scorer = scorer_with_embeddings(&document_embeddings);

        let score_1 = scorer.score(&0, &Embedding(vec![TokenEmbedding::new(0, 1.0)]));
        let score_2 = scorer.score(&2, &Embedding(vec![TokenEmbedding::new(1, 1.0)]));

        assert!(score_1.unwrap() < score_2.unwrap());
    }

    #[test]
    fn it_scores_longer_embeddings_lower_than_shorter_ones() {
        let document_embeddings = vec![
            // Longer embeddings will have a lower value for unique tokens.
            Embedding(vec![
                TokenEmbedding::new(0, 0.9),
                TokenEmbedding::new(1, 0.9),
            ]),
            Embedding(vec![TokenEmbedding::new(0, 1.0)]),
        ];
        let scorer = scorer_with_embeddings(&document_embeddings);

        let score_1 = scorer.score(&0, &Embedding(vec![TokenEmbedding::new(0, 1.0)]));
        let score_2 = scorer.score(&1, &Embedding(vec![TokenEmbedding::new(0, 1.0)]));

        assert!(score_1.unwrap() < score_2.unwrap());
    }

    #[test]
    fn it_only_matches_embeddings_with_non_zero_score() {
        let document_embeddings = vec![
            Embedding(vec![TokenEmbedding::new(0, 1.0)]),
            Embedding(vec![TokenEmbedding::new(1, 1.0)]),
        ];
        let scorer = scorer_with_embeddings(&document_embeddings);

        let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
        let matches = scorer.matches(&query_embedding);

        assert_eq!(
            matches,
            vec![ScoredDocument {
                id: 0,
                score: 0.6931472
            }]
        );
    }

    #[test]
    fn it_sorts_matches_by_score() {
        let document_embeddings = vec![
            Embedding(vec![
                TokenEmbedding::new(0, 0.9),
                TokenEmbedding::new(1, 0.9),
            ]),
            Embedding(vec![TokenEmbedding::new(0, 1.0)]),
        ];
        let scorer = scorer_with_embeddings(&document_embeddings);

        let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
        let matches = scorer.matches(&query_embedding);

        assert_eq!(
            matches,
            vec![
                ScoredDocument {
                    id: 1,
                    score: 0.1823216
                },
                ScoredDocument {
                    id: 0,
                    score: 0.16408943
                }
            ]
        );
    }
}