ripvec_core/encoder/ripvec/
bm25.rs1use std::path::Path;
16
17use lasso::{Spur, ThreadedRodeo};
18use rayon::prelude::*;
19use rustc_hash::{FxBuildHasher, FxHashMap};
20
21use crate::chunk::CodeChunk;
22use crate::encoder::ripvec::tokens::tokenize;
23
24const K1: f32 = 1.5;
26const B: f32 = 0.75;
28
29#[must_use]
36pub fn enrich_for_bm25(chunk: &CodeChunk) -> String {
37 let path = Path::new(&chunk.file_path);
38 let stem = path
39 .file_stem()
40 .and_then(|s| s.to_str())
41 .unwrap_or_default();
42 let dir_parts: Vec<&str> = path
43 .parent()
44 .into_iter()
45 .flat_map(|p| p.iter())
46 .filter_map(|os| os.to_str())
47 .filter(|part| *part != "." && *part != "/")
48 .collect();
49 let tail_len = dir_parts.len().min(3);
51 let dir_text = dir_parts[dir_parts.len() - tail_len..].join(" ");
52 format!("{} {stem} {stem} {dir_text}", chunk.content)
53}
54
55pub struct Bm25Index {
61 rodeo: ThreadedRodeo<Spur, FxBuildHasher>,
68 doc_tfs: Vec<FxHashMap<Spur, u32>>,
73 doc_lengths: Vec<u32>,
75 avgdl: f32,
77 df_idf: FxHashMap<Spur, (u32, f32)>,
79}
80
81impl Bm25Index {
82 #[must_use]
102 pub fn build(chunks: &[CodeChunk]) -> Self {
103 let n = chunks.len();
104 let rodeo: ThreadedRodeo<Spur, FxBuildHasher> = ThreadedRodeo::with_hasher(FxBuildHasher);
105 if n == 0 {
106 return Self {
107 rodeo,
108 doc_tfs: Vec::new(),
109 doc_lengths: Vec::new(),
110 avgdl: 0.0,
111 df_idf: FxHashMap::default(),
112 };
113 }
114
115 let per_doc: Vec<(FxHashMap<Spur, u32>, u32)> = chunks
120 .par_iter()
121 .map(|chunk| {
122 let enriched = enrich_for_bm25(chunk);
123 let tokens = tokenize(&enriched);
124 let token_count = u32::try_from(tokens.len()).unwrap_or(u32::MAX);
125 let mut tfs: FxHashMap<Spur, u32> =
126 FxHashMap::with_capacity_and_hasher(tokens.len(), FxBuildHasher);
127 for tok in &tokens {
128 let id = rodeo.get_or_intern(tok);
129 *tfs.entry(id).or_insert(0) += 1;
130 }
131 (tfs, token_count)
132 })
133 .collect();
134
135 let mut doc_tfs: Vec<FxHashMap<Spur, u32>> = Vec::with_capacity(n);
139 let mut doc_lengths: Vec<u32> = Vec::with_capacity(n);
140 let mut df: FxHashMap<Spur, u32> = FxHashMap::default();
141 for (tfs, len) in per_doc {
142 for term_id in tfs.keys() {
143 *df.entry(*term_id).or_insert(0) += 1;
144 }
145 doc_lengths.push(len);
146 doc_tfs.push(tfs);
147 }
148
149 let total_len: u64 = doc_lengths.iter().map(|&l| u64::from(l)).sum();
150 #[expect(
151 clippy::cast_precision_loss,
152 reason = "doc counts are bounded; f32 precision is sufficient for avgdl"
153 )]
154 let avgdl = (total_len as f32) / (n as f32);
155
156 #[expect(
159 clippy::cast_precision_loss,
160 reason = "doc counts are bounded; f32 precision is sufficient for idf"
161 )]
162 let n_f = n as f32;
163 let df_idf: FxHashMap<Spur, (u32, f32)> = df
164 .into_iter()
165 .map(|(term_id, df_count)| {
166 #[expect(
167 clippy::cast_precision_loss,
168 reason = "df is u32; f32 precision sufficient for idf"
169 )]
170 let df_f = df_count as f32;
171 let idf = ((n_f - df_f + 0.5) / (df_f + 0.5) + 1.0).ln();
172 (term_id, (df_count, idf))
173 })
174 .collect();
175
176 Self {
177 rodeo,
178 doc_tfs,
179 doc_lengths,
180 avgdl,
181 df_idf,
182 }
183 }
184
185 #[must_use]
187 pub fn len(&self) -> usize {
188 self.doc_tfs.len()
189 }
190
191 #[must_use]
193 pub fn is_empty(&self) -> bool {
194 self.doc_tfs.is_empty()
195 }
196
197 #[must_use]
201 pub fn score(&self, query: &str) -> Vec<f32> {
202 let q_tokens = tokenize(query);
203 if q_tokens.is_empty() || self.doc_tfs.is_empty() {
204 return vec![0.0; self.doc_tfs.len()];
205 }
206 let mut query_ids: Vec<Spur> = Vec::with_capacity(q_tokens.len());
210 let mut seen: rustc_hash::FxHashSet<Spur> = rustc_hash::FxHashSet::default();
211 for term in &q_tokens {
212 if let Some(id) = self.rodeo.get(term)
213 && seen.insert(id)
214 {
215 query_ids.push(id);
216 }
217 }
218 if query_ids.is_empty() {
219 return vec![0.0; self.doc_tfs.len()];
220 }
221
222 let mut scores = vec![0.0_f32; self.doc_tfs.len()];
223 #[expect(
224 clippy::cast_precision_loss,
225 reason = "tf/dl are u32 counts; f32 precision sufficient"
226 )]
227 for &term_id in &query_ids {
228 let Some(&(_, idf)) = self.df_idf.get(&term_id) else {
229 continue;
230 };
231 for (doc_idx, tfs) in self.doc_tfs.iter().enumerate() {
232 let Some(&tf) = tfs.get(&term_id) else {
233 continue;
234 };
235 let tf_f = tf as f32;
236 let dl = self.doc_lengths[doc_idx] as f32;
237 let norm = if self.avgdl > 0.0 {
238 dl / self.avgdl
239 } else {
240 0.0
241 };
242 let denom = tf_f + K1 * (1.0 - B + B * norm);
243 scores[doc_idx] += idf * tf_f * (K1 + 1.0) / denom.max(f32::EPSILON);
244 }
245 }
246 scores
247 }
248}
249
250#[must_use]
254pub fn selector_to_mask(selector: Option<&[usize]>, size: usize) -> Option<Vec<bool>> {
255 selector.map(|sel| {
256 let mut mask = vec![false; size];
257 for &i in sel {
258 if i < size {
259 mask[i] = true;
260 }
261 }
262 mask
263 })
264}
265
266#[must_use]
271pub fn search_bm25(
272 query: &str,
273 index: &Bm25Index,
274 top_k: usize,
275 selector: Option<&[usize]>,
276) -> Vec<(usize, f32)> {
277 if index.is_empty() || top_k == 0 {
278 return Vec::new();
279 }
280 let mask = selector_to_mask(selector, index.len());
281 let mut scores = index.score(query);
282 if let Some(m) = &mask {
283 for (i, allowed) in m.iter().enumerate() {
284 if !allowed {
285 scores[i] = 0.0;
286 }
287 }
288 }
289 let mut indexed: Vec<(usize, f32)> = scores
290 .into_iter()
291 .enumerate()
292 .filter(|(_, s)| *s > 0.0)
293 .collect();
294 indexed.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
295 indexed.truncate(top_k);
296 indexed
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 fn chunk(path: &str, content: &str) -> CodeChunk {
304 CodeChunk {
305 file_path: path.to_string(),
306 name: String::new(),
307 kind: String::new(),
308 start_line: 1,
309 end_line: 1,
310 content: content.to_string(),
311 enriched_content: content.to_string(),
312 }
313 }
314
315 #[test]
318 fn bm25_enrich_stem_doubled() {
319 let c = chunk("src/foo.rs", "fn run() {}");
320 let enriched = enrich_for_bm25(&c);
321 let occurrences = enriched.matches("foo").count();
322 assert_eq!(occurrences, 2, "expected 'foo' twice; got: {enriched}");
323 }
324
325 #[test]
328 fn bm25_enrich_last_3_dir_parts() {
329 let c = chunk("a/b/c/d/e/foo.rs", "");
330 let enriched = enrich_for_bm25(&c);
331 assert!(enriched.contains("c d e"), "got: {enriched:?}");
334 assert!(!enriched.contains(" b "), "got: {enriched:?}");
335 }
336
337 #[test]
340 fn bm25_selector_mask_excludes_non_selected() {
341 let chunks = vec![
342 chunk("src/a.rs", "alpha bravo"),
343 chunk("src/b.rs", "alpha gamma"),
344 ];
345 let idx = Bm25Index::build(&chunks);
346 let all = search_bm25("alpha", &idx, 10, None);
348 assert_eq!(all.len(), 2);
349 let masked = search_bm25("alpha", &idx, 10, Some(&[0]));
351 assert_eq!(masked.len(), 1);
352 assert_eq!(masked[0].0, 0);
353 }
354
355 #[test]
358 fn bm25_zero_score_excluded() {
359 let chunks = vec![chunk("src/a.rs", "alpha"), chunk("src/b.rs", "bravo")];
360 let idx = Bm25Index::build(&chunks);
361 let r = search_bm25("alpha", &idx, 10, None);
362 assert_eq!(r.len(), 1);
363 assert_eq!(r[0].0, 0);
364 }
365
366 #[test]
367 fn empty_query_returns_empty() {
368 let chunks = vec![chunk("src/a.rs", "alpha")];
369 let idx = Bm25Index::build(&chunks);
370 assert!(search_bm25("", &idx, 10, None).is_empty());
371 }
372
373 #[test]
376 fn stem_hits_via_enrichment_only() {
377 let chunks = vec![
378 chunk("src/foo.rs", "alpha bravo"),
379 chunk("src/bar.rs", "alpha bravo"),
380 ];
381 let idx = Bm25Index::build(&chunks);
382 let r = search_bm25("foo", &idx, 10, None);
383 assert_eq!(r.len(), 1);
384 assert_eq!(r[0].0, 0);
385 }
386}