1use crate::Result;
2use std::collections::HashMap;
3
4pub type DocumentId = String;
6
7#[derive(Debug, Clone)]
9pub struct Document {
10 pub id: DocumentId,
12 pub content: String,
14 pub metadata: HashMap<String, String>,
16}
17
18#[derive(Debug, Clone)]
20pub struct BM25Result {
21 pub doc_id: DocumentId,
23 pub score: f32,
25 pub content: String,
27}
28
29pub struct BM25Retriever {
31 k1: f32,
33 b: f32,
35 documents: HashMap<DocumentId, Document>,
37 term_frequencies: HashMap<String, HashMap<DocumentId, f32>>,
39 document_frequencies: HashMap<String, usize>,
41 document_lengths: HashMap<DocumentId, usize>,
43 avg_doc_length: f32,
45 total_docs: usize,
47}
48
49impl BM25Retriever {
50 pub fn new() -> Self {
52 Self::with_parameters(1.2, 0.75)
53 }
54
55 pub fn with_parameters(k1: f32, b: f32) -> Self {
57 Self {
58 k1,
59 b,
60 documents: HashMap::new(),
61 term_frequencies: HashMap::new(),
62 document_frequencies: HashMap::new(),
63 document_lengths: HashMap::new(),
64 avg_doc_length: 0.0,
65 total_docs: 0,
66 }
67 }
68
69 pub fn index_document(&mut self, document: Document) -> Result<()> {
71 let doc_id = document.id.clone();
72 let tokens = self.tokenize(&document.content);
73 let doc_length = tokens.len();
74
75 let mut term_freq: HashMap<String, usize> = HashMap::new();
77 for token in &tokens {
78 *term_freq.entry(token.clone()).or_insert(0) += 1;
79 }
80
81 for term in term_freq.keys() {
83 *self.document_frequencies.entry(term.clone()).or_insert(0) += 1;
84 }
85
86 for (term, freq) in term_freq {
88 let normalized_freq = freq as f32 / doc_length as f32;
89 self.term_frequencies
90 .entry(term)
91 .or_default()
92 .insert(doc_id.clone(), normalized_freq);
93 }
94
95 self.document_lengths.insert(doc_id.clone(), doc_length);
97 self.documents.insert(doc_id, document);
98 self.total_docs += 1;
99
100 self.update_avg_doc_length();
102
103 Ok(())
104 }
105
106 pub fn index_documents(&mut self, documents: &[Document]) -> Result<()> {
108 for document in documents {
109 self.index_document(document.clone())?;
110 }
111 Ok(())
112 }
113
114 pub fn search(&self, query: &str, limit: usize) -> Vec<BM25Result> {
116 if self.total_docs == 0 {
117 return Vec::new();
118 }
119
120 let query_tokens = self.tokenize(query);
121 let mut doc_scores: HashMap<DocumentId, f32> = HashMap::new();
122
123 for token in &query_tokens {
125 if let Some(doc_freqs) = self.term_frequencies.get(token) {
126 let idf = self.calculate_idf(token);
127
128 for (doc_id, tf) in doc_freqs {
129 let doc_length = *self.document_lengths.get(doc_id).unwrap_or(&0);
130 let bm25_term_score = self.calculate_bm25_term_score(*tf, doc_length, idf);
131
132 *doc_scores.entry(doc_id.clone()).or_insert(0.0) += bm25_term_score;
133 }
134 }
135 }
136
137 let mut results: Vec<BM25Result> = doc_scores
139 .into_iter()
140 .filter_map(|(doc_id, score)| {
141 self.documents.get(&doc_id).map(|doc| BM25Result {
142 doc_id: doc_id.clone(),
143 score,
144 content: doc.content.clone(),
145 })
146 })
147 .collect();
148
149 results.sort_by(|a, b| {
150 b.score
151 .partial_cmp(&a.score)
152 .unwrap_or(std::cmp::Ordering::Equal)
153 });
154 results.truncate(limit);
155 results
156 }
157
158 pub fn get_document(&self, doc_id: &DocumentId) -> Option<&Document> {
160 self.documents.get(doc_id)
161 }
162
163 pub fn document_count(&self) -> usize {
165 self.total_docs
166 }
167
168 pub fn term_count(&self) -> usize {
170 self.term_frequencies.len()
171 }
172
173 fn calculate_idf(&self, term: &str) -> f32 {
176 let doc_freq = self.document_frequencies.get(term).unwrap_or(&0);
177 if *doc_freq == 0 {
178 return 0.0;
179 }
180
181 (self.total_docs as f32 / *doc_freq as f32).ln() + 1.0
183 }
184
185 fn calculate_bm25_term_score(&self, tf: f32, doc_length: usize, idf: f32) -> f32 {
187 let tf_component = (tf * (self.k1 + 1.0))
188 / (tf + self.k1 * (1.0 - self.b + self.b * (doc_length as f32 / self.avg_doc_length)));
189
190 idf * tf_component
191 }
192
193 fn update_avg_doc_length(&mut self) {
195 if self.total_docs > 0 {
196 let total_length: usize = self.document_lengths.values().sum();
197 self.avg_doc_length = total_length as f32 / self.total_docs as f32;
198 }
199 }
200
201 fn tokenize(&self, text: &str) -> Vec<String> {
203 text.to_lowercase()
204 .split_whitespace()
205 .map(|s| {
206 s.chars()
208 .filter(|c| c.is_alphanumeric())
209 .collect::<String>()
210 })
211 .filter(|s| !s.is_empty() && s.len() > 2 && !self.is_stop_word(s))
212 .collect()
213 }
214
215 fn is_stop_word(&self, word: &str) -> bool {
217 const STOP_WORDS: &[&str] = &[
218 "the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", "for", "not",
219 "on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
220 "they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
221 "there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
222 "go", "me", "when", "make", "can", "like", "time", "no", "just", "him", "know", "take",
223 "people", "into", "year", "your", "good", "some", "could", "them", "see", "other",
224 "than", "then", "now", "look", "only", "come", "its", "over", "think", "also", "back",
225 "after", "use", "two", "how", "our", "work", "first", "well", "way", "even", "new",
226 "want", "because", "any", "these", "give", "day", "most", "us",
227 ];
228 STOP_WORDS.contains(&word)
229 }
230
231 pub fn clear(&mut self) {
233 self.documents.clear();
234 self.term_frequencies.clear();
235 self.document_frequencies.clear();
236 self.document_lengths.clear();
237 self.avg_doc_length = 0.0;
238 self.total_docs = 0;
239 }
240
241 pub fn get_statistics(&self) -> BM25Statistics {
243 BM25Statistics {
244 total_documents: self.total_docs,
245 total_terms: self.term_frequencies.len(),
246 avg_doc_length: self.avg_doc_length,
247 parameters: BM25Parameters {
248 k1: self.k1,
249 b: self.b,
250 },
251 }
252 }
253}
254
255impl Default for BM25Retriever {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261#[derive(Debug, Clone)]
263pub struct BM25Parameters {
264 pub k1: f32,
266 pub b: f32,
268}
269
270#[derive(Debug, Clone)]
272pub struct BM25Statistics {
273 pub total_documents: usize,
275 pub total_terms: usize,
277 pub avg_doc_length: f32,
279 pub parameters: BM25Parameters,
281}
282
283impl BM25Statistics {
284 pub fn print(&self) {
286 println!("BM25 Index Statistics:");
287 println!(" Total documents: {}", self.total_documents);
288 println!(" Total terms: {}", self.total_terms);
289 println!(
290 " Average document length: {:.2} tokens",
291 self.avg_doc_length
292 );
293 println!(
294 " Parameters: k1={:.2}, b={:.2}",
295 self.parameters.k1, self.parameters.b
296 );
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 fn create_test_documents() -> Vec<Document> {
305 vec![
306 Document {
307 id: "doc1".to_string(),
308 content: "The quick brown fox jumps over the lazy dog".to_string(),
309 metadata: HashMap::new(),
310 },
311 Document {
312 id: "doc2".to_string(),
313 content: "A fast brown animal leaps across a sleeping canine".to_string(),
314 metadata: HashMap::new(),
315 },
316 Document {
317 id: "doc3".to_string(),
318 content: "The weather is nice today".to_string(),
319 metadata: HashMap::new(),
320 },
321 ]
322 }
323
324 #[test]
325 fn test_bm25_creation() {
326 let retriever = BM25Retriever::new();
327 assert_eq!(retriever.document_count(), 0);
328 assert_eq!(retriever.term_count(), 0);
329 }
330
331 #[test]
332 fn test_document_indexing() {
333 let mut retriever = BM25Retriever::new();
334 let docs = create_test_documents();
335
336 retriever.index_documents(&docs).unwrap();
337
338 assert_eq!(retriever.document_count(), 3);
339 assert!(retriever.term_count() > 0);
340 }
341
342 #[test]
343 fn test_search() {
344 let mut retriever = BM25Retriever::new();
345 let docs = create_test_documents();
346
347 retriever.index_documents(&docs).unwrap();
348
349 let results = retriever.search("brown fox", 10);
350 assert!(!results.is_empty());
351
352 assert_eq!(results[0].doc_id, "doc1");
354 assert!(results[0].score > 0.0);
355 }
356
357 #[test]
358 fn test_tokenization() {
359 let retriever = BM25Retriever::new();
360 let tokens = retriever.tokenize("The quick, brown fox!");
361
362 assert!(tokens.contains(&"quick".to_string()));
364 assert!(tokens.contains(&"brown".to_string()));
365 assert!(tokens.contains(&"fox".to_string()));
366 assert!(!tokens.contains(&"the".to_string())); }
368
369 #[test]
370 fn test_empty_search() {
371 let retriever = BM25Retriever::new();
372 let results = retriever.search("test", 10);
373 assert!(results.is_empty());
374 }
375}