1use crate::embedder::{DefaultEmbeddingSpace, Embedding};
2use std::{
3 cmp::Ordering,
4 collections::{HashMap, HashSet},
5 fmt::Debug,
6 hash::Hash,
7};
8
9#[derive(PartialEq, Debug)]
11pub struct ScoredDocument<K> {
12 pub id: K,
14 pub score: f32,
16}
17
18#[derive(Default)]
21pub struct Scorer<K, D = DefaultEmbeddingSpace> {
22 embeddings: HashMap<K, Embedding<D>>,
24 inverted_token_index: HashMap<D, HashSet<K>>,
26}
27
28impl<K, D> Scorer<K, D>
29where
30 D: Eq + Hash + Clone,
31 K: Eq + Hash + Clone,
32{
33 pub fn new() -> Scorer<K, D> {
35 Scorer {
36 embeddings: HashMap::new(),
37 inverted_token_index: HashMap::new(),
38 }
39 }
40
41 pub fn upsert(&mut self, document_id: &K, embedding: Embedding<D>) {
46 if self.embeddings.contains_key(document_id) {
47 self.remove(document_id);
48 }
49 for token_index in embedding.indices() {
50 let documents_containing_token = self
51 .inverted_token_index
52 .entry(token_index.clone())
53 .or_default();
54 documents_containing_token.insert(document_id.clone());
55 }
56 self.embeddings.insert(document_id.clone(), embedding);
57 }
58
59 pub fn remove(&mut self, document_id: &K) {
61 if let Some(embedding) = self.embeddings.remove(document_id) {
62 for token_index in embedding.indices() {
63 if let Some(matches) = self.inverted_token_index.get_mut(token_index) {
64 matches.remove(document_id);
65 }
66 }
67 }
68 }
69
70 pub fn score(&self, document_id: &K, query_embedding: &Embedding<D>) -> Option<f32> {
73 let document_embedding = self.embeddings.get(document_id)?;
74 Some(self.score_(document_embedding, query_embedding))
75 }
76
77 pub fn matches(&self, query_embedding: &Embedding<D>) -> Vec<ScoredDocument<K>> {
80 let relevant_embeddings_it = query_embedding
81 .indices()
82 .filter_map(|token_index| self.inverted_token_index.get(token_index))
83 .flat_map(|document_set| document_set.iter())
84 .collect::<HashSet<_>>()
85 .into_iter()
86 .filter_map(|document_id| self.embeddings.get(document_id).map(|e| (document_id, e)));
87
88 let mut scores: Vec<_> = relevant_embeddings_it
89 .map(|(document_id, document_embedding)| ScoredDocument {
90 id: document_id.clone(),
91 score: self.score_(document_embedding, query_embedding),
92 })
93 .collect();
94
95 scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
96 scores
97 }
98
99 fn idf(&self, token_index: &D) -> f32 {
100 let token_frequency = self
101 .inverted_token_index
102 .get(token_index)
103 .map_or(0, |documents| documents.len()) as f32;
104 let numerator = self.embeddings.len() as f32 - token_frequency + 0.5;
105 let denominator = token_frequency + 0.5;
106 (1f32 + (numerator / denominator)).ln()
107 }
108
109 fn score_(&self, document_embedding: &Embedding<D>, query_embedding: &Embedding<D>) -> f32 {
110 let mut document_score = 0f32;
111
112 for token_index in query_embedding.indices() {
113 let token_idf = self.idf(token_index);
114 let token_index_value = document_embedding
115 .iter()
116 .find(|token_embedding| token_embedding.index == *token_index)
117 .map(|token_embedding| token_embedding.value)
118 .unwrap_or(0f32);
119 let token_score = token_idf * token_index_value;
120 document_score += token_score;
121 }
122 document_score
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use crate::TokenEmbedding;
129
130 use super::*;
131
132 fn scorer_with_embeddings(embeddings: &Vec<Embedding>) -> Scorer<usize> {
133 let mut scorer = Scorer::<usize>::new();
134
135 for (i, document_embedding) in embeddings.iter().enumerate() {
136 scorer.upsert(&i, document_embedding.clone());
137 }
138
139 scorer
140 }
141
142 #[test]
143 fn it_scores_missing_document_as_none() {
144 let scorer = Scorer::<usize>::new();
145 let query_embedding = Embedding::any();
146 let score = scorer.score(&12345, &query_embedding);
147 let matches = scorer.matches(&query_embedding);
148 assert_eq!(score, None);
149 assert!(matches.is_empty());
150 }
151
152 #[test]
153 fn it_scores_mutually_exclusive_indices_as_zero() {
154 let document_embeddings = vec![Embedding(vec![TokenEmbedding::new(1, 1.0)])];
155 let scorer = scorer_with_embeddings(&document_embeddings);
156
157 let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
158 let score = scorer.score(&0, &query_embedding);
159
160 assert_eq!(score, Some(0.0));
161 }
162
163 #[test]
164 fn it_scores_rare_indices_higher_than_common_ones() {
165 let document_embeddings = vec![
167 Embedding(vec![TokenEmbedding::new(0, 1.0)]),
168 Embedding(vec![TokenEmbedding::new(0, 1.0)]),
169 Embedding(vec![TokenEmbedding::new(1, 1.0)]),
170 ];
171 let scorer = scorer_with_embeddings(&document_embeddings);
172
173 let score_1 = scorer.score(&0, &Embedding(vec![TokenEmbedding::new(0, 1.0)]));
174 let score_2 = scorer.score(&2, &Embedding(vec![TokenEmbedding::new(1, 1.0)]));
175
176 assert!(score_1.unwrap() < score_2.unwrap());
177 }
178
179 #[test]
180 fn it_scores_longer_embeddings_lower_than_shorter_ones() {
181 let document_embeddings = vec![
182 Embedding(vec![
184 TokenEmbedding::new(0, 0.9),
185 TokenEmbedding::new(1, 0.9),
186 ]),
187 Embedding(vec![TokenEmbedding::new(0, 1.0)]),
188 ];
189 let scorer = scorer_with_embeddings(&document_embeddings);
190
191 let score_1 = scorer.score(&0, &Embedding(vec![TokenEmbedding::new(0, 1.0)]));
192 let score_2 = scorer.score(&1, &Embedding(vec![TokenEmbedding::new(0, 1.0)]));
193
194 assert!(score_1.unwrap() < score_2.unwrap());
195 }
196
197 #[test]
198 fn it_only_matches_embeddings_with_non_zero_score() {
199 let document_embeddings = vec![
200 Embedding(vec![TokenEmbedding::new(0, 1.0)]),
201 Embedding(vec![TokenEmbedding::new(1, 1.0)]),
202 ];
203 let scorer = scorer_with_embeddings(&document_embeddings);
204
205 let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
206 let matches = scorer.matches(&query_embedding);
207
208 assert_eq!(
209 matches,
210 vec![ScoredDocument {
211 id: 0,
212 score: 0.6931472
213 }]
214 );
215 }
216
217 #[test]
218 fn it_does_not_score_frequent_terms_negatively() {
219 let document_embeddings = vec![Embedding(vec![
225 TokenEmbedding::new(0, 1.5),
226 TokenEmbedding::new(0, 1.5),
227 ])];
228 let scorer = scorer_with_embeddings(&document_embeddings);
229 let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
230
231 let matches = scorer.matches(&query_embedding);
232
233 assert!(matches[0].score >= 0.0);
234 }
235
236 #[test]
237 fn it_sorts_matches_by_score() {
238 let document_embeddings = vec![
239 Embedding(vec![
240 TokenEmbedding::new(0, 0.9),
241 TokenEmbedding::new(1, 0.9),
242 ]),
243 Embedding(vec![TokenEmbedding::new(0, 1.0)]),
244 ];
245 let scorer = scorer_with_embeddings(&document_embeddings);
246
247 let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
248 let matches = scorer.matches(&query_embedding);
249
250 assert_eq!(
251 matches,
252 vec![
253 ScoredDocument {
254 id: 1,
255 score: 0.1823216
256 },
257 ScoredDocument {
258 id: 0,
259 score: 0.16408943
260 }
261 ]
262 );
263 }
264}