1use crate::Result;
2use std::collections::HashMap;
3
4pub type DocumentId = String;
6
7#[derive(Debug, Clone)]
9pub struct Document {
10 pub id: DocumentId,
12 pub content: String,
14 pub metadata: HashMap<String, String>,
16}
17
18#[derive(Debug, Clone)]
20pub struct BM25Result {
21 pub doc_id: DocumentId,
23 pub score: f32,
25 pub content: String,
27}
28
29pub struct BM25Retriever {
31 k1: f32,
33 b: f32,
35 documents: HashMap<DocumentId, Document>,
37 term_frequencies: HashMap<String, HashMap<DocumentId, f32>>,
39 document_frequencies: HashMap<String, usize>,
41 document_lengths: HashMap<DocumentId, usize>,
43 avg_doc_length: f32,
45 total_docs: usize,
47}
48
49impl BM25Retriever {
50 pub fn new() -> Self {
52 Self::with_parameters(1.2, 0.75)
53 }
54
55 pub fn with_parameters(k1: f32, b: f32) -> Self {
57 Self {
58 k1,
59 b,
60 documents: HashMap::new(),
61 term_frequencies: HashMap::new(),
62 document_frequencies: HashMap::new(),
63 document_lengths: HashMap::new(),
64 avg_doc_length: 0.0,
65 total_docs: 0,
66 }
67 }
68
69 pub fn index_document(&mut self, document: Document) -> Result<()> {
71 let doc_id = document.id.clone();
72 let tokens = self.tokenize(&document.content);
73 let doc_length = tokens.len();
74
75 let mut term_freq: HashMap<String, usize> = HashMap::new();
77 for token in &tokens {
78 *term_freq.entry(token.clone()).or_insert(0) += 1;
79 }
80
81 for term in term_freq.keys() {
83 *self.document_frequencies.entry(term.clone()).or_insert(0) += 1;
84 }
85
86 for (term, freq) in term_freq {
88 let normalized_freq = freq as f32 / doc_length as f32;
89 self.term_frequencies
90 .entry(term)
91 .or_default()
92 .insert(doc_id.clone(), normalized_freq);
93 }
94
95 self.document_lengths.insert(doc_id.clone(), doc_length);
97 self.documents.insert(doc_id, document);
98 self.total_docs += 1;
99
100 self.update_avg_doc_length();
102
103 Ok(())
104 }
105
106 pub fn index_documents(&mut self, documents: &[Document]) -> Result<()> {
108 for document in documents {
109 self.index_document(document.clone())?;
110 }
111 Ok(())
112 }
113
114 pub fn search(&self, query: &str, limit: usize) -> Vec<BM25Result> {
116 if self.total_docs == 0 {
117 return Vec::new();
118 }
119
120 let query_tokens = self.tokenize(query);
121 let mut doc_scores: HashMap<DocumentId, f32> = HashMap::new();
122
123 for token in &query_tokens {
125 if let Some(doc_freqs) = self.term_frequencies.get(token) {
126 let idf = self.calculate_idf(token);
127
128 for (doc_id, tf) in doc_freqs {
129 let doc_length = *self.document_lengths.get(doc_id).unwrap_or(&0);
130 let bm25_term_score = self.calculate_bm25_term_score(*tf, doc_length, idf);
131
132 *doc_scores.entry(doc_id.clone()).or_insert(0.0) += bm25_term_score;
133 }
134 }
135 }
136
137 let mut results: Vec<BM25Result> = doc_scores
139 .into_iter()
140 .filter_map(|(doc_id, score)| {
141 self.documents.get(&doc_id).map(|doc| BM25Result {
142 doc_id: doc_id.clone(),
143 score,
144 content: doc.content.clone(),
145 })
146 })
147 .collect();
148
149 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
150 results.truncate(limit);
151 results
152 }
153
154 pub fn get_document(&self, doc_id: &DocumentId) -> Option<&Document> {
156 self.documents.get(doc_id)
157 }
158
159 pub fn document_count(&self) -> usize {
161 self.total_docs
162 }
163
164 pub fn term_count(&self) -> usize {
166 self.term_frequencies.len()
167 }
168
169 fn calculate_idf(&self, term: &str) -> f32 {
172 let doc_freq = self.document_frequencies.get(term).unwrap_or(&0);
173 if *doc_freq == 0 {
174 return 0.0;
175 }
176
177 (self.total_docs as f32 / *doc_freq as f32).ln() + 1.0
179 }
180
181 fn calculate_bm25_term_score(&self, tf: f32, doc_length: usize, idf: f32) -> f32 {
183 let tf_component = (tf * (self.k1 + 1.0))
184 / (tf + self.k1 * (1.0 - self.b + self.b * (doc_length as f32 / self.avg_doc_length)));
185
186 idf * tf_component
187 }
188
189 fn update_avg_doc_length(&mut self) {
191 if self.total_docs > 0 {
192 let total_length: usize = self.document_lengths.values().sum();
193 self.avg_doc_length = total_length as f32 / self.total_docs as f32;
194 }
195 }
196
197 fn tokenize(&self, text: &str) -> Vec<String> {
199 text.to_lowercase()
200 .split_whitespace()
201 .map(|s| {
202 s.chars()
204 .filter(|c| c.is_alphanumeric())
205 .collect::<String>()
206 })
207 .filter(|s| !s.is_empty() && s.len() > 2 && !self.is_stop_word(s))
208 .collect()
209 }
210
211 fn is_stop_word(&self, word: &str) -> bool {
213 const STOP_WORDS: &[&str] = &[
214 "the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", "for", "not",
215 "on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
216 "they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
217 "there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
218 "go", "me", "when", "make", "can", "like", "time", "no", "just", "him", "know", "take",
219 "people", "into", "year", "your", "good", "some", "could", "them", "see", "other",
220 "than", "then", "now", "look", "only", "come", "its", "over", "think", "also", "back",
221 "after", "use", "two", "how", "our", "work", "first", "well", "way", "even", "new",
222 "want", "because", "any", "these", "give", "day", "most", "us",
223 ];
224 STOP_WORDS.contains(&word)
225 }
226
227 pub fn clear(&mut self) {
229 self.documents.clear();
230 self.term_frequencies.clear();
231 self.document_frequencies.clear();
232 self.document_lengths.clear();
233 self.avg_doc_length = 0.0;
234 self.total_docs = 0;
235 }
236
237 pub fn get_statistics(&self) -> BM25Statistics {
239 BM25Statistics {
240 total_documents: self.total_docs,
241 total_terms: self.term_frequencies.len(),
242 avg_doc_length: self.avg_doc_length,
243 parameters: BM25Parameters {
244 k1: self.k1,
245 b: self.b,
246 },
247 }
248 }
249}
250
251impl Default for BM25Retriever {
252 fn default() -> Self {
253 Self::new()
254 }
255}
256
257#[derive(Debug, Clone)]
259pub struct BM25Parameters {
260 pub k1: f32,
262 pub b: f32,
264}
265
266#[derive(Debug, Clone)]
268pub struct BM25Statistics {
269 pub total_documents: usize,
271 pub total_terms: usize,
273 pub avg_doc_length: f32,
275 pub parameters: BM25Parameters,
277}
278
279impl BM25Statistics {
280 pub fn print(&self) {
282 println!("BM25 Index Statistics:");
283 println!(" Total documents: {}", self.total_documents);
284 println!(" Total terms: {}", self.total_terms);
285 println!(
286 " Average document length: {:.2} tokens",
287 self.avg_doc_length
288 );
289 println!(
290 " Parameters: k1={:.2}, b={:.2}",
291 self.parameters.k1, self.parameters.b
292 );
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 fn create_test_documents() -> Vec<Document> {
301 vec![
302 Document {
303 id: "doc1".to_string(),
304 content: "The quick brown fox jumps over the lazy dog".to_string(),
305 metadata: HashMap::new(),
306 },
307 Document {
308 id: "doc2".to_string(),
309 content: "A fast brown animal leaps across a sleeping canine".to_string(),
310 metadata: HashMap::new(),
311 },
312 Document {
313 id: "doc3".to_string(),
314 content: "The weather is nice today".to_string(),
315 metadata: HashMap::new(),
316 },
317 ]
318 }
319
320 #[test]
321 fn test_bm25_creation() {
322 let retriever = BM25Retriever::new();
323 assert_eq!(retriever.document_count(), 0);
324 assert_eq!(retriever.term_count(), 0);
325 }
326
327 #[test]
328 fn test_document_indexing() {
329 let mut retriever = BM25Retriever::new();
330 let docs = create_test_documents();
331
332 retriever.index_documents(&docs).unwrap();
333
334 assert_eq!(retriever.document_count(), 3);
335 assert!(retriever.term_count() > 0);
336 }
337
338 #[test]
339 fn test_search() {
340 let mut retriever = BM25Retriever::new();
341 let docs = create_test_documents();
342
343 retriever.index_documents(&docs).unwrap();
344
345 let results = retriever.search("brown fox", 10);
346 assert!(!results.is_empty());
347
348 assert_eq!(results[0].doc_id, "doc1");
350 assert!(results[0].score > 0.0);
351 }
352
353 #[test]
354 fn test_tokenization() {
355 let retriever = BM25Retriever::new();
356 let tokens = retriever.tokenize("The quick, brown fox!");
357
358 assert!(tokens.contains(&"quick".to_string()));
360 assert!(tokens.contains(&"brown".to_string()));
361 assert!(tokens.contains(&"fox".to_string()));
362 assert!(!tokens.contains(&"the".to_string())); }
364
365 #[test]
366 fn test_empty_search() {
367 let retriever = BM25Retriever::new();
368 let results = retriever.search("test", 10);
369 assert!(results.is_empty());
370 }
371}