1use crate::{Document, RragResult, SearchResult};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct BM25Config {
16 pub k1: f32,
18
19 pub b: f32,
21
22 pub tokenizer: TokenizerType,
24
25 pub min_token_length: usize,
27
28 pub max_token_length: usize,
30
31 pub use_stemming: bool,
33
34 pub remove_stopwords: bool,
36
37 pub custom_stopwords: Option<HashSet<String>>,
39}
40
41impl Default for BM25Config {
42 fn default() -> Self {
43 Self {
44 k1: 1.2,
45 b: 0.75,
46 tokenizer: TokenizerType::Standard,
47 min_token_length: 2,
48 max_token_length: 50,
49 use_stemming: true,
50 remove_stopwords: true,
51 custom_stopwords: None,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub enum TokenizerType {
59 Standard,
61 NGram(usize),
63 Language(String),
65}
66
67#[derive(Debug, Clone)]
69struct BM25Document {
70 id: String,
72
73 content: String,
75
76 term_frequencies: HashMap<String, f32>,
78
79 length: usize,
81
82 metadata: HashMap<String, serde_json::Value>,
84}
85
86pub struct BM25Retriever {
88 config: BM25Config,
90
91 documents: Arc<RwLock<HashMap<String, BM25Document>>>,
93
94 inverted_index: Arc<RwLock<HashMap<String, HashSet<String>>>>,
96
97 document_frequencies: Arc<RwLock<HashMap<String, usize>>>,
99
100 avg_doc_length: Arc<RwLock<f32>>,
102
103 total_docs: Arc<RwLock<usize>>,
105
106 stop_words: HashSet<String>,
108}
109
110impl BM25Retriever {
111 pub fn new(config: BM25Config) -> Self {
113 let stop_words = if config.remove_stopwords {
114 Self::default_stop_words()
115 } else {
116 HashSet::new()
117 };
118
119 Self {
120 config,
121 documents: Arc::new(RwLock::new(HashMap::new())),
122 inverted_index: Arc::new(RwLock::new(HashMap::new())),
123 document_frequencies: Arc::new(RwLock::new(HashMap::new())),
124 avg_doc_length: Arc::new(RwLock::new(0.0)),
125 total_docs: Arc::new(RwLock::new(0)),
126 stop_words,
127 }
128 }
129
130 pub async fn index_document(&self, doc: &Document) -> RragResult<()> {
132 let tokens = self.tokenize(&doc.content);
133 let term_frequencies = self.calculate_term_frequencies(&tokens);
134
135 let bm25_doc = BM25Document {
136 id: doc.id.clone(),
137 content: doc.content.to_string(),
138 term_frequencies: term_frequencies.clone(),
139 length: tokens.len(),
140 metadata: doc.metadata.clone(),
141 };
142
143 let mut documents = self.documents.write().await;
145 documents.insert(doc.id.clone(), bm25_doc);
146
147 let mut inverted_index = self.inverted_index.write().await;
149 let mut doc_frequencies = self.document_frequencies.write().await;
150
151 for term in term_frequencies.keys() {
152 inverted_index
153 .entry(term.clone())
154 .or_insert_with(HashSet::new)
155 .insert(doc.id.clone());
156
157 *doc_frequencies.entry(term.clone()).or_insert(0) += 1;
158 }
159
160 let mut total_docs = self.total_docs.write().await;
162 *total_docs += 1;
163
164 let mut avg_length = self.avg_doc_length.write().await;
165 *avg_length =
166 (*avg_length * (*total_docs - 1) as f32 + tokens.len() as f32) / *total_docs as f32;
167
168 Ok(())
169 }
170
171 pub async fn index_batch(&self, documents: Vec<Document>) -> RragResult<()> {
173 for doc in documents {
174 self.index_document(&doc).await?;
175 }
176 Ok(())
177 }
178
179 pub async fn search(&self, query: &str, limit: usize) -> RragResult<Vec<SearchResult>> {
181 let query_tokens = self.tokenize(query);
182 if query_tokens.is_empty() {
183 return Ok(Vec::new());
184 }
185
186 let documents = self.documents.read().await;
187 let inverted_index = self.inverted_index.read().await;
188 let doc_frequencies = self.document_frequencies.read().await;
189 let avg_length = *self.avg_doc_length.read().await;
190 let total_docs = *self.total_docs.read().await;
191
192 let mut scores: HashMap<String, f32> = HashMap::new();
193
194 for term in &query_tokens {
196 if let Some(doc_ids) = inverted_index.get(term) {
197 let df = doc_frequencies.get(term).copied().unwrap_or(0) as f32;
198 let idf = ((total_docs as f32 - df + 0.5) / (df + 0.5) + 1.0).ln();
199
200 for doc_id in doc_ids {
201 if let Some(doc) = documents.get(doc_id) {
202 let tf = doc.term_frequencies.get(term).copied().unwrap_or(0.0);
203 let doc_length = doc.length as f32;
204
205 let numerator = tf * (self.config.k1 + 1.0);
207 let denominator = tf
208 + self.config.k1
209 * (1.0 - self.config.b + self.config.b * (doc_length / avg_length));
210 let score = idf * (numerator / denominator);
211
212 *scores.entry(doc_id.clone()).or_insert(0.0) += score;
213 }
214 }
215 }
216 }
217
218 let mut results: Vec<_> = scores
220 .into_iter()
221 .filter_map(|(doc_id, score)| {
222 documents.get(&doc_id).map(|doc| SearchResult {
223 id: doc_id.clone(),
224 content: doc.content.clone(),
225 score: score / query_tokens.len() as f32, rank: 0,
227 metadata: doc.metadata.clone(),
228 embedding: None,
229 })
230 })
231 .collect();
232
233 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
234 results.truncate(limit);
235
236 for (i, result) in results.iter_mut().enumerate() {
238 result.rank = i;
239 }
240
241 Ok(results)
242 }
243
244 fn tokenize(&self, text: &str) -> Vec<String> {
246 let lowercase = text.to_lowercase();
247 let tokens: Vec<String> = match &self.config.tokenizer {
248 TokenizerType::Standard => lowercase
249 .split(|c: char| !c.is_alphanumeric())
250 .filter(|s| !s.is_empty())
251 .filter(|s| s.len() >= self.config.min_token_length)
252 .filter(|s| s.len() <= self.config.max_token_length)
253 .filter(|s| !self.stop_words.contains(*s))
254 .map(|s| {
255 if self.config.use_stemming {
256 Self::simple_stem(s)
257 } else {
258 s.to_string()
259 }
260 })
261 .collect(),
262 TokenizerType::NGram(n) => {
263 let chars: Vec<char> = lowercase.chars().collect();
265 let mut ngrams = Vec::new();
266 for i in 0..chars.len().saturating_sub(n - 1) {
267 let ngram: String = chars[i..i + n].iter().collect();
268 if !ngram.trim().is_empty() {
269 ngrams.push(ngram);
270 }
271 }
272 ngrams
273 }
274 TokenizerType::Language(ref _lang) => {
275 lowercase
278 .split_whitespace()
279 .filter(|s| !self.stop_words.contains(*s))
280 .map(String::from)
281 .collect()
282 }
283 };
284
285 tokens
286 }
287
288 fn calculate_term_frequencies(&self, tokens: &[String]) -> HashMap<String, f32> {
290 let mut frequencies = HashMap::new();
291 let total = tokens.len() as f32;
292
293 for token in tokens {
294 *frequencies.entry(token.clone()).or_insert(0.0) += 1.0;
295 }
296
297 for freq in frequencies.values_mut() {
299 *freq /= total;
300 }
301
302 frequencies
303 }
304
305 fn simple_stem(word: &str) -> String {
307 let mut stem = word.to_string();
308
309 let suffixes = ["ing", "ed", "es", "s", "ly", "er", "est", "ness", "ment"];
311 for suffix in &suffixes {
312 if stem.len() > suffix.len() + 3 && stem.ends_with(suffix) {
313 stem.truncate(stem.len() - suffix.len());
314 break;
315 }
316 }
317
318 stem
319 }
320
321 fn default_stop_words() -> HashSet<String> {
323 let words = vec![
324 "a", "an", "and", "are", "as", "at", "be", "been", "by", "for", "from", "has", "have",
325 "he", "in", "is", "it", "its", "of", "on", "that", "the", "to", "was", "will", "with",
326 "the", "this", "these", "those", "i", "you", "we", "they", "them", "their", "what",
327 "which", "who", "when", "where", "why", "how", "all", "would", "there", "could",
328 ];
329
330 words.into_iter().map(String::from).collect()
331 }
332
333 pub async fn clear(&self) -> RragResult<()> {
335 let mut documents = self.documents.write().await;
336 let mut inverted_index = self.inverted_index.write().await;
337 let mut doc_frequencies = self.document_frequencies.write().await;
338 let mut avg_length = self.avg_doc_length.write().await;
339 let mut total_docs = self.total_docs.write().await;
340
341 documents.clear();
342 inverted_index.clear();
343 doc_frequencies.clear();
344 *avg_length = 0.0;
345 *total_docs = 0;
346
347 Ok(())
348 }
349
350 pub async fn stats(&self) -> HashMap<String, serde_json::Value> {
352 let documents = self.documents.read().await;
353 let inverted_index = self.inverted_index.read().await;
354 let total_docs = *self.total_docs.read().await;
355 let avg_length = *self.avg_doc_length.read().await;
356
357 let mut stats = HashMap::new();
358 stats.insert("total_documents".to_string(), total_docs.into());
359 stats.insert("unique_terms".to_string(), inverted_index.len().into());
360 stats.insert("average_document_length".to_string(), avg_length.into());
361 stats.insert(
362 "index_size_bytes".to_string(),
363 (documents.len() * std::mem::size_of::<BM25Document>()).into(),
364 );
365
366 stats
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[tokio::test]
375 async fn test_bm25_indexing_and_search() {
376 let retriever = BM25Retriever::new(BM25Config::default());
377
378 let docs = vec![
379 Document::with_id("1", "The quick brown fox jumps over the lazy dog"),
380 Document::with_id("2", "A quick brown dog runs through the forest"),
381 Document::with_id("3", "The lazy cat sleeps in the warm sunshine"),
382 ];
383
384 retriever.index_batch(docs).await.unwrap();
385
386 let results = retriever.search("quick brown", 2).await.unwrap();
387 assert_eq!(results.len(), 2);
388 assert!(results[0].score > results[1].score);
389 }
390}