1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
/// association doc_id -> term frequency
type DocFrequency = Vec<(String, u32)>;
#[derive(Debug, Serialize, Deserialize)]
pub struct Bm25Index {
term_to_counts: HashMap<String, DocFrequency>, // maps every known term to the docs that contain it — e.g. {"rust": [("doc1", 3), ("doc2", 1)]}
doc_lengths: HashMap<String, u32>, // total token count per document, used to normalise tf — e.g. {"doc1": 120, "doc2": 45}
avg_dl: f32, // average document length across all docs, recomputed on insert/remove — e.g. 82.5
n_docs: u32, // total number of indexed documents, used in IDF — e.g. 1000
}
impl Bm25Index {
pub fn new() -> Self {
Bm25Index {
term_to_counts: HashMap::new(),
doc_lengths: HashMap::new(),
avg_dl: 0.0,
n_docs: 0,
}
}
pub fn load(path: &str) -> Self {
match std::fs::read_to_string(path) {
Ok(contents) => serde_json::from_str(&contents).unwrap(),
Err(_) => Self::new(),
}
}
pub fn save(&self, path: &str) {
let contents = serde_json::to_string(self).unwrap();
std::fs::write(path, contents).unwrap();
}
// index a record by calculating the term frequency of each token and updating the term_to_counts and doc_lengths mappings accordingly
pub fn index_record(&mut self, doc_id: &str, tokens: &[String]) {
// count term frequency for each token in this document
let mut term_frequencies: HashMap<&str, u32> = HashMap::new();
for token in tokens {
if term_frequencies.contains_key(token.as_str()) {
term_frequencies.insert(token.as_str(), term_frequencies.get(token.as_str()).copied().unwrap() + 1);
} else {
term_frequencies.insert(token.as_str(), 1);
}
}
// add the doc to the term to counts mapping
for (term, term_frequency) in &term_frequencies {
self.term_to_counts
.entry(term.to_string())
.or_default()
.push((doc_id.to_string(), *term_frequency));
}
// add the doc length to the mapping
self.doc_lengths.insert(doc_id.to_string(), tokens.len() as u32);
// update number of docs and average length
self.n_docs += 1;
self.avg_dl = self.doc_lengths.values().sum::<u32>() as f32 / self.n_docs as f32;
}
// remove a record by removing its doc_id from the term_to_counts mapping for each token, and removing its length from the doc_lengths mapping
pub fn remove_record(&mut self, doc_id: &str, tokens: &[String]) {
// remove the doc from the term to counts mapping, starting from each token
for token in tokens {
if let Some(doc_freqs) = self.term_to_counts.get_mut(token) {
doc_freqs.retain(|(id, _)| id != doc_id);
if doc_freqs.is_empty() {
self.term_to_counts.remove(token);
}
}
}
// remove the doc length from the mapping
if let Some(length) = self.doc_lengths.remove(doc_id) {
self.n_docs -= 1;
if self.n_docs > 0 {
// update average document length using the formula: new_avg = (old_avg * old_n - removed_length) / new_n
self.avg_dl = (self.avg_dl * (self.n_docs as f32 + 1.0) - length as f32) / self.n_docs as f32;
} else {
self.avg_dl = 0.0;
}
}
}
/// Returns (doc_id, bm25_score) for all documents containing at least one query term.
pub fn score(&self, query_tokens: &[String]) -> Vec<(String, f32)> {
const K: f32 = 1.5; // How much term-frequency impacts result
const B: f32 = 0.75; // How much long documents are penalized?
let mut scores: HashMap<String, f32> = HashMap::new();
for token in query_tokens {
// get the documents that contains this token and their term frequency
let Some(postings) = self.term_to_counts.get(token) else { continue };
let documents_count = postings.len() as f32;
// IDF describes how rare a term is across all documents, and is used to give more weight to rare terms.
// high IDF = rare term, low IDF = common term. The 0.5 is a smoothing factor to prevent division by zero.
let idf_numerator = self.n_docs as f32 - documents_count + 0.5; // docs that do NOT contain this term
let idf_denominator = documents_count + 0.5; // docs that DO contain this term
let idf = (idf_numerator / idf_denominator + 1.0).ln(); // +1 ensures idf > 0 even for very common terms, the ln helps to compress the range of idf values (very rare terms will have very high idf, and we don't want them to dominate the scores)
// given each doc and its term frequency for this token
for (doc_id, term_frequency) in postings {
// get total document length, to penalize long documents
let doc_len = *self.doc_lengths.get(doc_id).unwrap_or(&0) as f32;
let term_frequency = *term_frequency as f32;
let tf_numerator = term_frequency * (K + 1.0); // raw frequency scaled up — would grow unbounded without the denominator
let length_norm = 1.0 - B + B * doc_len / self.avg_dl; // start from a base of 1 - B, then multiply the long document penalizer for the document lenght, then get the normalized lenght > 1 if doc is longer than average, < 1 if shorter
let tf_denominator = term_frequency + K * length_norm; // an high term frequency will be scaled down by the denominator, and long documents will have a higher denominator, thus penalizing them
let tf_norm = tf_numerator / tf_denominator;
*scores.entry(doc_id.clone()).or_insert(0.0) += idf * tf_norm;
}
}
let mut results: Vec<(String, f32)> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results
}
}