ripvec_core/encoder/ripvec/bm25.rs
1//! BM25 with ripvec's stem-doubled path enrichment.
2//!
3//! Port of `~/src/semble/src/semble/index/sparse.py` (`enrich_for_bm25`
4//! and `selector_to_mask`) plus the BM25 scoring loop used in
5//! `~/src/semble/src/semble/search.py:search_bm25`. The enrichment
6//! appends the file stem twice and the last three directory components
7//! to chunk content before tokenization, so path-based queries hit
8//! even when the query terms aren't in the chunk text.
9//!
10//! Python uses the `bm25s` library; this port hand-rolls Okapi BM25
11//! (k1=1.5, b=0.75) to avoid another dependency. The output ordering
12//! matches `bm25s`'s descending-score semantics with zero-score
13//! exclusion as in `search.py:search_bm25`.
14
15use std::path::Path;
16
17use lasso::{Spur, ThreadedRodeo};
18use rayon::prelude::*;
19use rustc_hash::{FxBuildHasher, FxHashMap};
20
21use crate::chunk::CodeChunk;
22use crate::encoder::ripvec::tokens::tokenize;
23
24/// Okapi BM25 free parameter — term-frequency saturation.
25const K1: f32 = 1.5;
26/// Okapi BM25 free parameter — document-length normalization.
27const B: f32 = 0.75;
28
29/// Append the file stem (twice, for up-weight) and the last three
30/// directory components to a chunk's text content. Mirrors
31/// `enrich_for_bm25` from `sparse.py:18`.
32///
33/// Assumes `chunk.file_path` is already repo-relative so
34/// machine-specific directory components don't leak into the index.
35#[must_use]
36pub fn enrich_for_bm25(chunk: &CodeChunk) -> String {
37 let path = Path::new(&chunk.file_path);
38 let stem = path
39 .file_stem()
40 .and_then(|s| s.to_str())
41 .unwrap_or_default();
42 let dir_parts: Vec<&str> = path
43 .parent()
44 .into_iter()
45 .flat_map(|p| p.iter())
46 .filter_map(|os| os.to_str())
47 .filter(|part| *part != "." && *part != "/")
48 .collect();
49 // Last 3 directory components (mirrors Python's dir_parts[-3:]).
50 let tail_len = dir_parts.len().min(3);
51 let dir_text = dir_parts[dir_parts.len() - tail_len..].join(" ");
52 format!("{} {stem} {stem} {dir_text}", chunk.content)
53}
54
55/// Hand-rolled Okapi BM25 index over a set of enriched documents.
56///
57/// Built once via [`Bm25Index::build`]; queried repeatedly via
58/// [`Bm25Index::score`]. Document order matches the chunk-index
59/// convention used elsewhere in the ripvec port.
60pub struct Bm25Index {
61 /// String interner. All term `String`s in the corpus deduplicate to
62 /// a `Spur` (32-bit ID). A 92K-file linux corpus has ~250K chunks ×
63 /// ~50 unique terms each = ~12.5M term references; before interning
64 /// each was a separately-allocated `String` (~500 MB of duplicated
65 /// keys). After interning the keys are 4-byte IDs and only ~500K
66 /// unique strings live in the rodeo (~10 MB).
67 rodeo: ThreadedRodeo<Spur, FxBuildHasher>,
68 /// Per-document length (token count).
69 doc_lengths: Vec<u32>,
70 /// Average document length across the corpus.
71 avgdl: f32,
72 /// Inverted index: term_id -> (doc_frequency, idf).
73 df_idf: FxHashMap<Spur, (u32, f32)>,
74 /// Inverted postings: term_id -> Vec<(doc_idx, tf)>.
75 ///
76 /// Replaces the prior per-document `doc_tfs: Vec<FxHashMap<Spur, u32>>`
77 /// for query scoring. The old layout forced
78 /// `O(query_terms × total_docs)` per query — the score loop iterated
79 /// every doc and HashMap-missed ~99% of them. With postings the
80 /// per-query cost is `O(query_terms × postings_length_per_term)`,
81 /// which on the 1M-chunk corpus collapses from ~5M lookups to ~5K
82 /// updates per query, a ~100x algorithmic win independent of any
83 /// parallelism or SIMD. Profile evidence: `search_bm25` was 41.5%
84 /// of `search_hybrid` wall time post-2A+2B (samply, 2026-05-21).
85 postings: FxHashMap<Spur, Vec<(u32, u32)>>,
86}
87
88impl Bm25Index {
89 /// Build an index over enriched chunks. Tokenization uses
90 /// `crate::encoder::ripvec::tokens::tokenize`.
91 ///
92 /// Three-pass build:
93 ///
94 /// 1. **par_iter (tokenize + intern + TF)**: each chunk is enriched,
95 /// tokenized, and its tokens interned into a shared
96 /// `ThreadedRodeo`. The per-doc TF map keys on the `Spur` ID
97 /// instead of `String`, eliminating the duplicated-string
98 /// storage that dominated memory + hashing in the previous
99 /// version.
100 /// 2. **serial DF merge**: walk per-doc TF maps and increment a
101 /// global `Spur`-keyed counter. With `Spur` keys (4-byte
102 /// `NonZeroU32`), FxHash lookups are a single multiply.
103 /// 3. **serial IDF compute**: produce the final df_idf map.
104 ///
105 /// On a 92K-file linux corpus (~250K chunks): bm25_build drops
106 /// from 35s serial → ~14s parallel without interning → ~7s with
107 /// interning.
108 #[must_use]
109 pub fn build(chunks: &[CodeChunk]) -> Self {
110 let n = chunks.len();
111 let rodeo: ThreadedRodeo<Spur, FxBuildHasher> = ThreadedRodeo::with_hasher(FxBuildHasher);
112 if n == 0 {
113 return Self {
114 rodeo,
115 doc_lengths: Vec::new(),
116 avgdl: 0.0,
117 df_idf: FxHashMap::default(),
118 postings: FxHashMap::default(),
119 };
120 }
121
122 // Stage 1: par_iter — produce per-doc (tfs, token_count) pairs.
123 // `ThreadedRodeo::get_or_intern` is lock-free for the common
124 // (already-interned) case and uses a sharded lock only on
125 // first insert. Worker threads share `&rodeo` safely.
126 let per_doc: Vec<(FxHashMap<Spur, u32>, u32)> = chunks
127 .par_iter()
128 .map(|chunk| {
129 let enriched = enrich_for_bm25(chunk);
130 let tokens = tokenize(&enriched);
131 let token_count = u32::try_from(tokens.len()).unwrap_or(u32::MAX);
132 let mut tfs: FxHashMap<Spur, u32> =
133 FxHashMap::with_capacity_and_hasher(tokens.len(), FxBuildHasher);
134 for tok in &tokens {
135 let id = rodeo.get_or_intern(tok);
136 *tfs.entry(id).or_insert(0) += 1;
137 }
138 (tfs, token_count)
139 })
140 .collect();
141
142 // Stage 2: serial — invert per-doc TF maps into postings, drop
143 // doc_tfs entirely. The postings index maps each term to the
144 // list of (doc_idx, tf) pairs that contain it, enabling
145 // O(posting_length) per-query-term scoring instead of
146 // O(total_docs). For a 1M-chunk corpus with average posting
147 // length ~1K, this is a ~1000x reduction in per-query work.
148 let mut doc_lengths: Vec<u32> = Vec::with_capacity(n);
149 let mut df: FxHashMap<Spur, u32> = FxHashMap::default();
150 let mut postings: FxHashMap<Spur, Vec<(u32, u32)>> = FxHashMap::default();
151 for (doc_idx, (tfs, len)) in per_doc.into_iter().enumerate() {
152 doc_lengths.push(len);
153 let d = u32::try_from(doc_idx).unwrap_or(u32::MAX);
154 for (term_id, tf) in tfs {
155 *df.entry(term_id).or_insert(0) += 1;
156 postings.entry(term_id).or_default().push((d, tf));
157 }
158 }
159 // Shrink each posting list to fit so the index doesn't carry
160 // headroom across the whole corpus.
161 postings.values_mut().for_each(Vec::shrink_to_fit);
162
163 let total_len: u64 = doc_lengths.iter().map(|&l| u64::from(l)).sum();
164 #[expect(
165 clippy::cast_precision_loss,
166 reason = "doc counts are bounded; f32 precision is sufficient for avgdl"
167 )]
168 let avgdl = (total_len as f32) / (n as f32);
169
170 // BM25 idf with the "plus 1" smoothing used by bm25s:
171 // idf(t) = ln( (N - df + 0.5) / (df + 0.5) + 1 )
172 #[expect(
173 clippy::cast_precision_loss,
174 reason = "doc counts are bounded; f32 precision is sufficient for idf"
175 )]
176 let n_f = n as f32;
177 let df_idf: FxHashMap<Spur, (u32, f32)> = df
178 .into_iter()
179 .map(|(term_id, df_count)| {
180 #[expect(
181 clippy::cast_precision_loss,
182 reason = "df is u32; f32 precision sufficient for idf"
183 )]
184 let df_f = df_count as f32;
185 let idf = ((n_f - df_f + 0.5) / (df_f + 0.5) + 1.0).ln();
186 (term_id, (df_count, idf))
187 })
188 .collect();
189
190 Self {
191 rodeo,
192 doc_lengths,
193 avgdl,
194 df_idf,
195 postings,
196 }
197 }
198
199 /// Number of indexed documents.
200 #[must_use]
201 pub fn len(&self) -> usize {
202 self.doc_lengths.len()
203 }
204
205 /// Whether the index has zero documents.
206 #[must_use]
207 pub fn is_empty(&self) -> bool {
208 self.doc_lengths.is_empty()
209 }
210
211 /// Compute BM25 scores for `query` against every document.
212 /// Returns a `Vec<f32>` of length `self.len()` (one score per doc).
213 /// Zero scores indicate no query terms matched.
214 ///
215 /// Postings-list scoring: walks `postings[term]` for each query
216 /// term (typically <1% of corpus). Per-term work is dispatched via
217 /// rayon: each thread accumulates a local scores vector, all
218 /// vectors fold-reduce at the end. Parallelism is bounded by the
219 /// number of distinct query terms; for the common 1-5-term query
220 /// rayon uses 1-5 workers, which is appropriate — the algorithmic
221 /// win from inversion dwarfs any further parallel scaling.
222 #[must_use]
223 pub fn score(&self, query: &str) -> Vec<f32> {
224 let n = self.doc_lengths.len();
225 let q_tokens = tokenize(query);
226 if q_tokens.is_empty() || n == 0 {
227 return vec![0.0; n];
228 }
229 // Resolve query terms to interned IDs, dropping unknown terms
230 // (they can't possibly score) and deduplicating.
231 let mut query_ids: Vec<Spur> = Vec::with_capacity(q_tokens.len());
232 let mut seen: rustc_hash::FxHashSet<Spur> = rustc_hash::FxHashSet::default();
233 for term in &q_tokens {
234 if let Some(id) = self.rodeo.get(term)
235 && seen.insert(id)
236 {
237 query_ids.push(id);
238 }
239 }
240 if query_ids.is_empty() {
241 return vec![0.0; n];
242 }
243
244 let avgdl = self.avgdl;
245 let doc_lengths = &self.doc_lengths;
246 let df_idf = &self.df_idf;
247 let postings = &self.postings;
248
249 // par_iter over query terms; each thread walks the term's
250 // posting list and writes into a thread-local accumulator.
251 // Reduce sums the accumulators element-wise.
252 query_ids
253 .par_iter()
254 .fold(
255 || vec![0.0_f32; n],
256 |mut acc, term_id| {
257 let Some(&(_, idf)) = df_idf.get(term_id) else {
258 return acc;
259 };
260 let Some(posting) = postings.get(term_id) else {
261 return acc;
262 };
263 #[expect(
264 clippy::cast_precision_loss,
265 reason = "tf/dl are u32 counts; f32 precision sufficient"
266 )]
267 for &(doc_idx, tf) in posting {
268 let tf_f = tf as f32;
269 let dl = doc_lengths[doc_idx as usize] as f32;
270 let norm = if avgdl > 0.0 { dl / avgdl } else { 0.0 };
271 let denom = tf_f + K1 * (1.0 - B + B * norm);
272 acc[doc_idx as usize] += idf * tf_f * (K1 + 1.0) / denom.max(f32::EPSILON);
273 }
274 acc
275 },
276 )
277 .reduce(
278 || vec![0.0_f32; n],
279 |mut a, b| {
280 for i in 0..n {
281 a[i] += b[i];
282 }
283 a
284 },
285 )
286 }
287}
288
289/// Convert a sparse selector (chunk indices to keep) into a dense
290/// boolean mask of `size`. Mirrors `selector_to_mask` from
291/// `sparse.py:9`. Returns `None` when `selector` is `None`.
292#[must_use]
293pub fn selector_to_mask(selector: Option<&[usize]>, size: usize) -> Option<Vec<bool>> {
294 selector.map(|sel| {
295 let mut mask = vec![false; size];
296 for &i in sel {
297 if i < size {
298 mask[i] = true;
299 }
300 }
301 mask
302 })
303}
304
305/// Top-k BM25 search with optional selector mask and zero-score
306/// exclusion. Mirrors `search.py:search_bm25`.
307///
308/// Returns `(chunk_index, score)` pairs sorted by score descending.
309#[must_use]
310pub fn search_bm25(
311 query: &str,
312 index: &Bm25Index,
313 top_k: usize,
314 selector: Option<&[usize]>,
315) -> Vec<(usize, f32)> {
316 if index.is_empty() || top_k == 0 {
317 return Vec::new();
318 }
319 let mask = selector_to_mask(selector, index.len());
320 let mut scores = index.score(query);
321 if let Some(m) = &mask {
322 for (i, allowed) in m.iter().enumerate() {
323 if !allowed {
324 scores[i] = 0.0;
325 }
326 }
327 }
328 let mut indexed: Vec<(usize, f32)> = scores
329 .into_iter()
330 .enumerate()
331 .filter(|(_, s)| *s > 0.0)
332 .collect();
333 indexed.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
334 indexed.truncate(top_k);
335 indexed
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 fn chunk(path: &str, content: &str) -> CodeChunk {
343 CodeChunk {
344 file_path: path.to_string(),
345 name: String::new(),
346 kind: String::new(),
347 start_line: 1,
348 end_line: 1,
349 content: content.to_string(),
350 enriched_content: content.to_string(),
351 }
352 }
353
354 /// `test:bm25-enrich-stem-doubled` — the file stem appears twice in
355 /// the enriched text so BM25 up-weights stem matches.
356 #[test]
357 fn bm25_enrich_stem_doubled() {
358 let c = chunk("src/foo.rs", "fn run() {}");
359 let enriched = enrich_for_bm25(&c);
360 let occurrences = enriched.matches("foo").count();
361 assert_eq!(occurrences, 2, "expected 'foo' twice; got: {enriched}");
362 }
363
364 /// `test:bm25-enrich-last-3-dir-parts` — only the last 3 directory
365 /// components are appended (mirrors Python's `dir_parts[-3:]`).
366 #[test]
367 fn bm25_enrich_last_3_dir_parts() {
368 let c = chunk("a/b/c/d/e/foo.rs", "");
369 let enriched = enrich_for_bm25(&c);
370 // The dir part text should include the last three dirs c, d, e
371 // (in path order), not a or b.
372 assert!(enriched.contains("c d e"), "got: {enriched:?}");
373 assert!(!enriched.contains(" b "), "got: {enriched:?}");
374 }
375
376 /// `test:bm25-selector-mask-excludes-non-selected` — masked chunks
377 /// receive zero score even when they contain query terms.
378 #[test]
379 fn bm25_selector_mask_excludes_non_selected() {
380 let chunks = vec![
381 chunk("src/a.rs", "alpha bravo"),
382 chunk("src/b.rs", "alpha gamma"),
383 ];
384 let idx = Bm25Index::build(&chunks);
385 // Without mask both docs match "alpha".
386 let all = search_bm25("alpha", &idx, 10, None);
387 assert_eq!(all.len(), 2);
388 // With selector [0], only doc 0 is allowed.
389 let masked = search_bm25("alpha", &idx, 10, Some(&[0]));
390 assert_eq!(masked.len(), 1);
391 assert_eq!(masked[0].0, 0);
392 }
393
394 /// `test:bm25-zero-score-excluded` — documents with no query-term
395 /// matches don't appear in the results.
396 #[test]
397 fn bm25_zero_score_excluded() {
398 let chunks = vec![chunk("src/a.rs", "alpha"), chunk("src/b.rs", "bravo")];
399 let idx = Bm25Index::build(&chunks);
400 let r = search_bm25("alpha", &idx, 10, None);
401 assert_eq!(r.len(), 1);
402 assert_eq!(r[0].0, 0);
403 }
404
405 #[test]
406 fn empty_query_returns_empty() {
407 let chunks = vec![chunk("src/a.rs", "alpha")];
408 let idx = Bm25Index::build(&chunks);
409 assert!(search_bm25("", &idx, 10, None).is_empty());
410 }
411
412 /// Stem appears doubled even when the chunk doesn't otherwise
413 /// mention it; this lets a query like "foo" hit `foo.rs` files.
414 #[test]
415 fn stem_hits_via_enrichment_only() {
416 let chunks = vec![
417 chunk("src/foo.rs", "alpha bravo"),
418 chunk("src/bar.rs", "alpha bravo"),
419 ];
420 let idx = Bm25Index::build(&chunks);
421 let r = search_bm25("foo", &idx, 10, None);
422 assert_eq!(r.len(), 1);
423 assert_eq!(r[0].0, 0);
424 }
425}