Skip to main content

graphrag_core/retrieval/
bm25.rs

1use crate::Result;
2use std::collections::HashMap;
3
4/// Document ID type for BM25 indexing
5pub type DocumentId = String;
6
7/// A document for BM25 indexing
8#[derive(Debug, Clone)]
9pub struct Document {
10    /// Unique identifier for the document
11    pub id: DocumentId,
12    /// Text content of the document
13    pub content: String,
14    /// Key-value metadata associated with document
15    pub metadata: HashMap<String, String>,
16}
17
18/// BM25 search result
19#[derive(Debug, Clone)]
20pub struct BM25Result {
21    /// Document identifier for this result
22    pub doc_id: DocumentId,
23    /// BM25 relevance score for the result
24    pub score: f32,
25    /// Text content of the matched document
26    pub content: String,
27}
28
29/// BM25 retrieval system for keyword-based search
30pub struct BM25Retriever {
31    /// BM25 parameter k1 (term frequency saturation)
32    k1: f32,
33    /// BM25 parameter b (length normalization)
34    b: f32,
35    /// Indexed documents
36    documents: HashMap<DocumentId, Document>,
37    /// Term frequencies per document: term -> document_id -> frequency
38    term_frequencies: HashMap<String, HashMap<DocumentId, f32>>,
39    /// Document frequencies: term -> number of documents containing term
40    document_frequencies: HashMap<String, usize>,
41    /// Document lengths (in tokens)
42    document_lengths: HashMap<DocumentId, usize>,
43    /// Average document length
44    avg_doc_length: f32,
45    /// Total number of documents
46    total_docs: usize,
47}
48
49impl BM25Retriever {
50    /// Create a new BM25 retriever with default parameters
51    pub fn new() -> Self {
52        Self::with_parameters(1.2, 0.75)
53    }
54
55    /// Create a new BM25 retriever with custom parameters
56    pub fn with_parameters(k1: f32, b: f32) -> Self {
57        Self {
58            k1,
59            b,
60            documents: HashMap::new(),
61            term_frequencies: HashMap::new(),
62            document_frequencies: HashMap::new(),
63            document_lengths: HashMap::new(),
64            avg_doc_length: 0.0,
65            total_docs: 0,
66        }
67    }
68
69    /// Index a single document
70    pub fn index_document(&mut self, document: Document) -> Result<()> {
71        let doc_id = document.id.clone();
72        let tokens = self.tokenize(&document.content);
73        let doc_length = tokens.len();
74
75        // Calculate term frequencies for this document
76        let mut term_freq: HashMap<String, usize> = HashMap::new();
77        for token in &tokens {
78            *term_freq.entry(token.clone()).or_insert(0) += 1;
79        }
80
81        // Update document frequencies
82        for term in term_freq.keys() {
83            *self.document_frequencies.entry(term.clone()).or_insert(0) += 1;
84        }
85
86        // Store normalized term frequencies
87        for (term, freq) in term_freq {
88            let normalized_freq = freq as f32 / doc_length as f32;
89            self.term_frequencies
90                .entry(term)
91                .or_default()
92                .insert(doc_id.clone(), normalized_freq);
93        }
94
95        // Store document and metadata
96        self.document_lengths.insert(doc_id.clone(), doc_length);
97        self.documents.insert(doc_id, document);
98        self.total_docs += 1;
99
100        // Update average document length
101        self.update_avg_doc_length();
102
103        Ok(())
104    }
105
106    /// Index multiple documents
107    pub fn index_documents(&mut self, documents: &[Document]) -> Result<()> {
108        for document in documents {
109            self.index_document(document.clone())?;
110        }
111        Ok(())
112    }
113
114    /// Search for documents matching the query
115    pub fn search(&self, query: &str, limit: usize) -> Vec<BM25Result> {
116        if self.total_docs == 0 {
117            return Vec::new();
118        }
119
120        let query_tokens = self.tokenize(query);
121        let mut doc_scores: HashMap<DocumentId, f32> = HashMap::new();
122
123        // Calculate BM25 score for each document
124        for token in &query_tokens {
125            if let Some(doc_freqs) = self.term_frequencies.get(token) {
126                let idf = self.calculate_idf(token);
127
128                for (doc_id, tf) in doc_freqs {
129                    let doc_length = *self.document_lengths.get(doc_id).unwrap_or(&0);
130                    let bm25_term_score = self.calculate_bm25_term_score(*tf, doc_length, idf);
131
132                    *doc_scores.entry(doc_id.clone()).or_insert(0.0) += bm25_term_score;
133                }
134            }
135        }
136
137        // Convert to results and sort by score
138        let mut results: Vec<BM25Result> = doc_scores
139            .into_iter()
140            .filter_map(|(doc_id, score)| {
141                self.documents.get(&doc_id).map(|doc| BM25Result {
142                    doc_id: doc_id.clone(),
143                    score,
144                    content: doc.content.clone(),
145                })
146            })
147            .collect();
148
149        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
150        results.truncate(limit);
151        results
152    }
153
154    /// Get document by ID
155    pub fn get_document(&self, doc_id: &DocumentId) -> Option<&Document> {
156        self.documents.get(doc_id)
157    }
158
159    /// Get total number of indexed documents
160    pub fn document_count(&self) -> usize {
161        self.total_docs
162    }
163
164    /// Get number of unique terms in the index
165    pub fn term_count(&self) -> usize {
166        self.term_frequencies.len()
167    }
168
169    /// Calculate IDF (Inverse Document Frequency) for a term
170    /// Uses Lucene-style IDF to avoid negative values for common terms
171    fn calculate_idf(&self, term: &str) -> f32 {
172        let doc_freq = self.document_frequencies.get(term).unwrap_or(&0);
173        if *doc_freq == 0 {
174            return 0.0;
175        }
176
177        // Lucene-style IDF: log(N/df) + 1, which ensures non-negative values
178        (self.total_docs as f32 / *doc_freq as f32).ln() + 1.0
179    }
180
181    /// Calculate BM25 term score
182    fn calculate_bm25_term_score(&self, tf: f32, doc_length: usize, idf: f32) -> f32 {
183        let tf_component = (tf * (self.k1 + 1.0))
184            / (tf + self.k1 * (1.0 - self.b + self.b * (doc_length as f32 / self.avg_doc_length)));
185
186        idf * tf_component
187    }
188
189    /// Update average document length
190    fn update_avg_doc_length(&mut self) {
191        if self.total_docs > 0 {
192            let total_length: usize = self.document_lengths.values().sum();
193            self.avg_doc_length = total_length as f32 / self.total_docs as f32;
194        }
195    }
196
197    /// Tokenize text into terms
198    fn tokenize(&self, text: &str) -> Vec<String> {
199        text.to_lowercase()
200            .split_whitespace()
201            .map(|s| {
202                // Remove punctuation and clean up
203                s.chars()
204                    .filter(|c| c.is_alphanumeric())
205                    .collect::<String>()
206            })
207            .filter(|s| !s.is_empty() && s.len() > 2 && !self.is_stop_word(s))
208            .collect()
209    }
210
211    /// Check if a word is a stop word
212    fn is_stop_word(&self, word: &str) -> bool {
213        const STOP_WORDS: &[&str] = &[
214            "the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", "for", "not",
215            "on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
216            "they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
217            "there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
218            "go", "me", "when", "make", "can", "like", "time", "no", "just", "him", "know", "take",
219            "people", "into", "year", "your", "good", "some", "could", "them", "see", "other",
220            "than", "then", "now", "look", "only", "come", "its", "over", "think", "also", "back",
221            "after", "use", "two", "how", "our", "work", "first", "well", "way", "even", "new",
222            "want", "because", "any", "these", "give", "day", "most", "us",
223        ];
224        STOP_WORDS.contains(&word)
225    }
226
227    /// Clear all indexed data
228    pub fn clear(&mut self) {
229        self.documents.clear();
230        self.term_frequencies.clear();
231        self.document_frequencies.clear();
232        self.document_lengths.clear();
233        self.avg_doc_length = 0.0;
234        self.total_docs = 0;
235    }
236
237    /// Get statistics about the index
238    pub fn get_statistics(&self) -> BM25Statistics {
239        BM25Statistics {
240            total_documents: self.total_docs,
241            total_terms: self.term_frequencies.len(),
242            avg_doc_length: self.avg_doc_length,
243            parameters: BM25Parameters {
244                k1: self.k1,
245                b: self.b,
246            },
247        }
248    }
249}
250
251impl Default for BM25Retriever {
252    fn default() -> Self {
253        Self::new()
254    }
255}
256
257/// BM25 algorithm parameters
258#[derive(Debug, Clone)]
259pub struct BM25Parameters {
260    /// Term frequency saturation parameter
261    pub k1: f32,
262    /// Length normalization parameter
263    pub b: f32,
264}
265
266/// Statistics about the BM25 index
267#[derive(Debug, Clone)]
268pub struct BM25Statistics {
269    /// Total number of indexed documents
270    pub total_documents: usize,
271    /// Total number of unique terms
272    pub total_terms: usize,
273    /// Average document length in tokens
274    pub avg_doc_length: f32,
275    /// BM25 algorithm parameters used
276    pub parameters: BM25Parameters,
277}
278
279impl BM25Statistics {
280    /// Print statistics
281    pub fn print(&self) {
282        println!("BM25 Index Statistics:");
283        println!("  Total documents: {}", self.total_documents);
284        println!("  Total terms: {}", self.total_terms);
285        println!(
286            "  Average document length: {:.2} tokens",
287            self.avg_doc_length
288        );
289        println!(
290            "  Parameters: k1={:.2}, b={:.2}",
291            self.parameters.k1, self.parameters.b
292        );
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    fn create_test_documents() -> Vec<Document> {
301        vec![
302            Document {
303                id: "doc1".to_string(),
304                content: "The quick brown fox jumps over the lazy dog".to_string(),
305                metadata: HashMap::new(),
306            },
307            Document {
308                id: "doc2".to_string(),
309                content: "A fast brown animal leaps across a sleeping canine".to_string(),
310                metadata: HashMap::new(),
311            },
312            Document {
313                id: "doc3".to_string(),
314                content: "The weather is nice today".to_string(),
315                metadata: HashMap::new(),
316            },
317        ]
318    }
319
320    #[test]
321    fn test_bm25_creation() {
322        let retriever = BM25Retriever::new();
323        assert_eq!(retriever.document_count(), 0);
324        assert_eq!(retriever.term_count(), 0);
325    }
326
327    #[test]
328    fn test_document_indexing() {
329        let mut retriever = BM25Retriever::new();
330        let docs = create_test_documents();
331
332        retriever.index_documents(&docs).unwrap();
333
334        assert_eq!(retriever.document_count(), 3);
335        assert!(retriever.term_count() > 0);
336    }
337
338    #[test]
339    fn test_search() {
340        let mut retriever = BM25Retriever::new();
341        let docs = create_test_documents();
342
343        retriever.index_documents(&docs).unwrap();
344
345        let results = retriever.search("brown fox", 10);
346        assert!(!results.is_empty());
347
348        // First result should be the most relevant document
349        assert_eq!(results[0].doc_id, "doc1");
350        assert!(results[0].score > 0.0);
351    }
352
353    #[test]
354    fn test_tokenization() {
355        let retriever = BM25Retriever::new();
356        let tokens = retriever.tokenize("The quick, brown fox!");
357
358        // Should filter out stop words and punctuation
359        assert!(tokens.contains(&"quick".to_string()));
360        assert!(tokens.contains(&"brown".to_string()));
361        assert!(tokens.contains(&"fox".to_string()));
362        assert!(!tokens.contains(&"the".to_string())); // stop word
363    }
364
365    #[test]
366    fn test_empty_search() {
367        let retriever = BM25Retriever::new();
368        let results = retriever.search("test", 10);
369        assert!(results.is_empty());
370    }
371}