Skip to main content

synaptic_retrieval/
bm25.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use synaptic_core::SynapticError;
5
6use crate::{tokenize_to_vec, Document, Retriever};
7
8/// BM25 (Best Matching 25) retriever using Okapi BM25 scoring.
9///
10/// Pre-computes term frequencies, document lengths, and inverse document
11/// frequencies at construction time for efficient retrieval.
12#[derive(Debug, Clone)]
13pub struct BM25Retriever {
14    documents: Vec<Document>,
15    /// Term frequency per document: doc_term_freqs[doc_index][term] = count
16    doc_term_freqs: Vec<HashMap<String, usize>>,
17    /// Token count per document
18    doc_lengths: Vec<usize>,
19    /// Average document length across the corpus
20    avg_doc_length: f64,
21    /// Number of documents containing each term
22    doc_freq: HashMap<String, usize>,
23    /// Term saturation parameter (default 1.5)
24    k1: f64,
25    /// Length normalization parameter (default 0.75)
26    b: f64,
27}
28
29impl BM25Retriever {
30    /// Create a new BM25Retriever with default parameters (k1=1.5, b=0.75).
31    pub fn new(documents: Vec<Document>) -> Self {
32        Self::with_params(documents, 1.5, 0.75)
33    }
34
35    /// Create a new BM25Retriever with custom k1 and b parameters.
36    pub fn with_params(documents: Vec<Document>, k1: f64, b: f64) -> Self {
37        let mut doc_term_freqs = Vec::with_capacity(documents.len());
38        let mut doc_lengths = Vec::with_capacity(documents.len());
39        let mut doc_freq: HashMap<String, usize> = HashMap::new();
40
41        for doc in &documents {
42            let tokens = tokenize_to_vec(&doc.content);
43            let mut term_freq: HashMap<String, usize> = HashMap::new();
44
45            for token in &tokens {
46                *term_freq.entry(token.clone()).or_insert(0) += 1;
47            }
48
49            // Each unique term in this doc increments its document frequency
50            for term in term_freq.keys() {
51                *doc_freq.entry(term.clone()).or_insert(0) += 1;
52            }
53
54            doc_term_freqs.push(term_freq);
55            doc_lengths.push(tokens.len());
56        }
57
58        let avg_doc_length = if documents.is_empty() {
59            0.0
60        } else {
61            doc_lengths.iter().sum::<usize>() as f64 / documents.len() as f64
62        };
63
64        Self {
65            documents,
66            doc_term_freqs,
67            doc_lengths,
68            avg_doc_length,
69            doc_freq,
70            k1,
71            b,
72        }
73    }
74
75    /// Compute BM25 score for a single document given query terms.
76    fn score(&self, doc_idx: usize, query_terms: &[String]) -> f64 {
77        let n = self.documents.len() as f64;
78        let doc_len = self.doc_lengths[doc_idx] as f64;
79        let term_freqs = &self.doc_term_freqs[doc_idx];
80
81        let mut score = 0.0;
82
83        for term in query_terms {
84            let tf = *term_freqs.get(term).unwrap_or(&0) as f64;
85            let df = *self.doc_freq.get(term).unwrap_or(&0) as f64;
86
87            if df == 0.0 || tf == 0.0 {
88                continue;
89            }
90
91            // IDF: ln((N - df + 0.5) / (df + 0.5) + 1)
92            let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
93
94            // BM25 term score: idf * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * dl / avgdl))
95            let numerator = tf * (self.k1 + 1.0);
96            let denominator =
97                tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avg_doc_length);
98
99            score += idf * numerator / denominator;
100        }
101
102        score
103    }
104}
105
106#[async_trait]
107impl Retriever for BM25Retriever {
108    async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapticError> {
109        let query_terms = tokenize_to_vec(query);
110
111        if query_terms.is_empty() {
112            return Ok(vec![]);
113        }
114
115        let mut scored: Vec<(f64, usize)> = self
116            .documents
117            .iter()
118            .enumerate()
119            .map(|(idx, _)| (self.score(idx, &query_terms), idx))
120            .filter(|(score, _)| *score > 0.0)
121            .collect();
122
123        // Sort descending by score
124        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
125
126        Ok(scored
127            .into_iter()
128            .take(top_k)
129            .map(|(_, idx)| self.documents[idx].clone())
130            .collect())
131    }
132}