cognis_rag/retrievers/
bm25.rs1use async_trait::async_trait;
7use std::collections::HashMap;
8
9use cognis_core::{Result, Runnable, RunnableConfig};
10
11use crate::document::Document;
12
13const DEFAULT_K1: f32 = 1.5;
14const DEFAULT_B: f32 = 0.75;
15
16pub struct BM25Retriever {
19 docs: Vec<Document>,
20 tf: Vec<HashMap<String, u32>>,
22 doc_lens: Vec<u32>,
24 idf: HashMap<String, f32>,
26 avg_doc_len: f32,
28 k: usize,
29 k1: f32,
30 b: f32,
31}
32
33impl BM25Retriever {
34 pub fn from_documents(docs: Vec<Document>) -> Self {
36 let n = docs.len();
37 let mut tf: Vec<HashMap<String, u32>> = Vec::with_capacity(n);
38 let mut doc_lens: Vec<u32> = Vec::with_capacity(n);
39 let mut df: HashMap<String, u32> = HashMap::new();
40
41 for d in &docs {
42 let tokens = tokenize(&d.content);
43 doc_lens.push(tokens.len() as u32);
44 let mut counts: HashMap<String, u32> = HashMap::new();
45 for t in &tokens {
46 *counts.entry(t.clone()).or_insert(0) += 1;
47 }
48 for t in counts.keys() {
49 *df.entry(t.clone()).or_insert(0) += 1;
50 }
51 tf.push(counts);
52 }
53
54 let n_f = n.max(1) as f32;
55 let avg_doc_len = if n == 0 {
56 0.0
57 } else {
58 doc_lens.iter().map(|&l| l as f32).sum::<f32>() / n_f
59 };
60
61 let mut idf: HashMap<String, f32> = HashMap::new();
62 for (term, dfreq) in df {
63 let v = ((n_f - dfreq as f32 + 0.5) / (dfreq as f32 + 0.5) + 1.0).ln();
64 idf.insert(term, v);
65 }
66
67 Self {
68 docs,
69 tf,
70 doc_lens,
71 idf,
72 avg_doc_len,
73 k: 4,
74 k1: DEFAULT_K1,
75 b: DEFAULT_B,
76 }
77 }
78
79 pub fn with_k(mut self, k: usize) -> Self {
81 self.k = k;
82 self
83 }
84
85 pub fn with_k1(mut self, k1: f32) -> Self {
87 self.k1 = k1;
88 self
89 }
90
91 pub fn with_b(mut self, b: f32) -> Self {
93 self.b = b;
94 self
95 }
96
97 fn score(&self, query_terms: &[String], doc_idx: usize) -> f32 {
99 let dl = self.doc_lens[doc_idx] as f32;
100 let mut score = 0.0;
101 for term in query_terms {
102 let f = self.tf[doc_idx].get(term).copied().unwrap_or(0) as f32;
103 if f == 0.0 {
104 continue;
105 }
106 let idf = self.idf.get(term).copied().unwrap_or(0.0);
107 let denom = f + self.k1 * (1.0 - self.b + self.b * dl / self.avg_doc_len.max(1e-6));
108 score += idf * (f * (self.k1 + 1.0)) / denom;
109 }
110 score
111 }
112}
113
114#[async_trait]
115impl Runnable<String, Vec<Document>> for BM25Retriever {
116 async fn invoke(&self, query: String, _: RunnableConfig) -> Result<Vec<Document>> {
117 let q = tokenize(&query);
118 let mut scored: Vec<(usize, f32)> = (0..self.docs.len())
119 .map(|i| (i, self.score(&q, i)))
120 .filter(|(_, s)| *s > 0.0)
121 .collect();
122 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
123 Ok(scored
124 .into_iter()
125 .take(self.k)
126 .map(|(i, _)| self.docs[i].clone())
127 .collect())
128 }
129
130 fn name(&self) -> &str {
131 "BM25Retriever"
132 }
133}
134
135fn tokenize(s: &str) -> Vec<String> {
136 s.to_lowercase()
137 .split(|c: char| !c.is_alphanumeric())
138 .filter(|t| !t.is_empty())
139 .map(|t| t.to_string())
140 .collect()
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 fn corpus() -> Vec<Document> {
148 vec![
149 Document::new("Rust is a systems programming language").with_id("1"),
150 Document::new("Python is a high-level dynamic language").with_id("2"),
151 Document::new("Rust has zero-cost abstractions and ownership").with_id("3"),
152 Document::new("Cooking with cast iron pans is great").with_id("4"),
153 ]
154 }
155
156 #[tokio::test]
157 async fn ranks_relevant_first() {
158 let r = BM25Retriever::from_documents(corpus()).with_k(2);
159 let out = r
160 .invoke("rust ownership".into(), RunnableConfig::default())
161 .await
162 .unwrap();
163 assert!(!out.is_empty());
164 let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
165 assert!(ids.iter().any(|id| id == "3" || id == "1"));
166 assert!(!ids.iter().any(|id| id == "4"));
167 }
168
169 #[tokio::test]
170 async fn returns_empty_for_no_match() {
171 let r = BM25Retriever::from_documents(corpus());
172 let out = r
173 .invoke("zzz unrelated query xyz".into(), RunnableConfig::default())
174 .await
175 .unwrap();
176 assert!(out.is_empty());
177 }
178
179 #[tokio::test]
180 async fn respects_k() {
181 let r = BM25Retriever::from_documents(corpus()).with_k(1);
182 let out = r
183 .invoke("language".into(), RunnableConfig::default())
184 .await
185 .unwrap();
186 assert!(out.len() <= 1);
187 }
188}