rexis_rag/retrieval/
bm25.rs

1//! # BM25 Keyword-based Retrieval
2//!
3//! Implementation of the BM25 algorithm for keyword-based document retrieval.
4//! BM25 is a probabilistic retrieval model that ranks documents based on term frequency
5//! and inverse document frequency.
6
7use crate::{Document, RragResult, SearchResult};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13/// BM25 retriever configuration
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct BM25Config {
16    /// k1 parameter: controls term frequency saturation (typically 1.2-2.0)
17    pub k1: f32,
18
19    /// b parameter: controls length normalization (typically 0.75)
20    pub b: f32,
21
22    /// Tokenizer type to use
23    pub tokenizer: TokenizerType,
24
25    /// Minimum token length to index
26    pub min_token_length: usize,
27
28    /// Maximum token length to index
29    pub max_token_length: usize,
30
31    /// Whether to use stemming
32    pub use_stemming: bool,
33
34    /// Whether to remove stop words
35    pub remove_stopwords: bool,
36
37    /// Custom stop words list
38    pub custom_stopwords: Option<HashSet<String>>,
39}
40
41impl Default for BM25Config {
42    fn default() -> Self {
43        Self {
44            k1: 1.2,
45            b: 0.75,
46            tokenizer: TokenizerType::Standard,
47            min_token_length: 2,
48            max_token_length: 50,
49            use_stemming: true,
50            remove_stopwords: true,
51            custom_stopwords: None,
52        }
53    }
54}
55
56/// Tokenizer types for text processing
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub enum TokenizerType {
59    /// Standard whitespace and punctuation tokenizer
60    Standard,
61    /// N-gram based tokenizer
62    NGram(usize),
63    /// Language-specific tokenizer
64    Language(String),
65}
66
67/// BM25 index entry for a document
68#[derive(Debug, Clone)]
69struct BM25Document {
70    /// Document ID
71    id: String,
72
73    /// Original content
74    content: String,
75
76    /// Tokenized terms with frequencies
77    term_frequencies: HashMap<String, f32>,
78
79    /// Document length (number of tokens)
80    length: usize,
81
82    /// Additional metadata
83    metadata: HashMap<String, serde_json::Value>,
84}
85
86/// BM25 retriever implementation
87pub struct BM25Retriever {
88    /// Configuration
89    config: BM25Config,
90
91    /// Document storage
92    documents: Arc<RwLock<HashMap<String, BM25Document>>>,
93
94    /// Inverted index: term -> document IDs
95    inverted_index: Arc<RwLock<HashMap<String, HashSet<String>>>>,
96
97    /// Document frequencies for each term
98    document_frequencies: Arc<RwLock<HashMap<String, usize>>>,
99
100    /// Average document length
101    avg_doc_length: Arc<RwLock<f32>>,
102
103    /// Total number of documents
104    total_docs: Arc<RwLock<usize>>,
105
106    /// Stop words set
107    stop_words: HashSet<String>,
108}
109
110impl BM25Retriever {
111    /// Create a new BM25 retriever with configuration
112    pub fn new(config: BM25Config) -> Self {
113        let stop_words = if config.remove_stopwords {
114            Self::default_stop_words()
115        } else {
116            HashSet::new()
117        };
118
119        Self {
120            config,
121            documents: Arc::new(RwLock::new(HashMap::new())),
122            inverted_index: Arc::new(RwLock::new(HashMap::new())),
123            document_frequencies: Arc::new(RwLock::new(HashMap::new())),
124            avg_doc_length: Arc::new(RwLock::new(0.0)),
125            total_docs: Arc::new(RwLock::new(0)),
126            stop_words,
127        }
128    }
129
130    /// Index a document
131    pub async fn index_document(&self, doc: &Document) -> RragResult<()> {
132        let tokens = self.tokenize(&doc.content);
133        let term_frequencies = self.calculate_term_frequencies(&tokens);
134
135        let bm25_doc = BM25Document {
136            id: doc.id.clone(),
137            content: doc.content.to_string(),
138            term_frequencies: term_frequencies.clone(),
139            length: tokens.len(),
140            metadata: doc.metadata.clone(),
141        };
142
143        // Update document storage
144        let mut documents = self.documents.write().await;
145        documents.insert(doc.id.clone(), bm25_doc);
146
147        // Update inverted index
148        let mut inverted_index = self.inverted_index.write().await;
149        let mut doc_frequencies = self.document_frequencies.write().await;
150
151        for term in term_frequencies.keys() {
152            inverted_index
153                .entry(term.clone())
154                .or_insert_with(HashSet::new)
155                .insert(doc.id.clone());
156
157            *doc_frequencies.entry(term.clone()).or_insert(0) += 1;
158        }
159
160        // Update statistics
161        let mut total_docs = self.total_docs.write().await;
162        *total_docs += 1;
163
164        let mut avg_length = self.avg_doc_length.write().await;
165        *avg_length =
166            (*avg_length * (*total_docs - 1) as f32 + tokens.len() as f32) / *total_docs as f32;
167
168        Ok(())
169    }
170
171    /// Index multiple documents in batch
172    pub async fn index_batch(&self, documents: Vec<Document>) -> RragResult<()> {
173        for doc in documents {
174            self.index_document(&doc).await?;
175        }
176        Ok(())
177    }
178
179    /// Search using BM25 algorithm
180    pub async fn search(&self, query: &str, limit: usize) -> RragResult<Vec<SearchResult>> {
181        let query_tokens = self.tokenize(query);
182        if query_tokens.is_empty() {
183            return Ok(Vec::new());
184        }
185
186        let documents = self.documents.read().await;
187        let inverted_index = self.inverted_index.read().await;
188        let doc_frequencies = self.document_frequencies.read().await;
189        let avg_length = *self.avg_doc_length.read().await;
190        let total_docs = *self.total_docs.read().await;
191
192        let mut scores: HashMap<String, f32> = HashMap::new();
193
194        // Calculate BM25 scores for each document
195        for term in &query_tokens {
196            if let Some(doc_ids) = inverted_index.get(term) {
197                let df = doc_frequencies.get(term).copied().unwrap_or(0) as f32;
198                let idf = ((total_docs as f32 - df + 0.5) / (df + 0.5) + 1.0).ln();
199
200                for doc_id in doc_ids {
201                    if let Some(doc) = documents.get(doc_id) {
202                        let tf = doc.term_frequencies.get(term).copied().unwrap_or(0.0);
203                        let doc_length = doc.length as f32;
204
205                        // BM25 formula
206                        let numerator = tf * (self.config.k1 + 1.0);
207                        let denominator = tf
208                            + self.config.k1
209                                * (1.0 - self.config.b + self.config.b * (doc_length / avg_length));
210                        let score = idf * (numerator / denominator);
211
212                        *scores.entry(doc_id.clone()).or_insert(0.0) += score;
213                    }
214                }
215            }
216        }
217
218        // Sort by score and return top results
219        let mut results: Vec<_> = scores
220            .into_iter()
221            .filter_map(|(doc_id, score)| {
222                documents.get(&doc_id).map(|doc| SearchResult {
223                    id: doc_id.clone(),
224                    content: doc.content.clone(),
225                    score: score / query_tokens.len() as f32, // Normalize by query length
226                    rank: 0,
227                    metadata: doc.metadata.clone(),
228                    embedding: None,
229                })
230            })
231            .collect();
232
233        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
234        results.truncate(limit);
235
236        // Update ranks
237        for (i, result) in results.iter_mut().enumerate() {
238            result.rank = i;
239        }
240
241        Ok(results)
242    }
243
244    /// Tokenize text into terms
245    fn tokenize(&self, text: &str) -> Vec<String> {
246        let lowercase = text.to_lowercase();
247        let tokens: Vec<String> = match &self.config.tokenizer {
248            TokenizerType::Standard => lowercase
249                .split(|c: char| !c.is_alphanumeric())
250                .filter(|s| !s.is_empty())
251                .filter(|s| s.len() >= self.config.min_token_length)
252                .filter(|s| s.len() <= self.config.max_token_length)
253                .filter(|s| !self.stop_words.contains(*s))
254                .map(|s| {
255                    if self.config.use_stemming {
256                        Self::simple_stem(s)
257                    } else {
258                        s.to_string()
259                    }
260                })
261                .collect(),
262            TokenizerType::NGram(n) => {
263                // N-gram tokenization
264                let chars: Vec<char> = lowercase.chars().collect();
265                let mut ngrams = Vec::new();
266                for i in 0..chars.len().saturating_sub(n - 1) {
267                    let ngram: String = chars[i..i + n].iter().collect();
268                    if !ngram.trim().is_empty() {
269                        ngrams.push(ngram);
270                    }
271                }
272                ngrams
273            }
274            TokenizerType::Language(ref _lang) => {
275                // For now, use standard tokenization
276                // In production, integrate language-specific tokenizers
277                lowercase
278                    .split_whitespace()
279                    .filter(|s| !self.stop_words.contains(*s))
280                    .map(String::from)
281                    .collect()
282            }
283        };
284
285        tokens
286    }
287
288    /// Calculate term frequencies for a list of tokens
289    fn calculate_term_frequencies(&self, tokens: &[String]) -> HashMap<String, f32> {
290        let mut frequencies = HashMap::new();
291        let total = tokens.len() as f32;
292
293        for token in tokens {
294            *frequencies.entry(token.clone()).or_insert(0.0) += 1.0;
295        }
296
297        // Normalize frequencies
298        for freq in frequencies.values_mut() {
299            *freq /= total;
300        }
301
302        frequencies
303    }
304
305    /// Simple stemming algorithm (Porter stemmer simplified)
306    fn simple_stem(word: &str) -> String {
307        let mut stem = word.to_string();
308
309        // Remove common suffixes
310        let suffixes = ["ing", "ed", "es", "s", "ly", "er", "est", "ness", "ment"];
311        for suffix in &suffixes {
312            if stem.len() > suffix.len() + 3 && stem.ends_with(suffix) {
313                stem.truncate(stem.len() - suffix.len());
314                break;
315            }
316        }
317
318        stem
319    }
320
321    /// Default English stop words
322    fn default_stop_words() -> HashSet<String> {
323        let words = vec![
324            "a", "an", "and", "are", "as", "at", "be", "been", "by", "for", "from", "has", "have",
325            "he", "in", "is", "it", "its", "of", "on", "that", "the", "to", "was", "will", "with",
326            "the", "this", "these", "those", "i", "you", "we", "they", "them", "their", "what",
327            "which", "who", "when", "where", "why", "how", "all", "would", "there", "could",
328        ];
329
330        words.into_iter().map(String::from).collect()
331    }
332
333    /// Clear the index
334    pub async fn clear(&self) -> RragResult<()> {
335        let mut documents = self.documents.write().await;
336        let mut inverted_index = self.inverted_index.write().await;
337        let mut doc_frequencies = self.document_frequencies.write().await;
338        let mut avg_length = self.avg_doc_length.write().await;
339        let mut total_docs = self.total_docs.write().await;
340
341        documents.clear();
342        inverted_index.clear();
343        doc_frequencies.clear();
344        *avg_length = 0.0;
345        *total_docs = 0;
346
347        Ok(())
348    }
349
350    /// Get index statistics
351    pub async fn stats(&self) -> HashMap<String, serde_json::Value> {
352        let documents = self.documents.read().await;
353        let inverted_index = self.inverted_index.read().await;
354        let total_docs = *self.total_docs.read().await;
355        let avg_length = *self.avg_doc_length.read().await;
356
357        let mut stats = HashMap::new();
358        stats.insert("total_documents".to_string(), total_docs.into());
359        stats.insert("unique_terms".to_string(), inverted_index.len().into());
360        stats.insert("average_document_length".to_string(), avg_length.into());
361        stats.insert(
362            "index_size_bytes".to_string(),
363            (documents.len() * std::mem::size_of::<BM25Document>()).into(),
364        );
365
366        stats
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[tokio::test]
375    async fn test_bm25_indexing_and_search() {
376        let retriever = BM25Retriever::new(BM25Config::default());
377
378        let docs = vec![
379            Document::with_id("1", "The quick brown fox jumps over the lazy dog"),
380            Document::with_id("2", "A quick brown dog runs through the forest"),
381            Document::with_id("3", "The lazy cat sleeps in the warm sunshine"),
382        ];
383
384        retriever.index_batch(docs).await.unwrap();
385
386        let results = retriever.search("quick brown", 2).await.unwrap();
387        assert_eq!(results.len(), 2);
388        assert!(results[0].score > results[1].score);
389    }
390}