1use std::collections::HashMap;
23
24use crate::sparse::SparseVector;
25
26#[derive(Debug, Default)]
33pub struct SparseInvertedIndex {
34 postings: HashMap<u32, HashMap<u32, f32>>,
36 dims_of: Vec<Vec<u32>>,
39 ext_of: Vec<String>,
41 slot_of: HashMap<String, u32>,
43 free: Vec<u32>,
45 doclen: Vec<f32>,
48 total_len: f64,
50 len: usize,
52}
53
54pub const BM25_K1: f32 = 1.2;
56pub const BM25_B: f32 = 0.75;
58
59impl SparseInvertedIndex {
60 pub fn new() -> Self {
62 Self::default()
63 }
64
65 pub fn len(&self) -> usize {
67 self.len
68 }
69
70 pub fn is_empty(&self) -> bool {
72 self.len == 0
73 }
74
75 pub fn contains(&self, ext_id: &str) -> bool {
77 self.slot_of.contains_key(ext_id)
78 }
79
80 pub fn upsert(&mut self, ext_id: &str, sv: &SparseVector) {
85 let slot = match self.slot_of.get(ext_id).copied() {
86 Some(slot) => {
87 self.clear_postings(slot);
88 slot
89 }
90 None => self.allocate(ext_id),
91 };
92 self.total_len -= self.doclen[slot as usize] as f64;
95 let mut weights: HashMap<u32, f32> = HashMap::with_capacity(sv.indices.len());
98 for (&dim, &w) in sv.indices.iter().zip(sv.values.iter()) {
99 weights.insert(dim, w);
100 }
101 let dl: f32 = weights.values().copied().sum();
104 let mut dims = Vec::with_capacity(weights.len());
105 for (dim, w) in weights {
106 self.postings.entry(dim).or_default().insert(slot, w);
107 dims.push(dim);
108 }
109 self.dims_of[slot as usize] = dims;
110 self.doclen[slot as usize] = dl;
111 self.total_len += dl as f64;
112 }
113
114 pub fn remove(&mut self, ext_id: &str) -> bool {
116 let Some(slot) = self.slot_of.remove(ext_id) else {
117 return false;
118 };
119 self.clear_postings(slot);
120 self.total_len -= self.doclen[slot as usize] as f64;
121 self.doclen[slot as usize] = 0.0;
122 self.dims_of[slot as usize] = Vec::new();
123 self.ext_of[slot as usize] = String::new();
124 self.free.push(slot);
125 self.len -= 1;
126 true
127 }
128
129 pub fn search(&self, query: &SparseVector) -> Vec<(String, f32)> {
136 let mut qweights: HashMap<u32, f32> = HashMap::with_capacity(query.indices.len());
137 for (&dim, &w) in query.indices.iter().zip(query.values.iter()) {
138 qweights.insert(dim, w);
139 }
140 let mut scores: HashMap<u32, f32> = HashMap::new();
141 for (dim, qw) in qweights {
142 if let Some(plist) = self.postings.get(&dim) {
143 for (&slot, &w) in plist {
144 *scores.entry(slot).or_insert(0.0) += qw * w;
145 }
146 }
147 }
148 let mut out: Vec<(String, f32)> = scores
149 .into_iter()
150 .filter(|&(_, score)| score > 0.0)
151 .map(|(slot, score)| (self.ext_of[slot as usize].clone(), score))
152 .collect();
153 out.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
154 out
155 }
156
157 pub fn bm25_search(&self, query_terms: &[u32], k1: f32, b: f32) -> Vec<(String, f32)> {
168 if self.len == 0 {
169 return Vec::new();
170 }
171 let n = self.len as f64;
172 let avgdl = (self.total_len / n).max(f64::MIN_POSITIVE);
173 let (k1, b) = (k1 as f64, b as f64);
174 let mut scores: HashMap<u32, f32> = HashMap::new();
175 let mut seen: std::collections::HashSet<u32> = std::collections::HashSet::new();
176 for &term in query_terms {
177 if !seen.insert(term) {
178 continue; }
180 let Some(plist) = self.postings.get(&term) else {
181 continue;
182 };
183 let df = plist.len() as f64;
184 let idf = (1.0 + (n - df + 0.5) / (df + 0.5)).ln();
185 for (&slot, &tf) in plist {
186 let tf = tf as f64;
187 let dl = self.doclen[slot as usize] as f64;
188 let denom = tf + k1 * (1.0 - b + b * (dl / avgdl));
189 let contribution = idf * (tf * (k1 + 1.0)) / denom;
190 *scores.entry(slot).or_insert(0.0) += contribution as f32;
191 }
192 }
193 let mut out: Vec<(String, f32)> = scores
194 .into_iter()
195 .filter(|&(_, score)| score > 0.0)
196 .map(|(slot, score)| (self.ext_of[slot as usize].clone(), score))
197 .collect();
198 out.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
199 out
200 }
201
202 fn allocate(&mut self, ext_id: &str) -> u32 {
204 let slot = match self.free.pop() {
205 Some(slot) => {
206 self.ext_of[slot as usize] = ext_id.to_owned();
207 slot
208 }
209 None => {
210 let slot = self.ext_of.len() as u32;
211 self.ext_of.push(ext_id.to_owned());
212 self.dims_of.push(Vec::new());
213 self.doclen.push(0.0);
214 slot
215 }
216 };
217 self.slot_of.insert(ext_id.to_owned(), slot);
218 self.len += 1;
219 slot
220 }
221
222 fn clear_postings(&mut self, slot: u32) {
226 for dim in std::mem::take(&mut self.dims_of[slot as usize]) {
227 if let Some(plist) = self.postings.get_mut(&dim) {
228 plist.remove(&slot);
229 if plist.is_empty() {
230 self.postings.remove(&dim);
231 }
232 }
233 }
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 fn sv(indices: &[u32], values: &[f32]) -> SparseVector {
242 SparseVector {
243 indices: indices.to_vec(),
244 values: values.to_vec(),
245 }
246 }
247
248 fn ids(results: &[(String, f32)]) -> Vec<&str> {
249 results.iter().map(|(id, _)| id.as_str()).collect()
250 }
251
252 #[test]
253 fn empty_index_reports_empty_and_searches_to_nothing() {
254 let idx = SparseInvertedIndex::new();
255 assert!(idx.is_empty());
256 assert_eq!(idx.len(), 0);
257 assert!(!idx.contains("x"));
258 assert!(idx.search(&sv(&[1, 2], &[1.0, 1.0])).is_empty());
259 }
260
261 #[test]
262 fn ranks_by_dot_product_and_breaks_ties_by_id() {
263 let mut idx = SparseInvertedIndex::new();
264 idx.upsert("a", &sv(&[1, 2], &[1.0, 1.0])); idx.upsert("b", &sv(&[2, 3], &[1.0, 1.0])); idx.upsert("c", &sv(&[1], &[2.0])); idx.upsert("tie", &sv(&[1, 2], &[1.0, 1.0])); assert_eq!(idx.len(), 4);
269 let q = sv(&[1, 2], &[2.0, 3.0]);
270 let res = idx.search(&q);
271 assert_eq!(ids(&res), vec!["a", "tie", "c", "b"]);
273 assert_eq!(res[0].1, 5.0);
274 assert_eq!(res[3].1, 3.0);
275 }
276
277 #[test]
278 fn matches_brute_force_dot_product() {
279 let docs = [
280 ("a", sv(&[1, 5, 9], &[1.0, 2.0, 3.0])),
281 ("b", sv(&[9, 1, 7], &[10.0, 4.0, 1.0])),
282 ("c", sv(&[2, 4], &[5.0, 5.0])),
283 ("z", sv(&[100], &[5.0])), ];
285 let mut idx = SparseInvertedIndex::new();
286 for (id, v) in &docs {
287 idx.upsert(id, v);
288 }
289 let q = sv(&[1, 9, 4], &[1.5, 0.5, 2.0]);
290 let mut expected: Vec<(String, f32)> = docs
291 .iter()
292 .map(|(id, v)| ((*id).to_owned(), q.dot(v)))
293 .filter(|&(_, s)| s > 0.0)
294 .collect();
295 expected.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
296 assert_eq!(idx.search(&q), expected);
297 }
298
299 #[test]
300 fn reupsert_replaces_old_postings_without_double_counting() {
301 let mut idx = SparseInvertedIndex::new();
302 idx.upsert("a", &sv(&[1, 2], &[1.0, 1.0]));
303 idx.upsert("a", &sv(&[3], &[5.0]));
305 assert_eq!(idx.len(), 1);
306 assert!(idx.search(&sv(&[1, 2], &[1.0, 1.0])).is_empty());
307 let res = idx.search(&sv(&[3], &[2.0]));
308 assert_eq!(ids(&res), vec!["a"]);
309 assert_eq!(res[0].1, 10.0);
310 }
311
312 #[test]
313 fn remove_drops_from_results_and_reuses_the_slot() {
314 let mut idx = SparseInvertedIndex::new();
315 idx.upsert("a", &sv(&[1], &[1.0]));
316 idx.upsert("b", &sv(&[1], &[1.0]));
317 assert!(idx.remove("a"));
318 assert!(!idx.remove("a")); assert!(!idx.contains("a"));
320 assert_eq!(idx.len(), 1);
321 assert_eq!(ids(&idx.search(&sv(&[1], &[1.0]))), vec!["b"]);
322 let before = idx.ext_of.len();
324 idx.upsert("c", &sv(&[1], &[1.0]));
325 assert_eq!(idx.ext_of.len(), before);
326 assert_eq!(idx.len(), 2);
327 }
328
329 #[test]
330 fn query_sharing_no_dimension_scores_nothing() {
331 let mut idx = SparseInvertedIndex::new();
332 idx.upsert("a", &sv(&[1, 2], &[1.0, 1.0]));
333 assert!(idx.search(&sv(&[7, 8], &[1.0, 1.0])).is_empty());
334 }
335
336 #[test]
337 fn duplicate_dimensions_are_deduplicated_last_wins() {
338 let mut idx = SparseInvertedIndex::new();
339 idx.upsert("a", &sv(&[1, 1], &[2.0, 3.0]));
341 let res = idx.search(&sv(&[1, 1], &[5.0, 10.0]));
343 assert_eq!(res, vec![("a".to_owned(), 30.0)]);
344 }
345
346 #[test]
347 fn negative_and_zero_scores_are_dropped() {
348 let mut idx = SparseInvertedIndex::new();
349 idx.upsert("neg", &sv(&[1], &[-1.0]));
350 idx.upsert("zero", &sv(&[2], &[0.0]));
351 assert!(idx.search(&sv(&[1, 2], &[1.0, 1.0])).is_empty());
353 }
354
355 #[test]
356 fn empty_sparse_vector_is_a_live_doc_with_no_postings() {
357 let mut idx = SparseInvertedIndex::new();
358 idx.upsert("a", &sv(&[], &[]));
359 assert!(idx.contains("a"));
360 assert_eq!(idx.len(), 1);
361 assert!(idx.search(&sv(&[1], &[1.0])).is_empty());
362 }
363
364 #[test]
365 fn debug_is_derivable() {
366 let mut idx = SparseInvertedIndex::new();
367 idx.upsert("a", &sv(&[1], &[1.0]));
368 assert!(format!("{idx:?}").contains("SparseInvertedIndex"));
369 }
370
371 #[test]
372 fn bm25_ranks_by_term_frequency() {
373 let mut idx = SparseInvertedIndex::new();
374 idx.upsert("hi", &sv(&[1], &[3.0])); idx.upsert("lo", &sv(&[1], &[1.0])); idx.upsert("other", &sv(&[2], &[5.0])); let res = idx.bm25_search(&[1], BM25_K1, BM25_B);
378 assert_eq!(ids(&res), vec!["hi", "lo"], "other lacks the term");
379 assert!(res[0].1 > res[1].1);
380 }
381
382 #[test]
383 fn bm25_rewards_shorter_documents_at_equal_term_frequency() {
384 let mut idx = SparseInvertedIndex::new();
385 idx.upsert("short", &sv(&[1], &[2.0])); idx.upsert("long", &sv(&[1, 2], &[2.0, 8.0])); let res = idx.bm25_search(&[1], BM25_K1, BM25_B);
388 assert_eq!(
389 res[0].0, "short",
390 "length normalization favours the shorter doc"
391 );
392 }
393
394 #[test]
395 fn bm25_empty_index_and_unknown_terms_score_nothing() {
396 assert!(
397 SparseInvertedIndex::new()
398 .bm25_search(&[1], BM25_K1, BM25_B)
399 .is_empty()
400 );
401 let mut idx = SparseInvertedIndex::new();
402 idx.upsert("a", &sv(&[1], &[1.0]));
403 assert!(idx.bm25_search(&[999], BM25_K1, BM25_B).is_empty());
404 }
405
406 #[test]
407 fn bm25_deduplicates_query_terms() {
408 let mut idx = SparseInvertedIndex::new();
409 idx.upsert("a", &sv(&[1], &[1.0]));
410 idx.upsert("b", &sv(&[1, 2], &[1.0, 1.0]));
411 let once = idx.bm25_search(&[1], BM25_K1, BM25_B);
412 let twice = idx.bm25_search(&[1, 1, 1], BM25_K1, BM25_B);
413 assert_eq!(once, twice, "a repeated query term scores once");
414 }
415
416 #[test]
417 fn bm25_tracks_document_length_through_update_and_delete() {
418 let mut idx = SparseInvertedIndex::new();
419 idx.upsert("a", &sv(&[1, 2], &[1.0, 2.0])); assert_eq!(idx.total_len, 3.0);
421 idx.upsert("a", &sv(&[1], &[5.0])); assert_eq!(idx.total_len, 5.0);
423 idx.upsert("b", &sv(&[1], &[2.0])); assert_eq!(idx.total_len, 7.0);
425 assert!(idx.remove("a")); assert_eq!(idx.total_len, 2.0);
427 assert_eq!(idx.doclen[idx.slot_of["b"] as usize], 2.0);
428 }
429
430 #[test]
431 fn clear_postings_tolerates_a_dim_missing_from_postings() {
432 let mut idx = SparseInvertedIndex::new();
436 idx.upsert("a", &sv(&[1], &[1.0]));
437 let slot = idx.slot_of["a"];
438 idx.dims_of[slot as usize].push(42); assert!(idx.remove("a"));
440 assert!(idx.is_empty());
441 }
442}