Skip to main content

cognis_rag/retrievers/
bm25.rs

1//! BM25 retriever — classic Okapi BM25 over an in-memory corpus.
2//!
3//! Useful as a sparse-baseline retriever to complement embedding search,
4//! especially in [`super::EnsembleRetriever`] hybrids.
5
6use async_trait::async_trait;
7use std::collections::HashMap;
8
9use cognis_core::{Result, Runnable, RunnableConfig};
10
11use crate::document::Document;
12
13const DEFAULT_K1: f32 = 1.5;
14const DEFAULT_B: f32 = 0.75;
15
16/// In-memory BM25 retriever. Construct from a corpus via
17/// [`BM25Retriever::from_documents`]; the index is built upfront.
18pub struct BM25Retriever {
19    docs: Vec<Document>,
20    /// Term frequencies per document.
21    tf: Vec<HashMap<String, u32>>,
22    /// Document length (in tokens) per document.
23    doc_lens: Vec<u32>,
24    /// Inverse document frequency per term.
25    idf: HashMap<String, f32>,
26    /// Average document length across the corpus.
27    avg_doc_len: f32,
28    k: usize,
29    k1: f32,
30    b: f32,
31}
32
33impl BM25Retriever {
34    /// Build an index from a static corpus.
35    pub fn from_documents(docs: Vec<Document>) -> Self {
36        let n = docs.len();
37        let mut tf: Vec<HashMap<String, u32>> = Vec::with_capacity(n);
38        let mut doc_lens: Vec<u32> = Vec::with_capacity(n);
39        let mut df: HashMap<String, u32> = HashMap::new();
40
41        for d in &docs {
42            let tokens = tokenize(&d.content);
43            doc_lens.push(tokens.len() as u32);
44            let mut counts: HashMap<String, u32> = HashMap::new();
45            for t in &tokens {
46                *counts.entry(t.clone()).or_insert(0) += 1;
47            }
48            for t in counts.keys() {
49                *df.entry(t.clone()).or_insert(0) += 1;
50            }
51            tf.push(counts);
52        }
53
54        let n_f = n.max(1) as f32;
55        let avg_doc_len = if n == 0 {
56            0.0
57        } else {
58            doc_lens.iter().map(|&l| l as f32).sum::<f32>() / n_f
59        };
60
61        let mut idf: HashMap<String, f32> = HashMap::new();
62        for (term, dfreq) in df {
63            let v = ((n_f - dfreq as f32 + 0.5) / (dfreq as f32 + 0.5) + 1.0).ln();
64            idf.insert(term, v);
65        }
66
67        Self {
68            docs,
69            tf,
70            doc_lens,
71            idf,
72            avg_doc_len,
73            k: 4,
74            k1: DEFAULT_K1,
75            b: DEFAULT_B,
76        }
77    }
78
79    /// Override the top-k.
80    pub fn with_k(mut self, k: usize) -> Self {
81        self.k = k;
82        self
83    }
84
85    /// Override the BM25 `k1` parameter (default 1.5).
86    pub fn with_k1(mut self, k1: f32) -> Self {
87        self.k1 = k1;
88        self
89    }
90
91    /// Override the BM25 `b` parameter (default 0.75).
92    pub fn with_b(mut self, b: f32) -> Self {
93        self.b = b;
94        self
95    }
96
97    /// Score a single document against the query.
98    fn score(&self, query_terms: &[String], doc_idx: usize) -> f32 {
99        let dl = self.doc_lens[doc_idx] as f32;
100        let mut score = 0.0;
101        for term in query_terms {
102            let f = self.tf[doc_idx].get(term).copied().unwrap_or(0) as f32;
103            if f == 0.0 {
104                continue;
105            }
106            let idf = self.idf.get(term).copied().unwrap_or(0.0);
107            let denom = f + self.k1 * (1.0 - self.b + self.b * dl / self.avg_doc_len.max(1e-6));
108            score += idf * (f * (self.k1 + 1.0)) / denom;
109        }
110        score
111    }
112}
113
114#[async_trait]
115impl Runnable<String, Vec<Document>> for BM25Retriever {
116    async fn invoke(&self, query: String, _: RunnableConfig) -> Result<Vec<Document>> {
117        let q = tokenize(&query);
118        let mut scored: Vec<(usize, f32)> = (0..self.docs.len())
119            .map(|i| (i, self.score(&q, i)))
120            .filter(|(_, s)| *s > 0.0)
121            .collect();
122        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
123        Ok(scored
124            .into_iter()
125            .take(self.k)
126            .map(|(i, _)| self.docs[i].clone())
127            .collect())
128    }
129
130    fn name(&self) -> &str {
131        "BM25Retriever"
132    }
133}
134
135fn tokenize(s: &str) -> Vec<String> {
136    s.to_lowercase()
137        .split(|c: char| !c.is_alphanumeric())
138        .filter(|t| !t.is_empty())
139        .map(|t| t.to_string())
140        .collect()
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146
147    fn corpus() -> Vec<Document> {
148        vec![
149            Document::new("Rust is a systems programming language").with_id("1"),
150            Document::new("Python is a high-level dynamic language").with_id("2"),
151            Document::new("Rust has zero-cost abstractions and ownership").with_id("3"),
152            Document::new("Cooking with cast iron pans is great").with_id("4"),
153        ]
154    }
155
156    #[tokio::test]
157    async fn ranks_relevant_first() {
158        let r = BM25Retriever::from_documents(corpus()).with_k(2);
159        let out = r
160            .invoke("rust ownership".into(), RunnableConfig::default())
161            .await
162            .unwrap();
163        assert!(!out.is_empty());
164        let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
165        assert!(ids.iter().any(|id| id == "3" || id == "1"));
166        assert!(!ids.iter().any(|id| id == "4"));
167    }
168
169    #[tokio::test]
170    async fn returns_empty_for_no_match() {
171        let r = BM25Retriever::from_documents(corpus());
172        let out = r
173            .invoke("zzz unrelated query xyz".into(), RunnableConfig::default())
174            .await
175            .unwrap();
176        assert!(out.is_empty());
177    }
178
179    #[tokio::test]
180    async fn respects_k() {
181        let r = BM25Retriever::from_documents(corpus()).with_k(1);
182        let out = r
183            .invoke("language".into(), RunnableConfig::default())
184            .await
185            .unwrap();
186        assert!(out.len() <= 1);
187    }
188}