Skip to main content

graphrag_core/retrieval/
bm25.rs

1use crate::Result;
2use std::collections::HashMap;
3
4/// Document ID type for BM25 indexing
5pub type DocumentId = String;
6
7/// A document for BM25 indexing
8#[derive(Debug, Clone)]
9pub struct Document {
10    /// Unique identifier for the document
11    pub id: DocumentId,
12    /// Text content of the document
13    pub content: String,
14    /// Key-value metadata associated with document
15    pub metadata: HashMap<String, String>,
16}
17
18/// BM25 search result
19#[derive(Debug, Clone)]
20pub struct BM25Result {
21    /// Document identifier for this result
22    pub doc_id: DocumentId,
23    /// BM25 relevance score for the result
24    pub score: f32,
25    /// Text content of the matched document
26    pub content: String,
27}
28
29/// BM25 retrieval system for keyword-based search
30pub struct BM25Retriever {
31    /// BM25 parameter k1 (term frequency saturation)
32    k1: f32,
33    /// BM25 parameter b (length normalization)
34    b: f32,
35    /// Indexed documents
36    documents: HashMap<DocumentId, Document>,
37    /// Term frequencies per document: term -> document_id -> frequency
38    term_frequencies: HashMap<String, HashMap<DocumentId, f32>>,
39    /// Document frequencies: term -> number of documents containing term
40    document_frequencies: HashMap<String, usize>,
41    /// Document lengths (in tokens)
42    document_lengths: HashMap<DocumentId, usize>,
43    /// Average document length
44    avg_doc_length: f32,
45    /// Total number of documents
46    total_docs: usize,
47}
48
49impl BM25Retriever {
50    /// Create a new BM25 retriever with default parameters
51    pub fn new() -> Self {
52        Self::with_parameters(1.2, 0.75)
53    }
54
55    /// Create a new BM25 retriever with custom parameters
56    pub fn with_parameters(k1: f32, b: f32) -> Self {
57        Self {
58            k1,
59            b,
60            documents: HashMap::new(),
61            term_frequencies: HashMap::new(),
62            document_frequencies: HashMap::new(),
63            document_lengths: HashMap::new(),
64            avg_doc_length: 0.0,
65            total_docs: 0,
66        }
67    }
68
69    /// Index a single document
70    pub fn index_document(&mut self, document: Document) -> Result<()> {
71        let doc_id = document.id.clone();
72        let tokens = self.tokenize(&document.content);
73        let doc_length = tokens.len();
74
75        // Calculate term frequencies for this document
76        let mut term_freq: HashMap<String, usize> = HashMap::new();
77        for token in &tokens {
78            *term_freq.entry(token.clone()).or_insert(0) += 1;
79        }
80
81        // Update document frequencies
82        for term in term_freq.keys() {
83            *self.document_frequencies.entry(term.clone()).or_insert(0) += 1;
84        }
85
86        // Store normalized term frequencies
87        for (term, freq) in term_freq {
88            let normalized_freq = freq as f32 / doc_length as f32;
89            self.term_frequencies
90                .entry(term)
91                .or_default()
92                .insert(doc_id.clone(), normalized_freq);
93        }
94
95        // Store document and metadata
96        self.document_lengths.insert(doc_id.clone(), doc_length);
97        self.documents.insert(doc_id, document);
98        self.total_docs += 1;
99
100        // Update average document length
101        self.update_avg_doc_length();
102
103        Ok(())
104    }
105
106    /// Index multiple documents
107    pub fn index_documents(&mut self, documents: &[Document]) -> Result<()> {
108        for document in documents {
109            self.index_document(document.clone())?;
110        }
111        Ok(())
112    }
113
114    /// Search for documents matching the query
115    pub fn search(&self, query: &str, limit: usize) -> Vec<BM25Result> {
116        if self.total_docs == 0 {
117            return Vec::new();
118        }
119
120        let query_tokens = self.tokenize(query);
121        let mut doc_scores: HashMap<DocumentId, f32> = HashMap::new();
122
123        // Calculate BM25 score for each document
124        for token in &query_tokens {
125            if let Some(doc_freqs) = self.term_frequencies.get(token) {
126                let idf = self.calculate_idf(token);
127
128                for (doc_id, tf) in doc_freqs {
129                    let doc_length = *self.document_lengths.get(doc_id).unwrap_or(&0);
130                    let bm25_term_score = self.calculate_bm25_term_score(*tf, doc_length, idf);
131
132                    *doc_scores.entry(doc_id.clone()).or_insert(0.0) += bm25_term_score;
133                }
134            }
135        }
136
137        // Convert to results and sort by score
138        let mut results: Vec<BM25Result> = doc_scores
139            .into_iter()
140            .filter_map(|(doc_id, score)| {
141                self.documents.get(&doc_id).map(|doc| BM25Result {
142                    doc_id: doc_id.clone(),
143                    score,
144                    content: doc.content.clone(),
145                })
146            })
147            .collect();
148
149        results.sort_by(|a, b| {
150            b.score
151                .partial_cmp(&a.score)
152                .unwrap_or(std::cmp::Ordering::Equal)
153        });
154        results.truncate(limit);
155        results
156    }
157
158    /// Get document by ID
159    pub fn get_document(&self, doc_id: &DocumentId) -> Option<&Document> {
160        self.documents.get(doc_id)
161    }
162
163    /// Get total number of indexed documents
164    pub fn document_count(&self) -> usize {
165        self.total_docs
166    }
167
168    /// Get number of unique terms in the index
169    pub fn term_count(&self) -> usize {
170        self.term_frequencies.len()
171    }
172
173    /// Calculate IDF (Inverse Document Frequency) for a term
174    /// Uses Lucene-style IDF to avoid negative values for common terms
175    fn calculate_idf(&self, term: &str) -> f32 {
176        let doc_freq = self.document_frequencies.get(term).unwrap_or(&0);
177        if *doc_freq == 0 {
178            return 0.0;
179        }
180
181        // Lucene-style IDF: log(N/df) + 1, which ensures non-negative values
182        (self.total_docs as f32 / *doc_freq as f32).ln() + 1.0
183    }
184
185    /// Calculate BM25 term score
186    fn calculate_bm25_term_score(&self, tf: f32, doc_length: usize, idf: f32) -> f32 {
187        let tf_component = (tf * (self.k1 + 1.0))
188            / (tf + self.k1 * (1.0 - self.b + self.b * (doc_length as f32 / self.avg_doc_length)));
189
190        idf * tf_component
191    }
192
193    /// Update average document length
194    fn update_avg_doc_length(&mut self) {
195        if self.total_docs > 0 {
196            let total_length: usize = self.document_lengths.values().sum();
197            self.avg_doc_length = total_length as f32 / self.total_docs as f32;
198        }
199    }
200
201    /// Tokenize text into terms
202    fn tokenize(&self, text: &str) -> Vec<String> {
203        text.to_lowercase()
204            .split_whitespace()
205            .map(|s| {
206                // Remove punctuation and clean up
207                s.chars()
208                    .filter(|c| c.is_alphanumeric())
209                    .collect::<String>()
210            })
211            .filter(|s| !s.is_empty() && s.len() > 2 && !self.is_stop_word(s))
212            .collect()
213    }
214
215    /// Check if a word is a stop word
216    fn is_stop_word(&self, word: &str) -> bool {
217        const STOP_WORDS: &[&str] = &[
218            "the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", "for", "not",
219            "on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
220            "they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
221            "there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
222            "go", "me", "when", "make", "can", "like", "time", "no", "just", "him", "know", "take",
223            "people", "into", "year", "your", "good", "some", "could", "them", "see", "other",
224            "than", "then", "now", "look", "only", "come", "its", "over", "think", "also", "back",
225            "after", "use", "two", "how", "our", "work", "first", "well", "way", "even", "new",
226            "want", "because", "any", "these", "give", "day", "most", "us",
227        ];
228        STOP_WORDS.contains(&word)
229    }
230
231    /// Clear all indexed data
232    pub fn clear(&mut self) {
233        self.documents.clear();
234        self.term_frequencies.clear();
235        self.document_frequencies.clear();
236        self.document_lengths.clear();
237        self.avg_doc_length = 0.0;
238        self.total_docs = 0;
239    }
240
241    /// Get statistics about the index
242    pub fn get_statistics(&self) -> BM25Statistics {
243        BM25Statistics {
244            total_documents: self.total_docs,
245            total_terms: self.term_frequencies.len(),
246            avg_doc_length: self.avg_doc_length,
247            parameters: BM25Parameters {
248                k1: self.k1,
249                b: self.b,
250            },
251        }
252    }
253}
254
255impl Default for BM25Retriever {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261/// BM25 algorithm parameters
262#[derive(Debug, Clone)]
263pub struct BM25Parameters {
264    /// Term frequency saturation parameter
265    pub k1: f32,
266    /// Length normalization parameter
267    pub b: f32,
268}
269
270/// Statistics about the BM25 index
271#[derive(Debug, Clone)]
272pub struct BM25Statistics {
273    /// Total number of indexed documents
274    pub total_documents: usize,
275    /// Total number of unique terms
276    pub total_terms: usize,
277    /// Average document length in tokens
278    pub avg_doc_length: f32,
279    /// BM25 algorithm parameters used
280    pub parameters: BM25Parameters,
281}
282
283impl BM25Statistics {
284    /// Print statistics
285    pub fn print(&self) {
286        println!("BM25 Index Statistics:");
287        println!("  Total documents: {}", self.total_documents);
288        println!("  Total terms: {}", self.total_terms);
289        println!(
290            "  Average document length: {:.2} tokens",
291            self.avg_doc_length
292        );
293        println!(
294            "  Parameters: k1={:.2}, b={:.2}",
295            self.parameters.k1, self.parameters.b
296        );
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    fn create_test_documents() -> Vec<Document> {
305        vec![
306            Document {
307                id: "doc1".to_string(),
308                content: "The quick brown fox jumps over the lazy dog".to_string(),
309                metadata: HashMap::new(),
310            },
311            Document {
312                id: "doc2".to_string(),
313                content: "A fast brown animal leaps across a sleeping canine".to_string(),
314                metadata: HashMap::new(),
315            },
316            Document {
317                id: "doc3".to_string(),
318                content: "The weather is nice today".to_string(),
319                metadata: HashMap::new(),
320            },
321        ]
322    }
323
324    #[test]
325    fn test_bm25_creation() {
326        let retriever = BM25Retriever::new();
327        assert_eq!(retriever.document_count(), 0);
328        assert_eq!(retriever.term_count(), 0);
329    }
330
331    #[test]
332    fn test_document_indexing() {
333        let mut retriever = BM25Retriever::new();
334        let docs = create_test_documents();
335
336        retriever.index_documents(&docs).unwrap();
337
338        assert_eq!(retriever.document_count(), 3);
339        assert!(retriever.term_count() > 0);
340    }
341
342    #[test]
343    fn test_search() {
344        let mut retriever = BM25Retriever::new();
345        let docs = create_test_documents();
346
347        retriever.index_documents(&docs).unwrap();
348
349        let results = retriever.search("brown fox", 10);
350        assert!(!results.is_empty());
351
352        // First result should be the most relevant document
353        assert_eq!(results[0].doc_id, "doc1");
354        assert!(results[0].score > 0.0);
355    }
356
357    #[test]
358    fn test_tokenization() {
359        let retriever = BM25Retriever::new();
360        let tokens = retriever.tokenize("The quick, brown fox!");
361
362        // Should filter out stop words and punctuation
363        assert!(tokens.contains(&"quick".to_string()));
364        assert!(tokens.contains(&"brown".to_string()));
365        assert!(tokens.contains(&"fox".to_string()));
366        assert!(!tokens.contains(&"the".to_string())); // stop word
367    }
368
369    #[test]
370    fn test_empty_search() {
371        let retriever = BM25Retriever::new();
372        let results = retriever.search("test", 10);
373        assert!(results.is_empty());
374    }
375}