Skip to main content

ripvec_core/encoder/ripvec/
bm25.rs

1//! BM25 with ripvec's stem-doubled path enrichment.
2//!
3//! Port of `~/src/semble/src/semble/index/sparse.py` (`enrich_for_bm25`
4//! and `selector_to_mask`) plus the BM25 scoring loop used in
5//! `~/src/semble/src/semble/search.py:search_bm25`. The enrichment
6//! appends the file stem twice and the last three directory components
7//! to chunk content before tokenization, so path-based queries hit
8//! even when the query terms aren't in the chunk text.
9//!
10//! Python uses the `bm25s` library; this port hand-rolls Okapi BM25
11//! (k1=1.5, b=0.75) to avoid another dependency. The output ordering
12//! matches `bm25s`'s descending-score semantics with zero-score
13//! exclusion as in `search.py:search_bm25`.
14
15use std::path::Path;
16
17use lasso::{Spur, ThreadedRodeo};
18use rayon::prelude::*;
19use rustc_hash::{FxBuildHasher, FxHashMap};
20
21use crate::chunk::CodeChunk;
22use crate::encoder::ripvec::tokens::tokenize;
23
24/// Okapi BM25 free parameter — term-frequency saturation.
25const K1: f32 = 1.5;
26/// Okapi BM25 free parameter — document-length normalization.
27const B: f32 = 0.75;
28
29/// Append the file stem (twice, for up-weight) and the last three
30/// directory components to a chunk's text content. Mirrors
31/// `enrich_for_bm25` from `sparse.py:18`.
32///
33/// Assumes `chunk.file_path` is already repo-relative so
34/// machine-specific directory components don't leak into the index.
35#[must_use]
36pub fn enrich_for_bm25(chunk: &CodeChunk) -> String {
37    let path = Path::new(&chunk.file_path);
38    let stem = path
39        .file_stem()
40        .and_then(|s| s.to_str())
41        .unwrap_or_default();
42    let dir_parts: Vec<&str> = path
43        .parent()
44        .into_iter()
45        .flat_map(|p| p.iter())
46        .filter_map(|os| os.to_str())
47        .filter(|part| *part != "." && *part != "/")
48        .collect();
49    // Last 3 directory components (mirrors Python's dir_parts[-3:]).
50    let tail_len = dir_parts.len().min(3);
51    let dir_text = dir_parts[dir_parts.len() - tail_len..].join(" ");
52    format!("{} {stem} {stem} {dir_text}", chunk.content)
53}
54
55/// Hand-rolled Okapi BM25 index over a set of enriched documents.
56///
57/// Built once via [`Bm25Index::build`]; queried repeatedly via
58/// [`Bm25Index::score`]. Document order matches the chunk-index
59/// convention used elsewhere in the ripvec port.
60pub struct Bm25Index {
61    /// String interner. All term `String`s in the corpus deduplicate to
62    /// a `Spur` (32-bit ID). A 92K-file linux corpus has ~250K chunks ×
63    /// ~50 unique terms each = ~12.5M term references; before interning
64    /// each was a separately-allocated `String` (~500 MB of duplicated
65    /// keys). After interning the keys are 4-byte IDs and only ~500K
66    /// unique strings live in the rodeo (~10 MB).
67    rodeo: ThreadedRodeo<Spur, FxBuildHasher>,
68    /// Per-document length (token count).
69    doc_lengths: Vec<u32>,
70    /// Average document length across the corpus.
71    avgdl: f32,
72    /// Inverted index: term_id -> (doc_frequency, idf).
73    df_idf: FxHashMap<Spur, (u32, f32)>,
74    /// Inverted postings: term_id -> Vec<(doc_idx, tf)>.
75    ///
76    /// Replaces the prior per-document `doc_tfs: Vec<FxHashMap<Spur, u32>>`
77    /// for query scoring. The old layout forced
78    /// `O(query_terms × total_docs)` per query — the score loop iterated
79    /// every doc and HashMap-missed ~99% of them. With postings the
80    /// per-query cost is `O(query_terms × postings_length_per_term)`,
81    /// which on the 1M-chunk corpus collapses from ~5M lookups to ~5K
82    /// updates per query, a ~100x algorithmic win independent of any
83    /// parallelism or SIMD. Profile evidence: `search_bm25` was 41.5%
84    /// of `search_hybrid` wall time post-2A+2B (samply, 2026-05-21).
85    postings: FxHashMap<Spur, Vec<(u32, u32)>>,
86}
87
88impl Bm25Index {
89    /// Build an index over enriched chunks. Tokenization uses
90    /// `crate::encoder::ripvec::tokens::tokenize`.
91    ///
92    /// Three-pass build:
93    ///
94    /// 1. **par_iter (tokenize + intern + TF)**: each chunk is enriched,
95    ///    tokenized, and its tokens interned into a shared
96    ///    `ThreadedRodeo`. The per-doc TF map keys on the `Spur` ID
97    ///    instead of `String`, eliminating the duplicated-string
98    ///    storage that dominated memory + hashing in the previous
99    ///    version.
100    /// 2. **serial DF merge**: walk per-doc TF maps and increment a
101    ///    global `Spur`-keyed counter. With `Spur` keys (4-byte
102    ///    `NonZeroU32`), FxHash lookups are a single multiply.
103    /// 3. **serial IDF compute**: produce the final df_idf map.
104    ///
105    /// On a 92K-file linux corpus (~250K chunks): bm25_build drops
106    /// from 35s serial → ~14s parallel without interning → ~7s with
107    /// interning.
108    #[must_use]
109    pub fn build(chunks: &[CodeChunk]) -> Self {
110        let n = chunks.len();
111        let rodeo: ThreadedRodeo<Spur, FxBuildHasher> = ThreadedRodeo::with_hasher(FxBuildHasher);
112        if n == 0 {
113            return Self {
114                rodeo,
115                doc_lengths: Vec::new(),
116                avgdl: 0.0,
117                df_idf: FxHashMap::default(),
118                postings: FxHashMap::default(),
119            };
120        }
121
122        // Stage 1: par_iter — produce per-doc (tfs, token_count) pairs.
123        // `ThreadedRodeo::get_or_intern` is lock-free for the common
124        // (already-interned) case and uses a sharded lock only on
125        // first insert. Worker threads share `&rodeo` safely.
126        let per_doc: Vec<(FxHashMap<Spur, u32>, u32)> = chunks
127            .par_iter()
128            .map(|chunk| {
129                let enriched = enrich_for_bm25(chunk);
130                let tokens = tokenize(&enriched);
131                let token_count = u32::try_from(tokens.len()).unwrap_or(u32::MAX);
132                let mut tfs: FxHashMap<Spur, u32> =
133                    FxHashMap::with_capacity_and_hasher(tokens.len(), FxBuildHasher);
134                for tok in &tokens {
135                    let id = rodeo.get_or_intern(tok);
136                    *tfs.entry(id).or_insert(0) += 1;
137                }
138                (tfs, token_count)
139            })
140            .collect();
141
142        // Stage 2: serial — invert per-doc TF maps into postings, drop
143        // doc_tfs entirely. The postings index maps each term to the
144        // list of (doc_idx, tf) pairs that contain it, enabling
145        // O(posting_length) per-query-term scoring instead of
146        // O(total_docs). For a 1M-chunk corpus with average posting
147        // length ~1K, this is a ~1000x reduction in per-query work.
148        let mut doc_lengths: Vec<u32> = Vec::with_capacity(n);
149        let mut df: FxHashMap<Spur, u32> = FxHashMap::default();
150        let mut postings: FxHashMap<Spur, Vec<(u32, u32)>> = FxHashMap::default();
151        for (doc_idx, (tfs, len)) in per_doc.into_iter().enumerate() {
152            doc_lengths.push(len);
153            let d = u32::try_from(doc_idx).unwrap_or(u32::MAX);
154            for (term_id, tf) in tfs {
155                *df.entry(term_id).or_insert(0) += 1;
156                postings.entry(term_id).or_default().push((d, tf));
157            }
158        }
159        // Shrink each posting list to fit so the index doesn't carry
160        // headroom across the whole corpus.
161        postings.values_mut().for_each(Vec::shrink_to_fit);
162
163        let total_len: u64 = doc_lengths.iter().map(|&l| u64::from(l)).sum();
164        #[expect(
165            clippy::cast_precision_loss,
166            reason = "doc counts are bounded; f32 precision is sufficient for avgdl"
167        )]
168        let avgdl = (total_len as f32) / (n as f32);
169
170        // BM25 idf with the "plus 1" smoothing used by bm25s:
171        //   idf(t) = ln( (N - df + 0.5) / (df + 0.5) + 1 )
172        #[expect(
173            clippy::cast_precision_loss,
174            reason = "doc counts are bounded; f32 precision is sufficient for idf"
175        )]
176        let n_f = n as f32;
177        let df_idf: FxHashMap<Spur, (u32, f32)> = df
178            .into_iter()
179            .map(|(term_id, df_count)| {
180                #[expect(
181                    clippy::cast_precision_loss,
182                    reason = "df is u32; f32 precision sufficient for idf"
183                )]
184                let df_f = df_count as f32;
185                let idf = ((n_f - df_f + 0.5) / (df_f + 0.5) + 1.0).ln();
186                (term_id, (df_count, idf))
187            })
188            .collect();
189
190        Self {
191            rodeo,
192            doc_lengths,
193            avgdl,
194            df_idf,
195            postings,
196        }
197    }
198
199    /// Number of indexed documents.
200    #[must_use]
201    pub fn len(&self) -> usize {
202        self.doc_lengths.len()
203    }
204
205    /// Whether the index has zero documents.
206    #[must_use]
207    pub fn is_empty(&self) -> bool {
208        self.doc_lengths.is_empty()
209    }
210
211    /// Compute BM25 scores for `query` against every document.
212    /// Returns a `Vec<f32>` of length `self.len()` (one score per doc).
213    /// Zero scores indicate no query terms matched.
214    ///
215    /// Postings-list scoring: walks `postings[term]` for each query
216    /// term (typically <1% of corpus). Per-term work is dispatched via
217    /// rayon: each thread accumulates a local scores vector, all
218    /// vectors fold-reduce at the end. Parallelism is bounded by the
219    /// number of distinct query terms; for the common 1-5-term query
220    /// rayon uses 1-5 workers, which is appropriate — the algorithmic
221    /// win from inversion dwarfs any further parallel scaling.
222    #[must_use]
223    pub fn score(&self, query: &str) -> Vec<f32> {
224        let n = self.doc_lengths.len();
225        let q_tokens = tokenize(query);
226        if q_tokens.is_empty() || n == 0 {
227            return vec![0.0; n];
228        }
229        // Resolve query terms to interned IDs, dropping unknown terms
230        // (they can't possibly score) and deduplicating.
231        let mut query_ids: Vec<Spur> = Vec::with_capacity(q_tokens.len());
232        let mut seen: rustc_hash::FxHashSet<Spur> = rustc_hash::FxHashSet::default();
233        for term in &q_tokens {
234            if let Some(id) = self.rodeo.get(term)
235                && seen.insert(id)
236            {
237                query_ids.push(id);
238            }
239        }
240        if query_ids.is_empty() {
241            return vec![0.0; n];
242        }
243
244        let avgdl = self.avgdl;
245        let doc_lengths = &self.doc_lengths;
246        let df_idf = &self.df_idf;
247        let postings = &self.postings;
248
249        // par_iter over query terms; each thread walks the term's
250        // posting list and writes into a thread-local accumulator.
251        // Reduce sums the accumulators element-wise.
252        query_ids
253            .par_iter()
254            .fold(
255                || vec![0.0_f32; n],
256                |mut acc, term_id| {
257                    let Some(&(_, idf)) = df_idf.get(term_id) else {
258                        return acc;
259                    };
260                    let Some(posting) = postings.get(term_id) else {
261                        return acc;
262                    };
263                    #[expect(
264                        clippy::cast_precision_loss,
265                        reason = "tf/dl are u32 counts; f32 precision sufficient"
266                    )]
267                    for &(doc_idx, tf) in posting {
268                        let tf_f = tf as f32;
269                        let dl = doc_lengths[doc_idx as usize] as f32;
270                        let norm = if avgdl > 0.0 { dl / avgdl } else { 0.0 };
271                        let denom = tf_f + K1 * (1.0 - B + B * norm);
272                        acc[doc_idx as usize] += idf * tf_f * (K1 + 1.0) / denom.max(f32::EPSILON);
273                    }
274                    acc
275                },
276            )
277            .reduce(
278                || vec![0.0_f32; n],
279                |mut a, b| {
280                    for i in 0..n {
281                        a[i] += b[i];
282                    }
283                    a
284                },
285            )
286    }
287}
288
289/// Convert a sparse selector (chunk indices to keep) into a dense
290/// boolean mask of `size`. Mirrors `selector_to_mask` from
291/// `sparse.py:9`. Returns `None` when `selector` is `None`.
292#[must_use]
293pub fn selector_to_mask(selector: Option<&[usize]>, size: usize) -> Option<Vec<bool>> {
294    selector.map(|sel| {
295        let mut mask = vec![false; size];
296        for &i in sel {
297            if i < size {
298                mask[i] = true;
299            }
300        }
301        mask
302    })
303}
304
305/// Top-k BM25 search with optional selector mask and zero-score
306/// exclusion. Mirrors `search.py:search_bm25`.
307///
308/// Returns `(chunk_index, score)` pairs sorted by score descending.
309#[must_use]
310pub fn search_bm25(
311    query: &str,
312    index: &Bm25Index,
313    top_k: usize,
314    selector: Option<&[usize]>,
315) -> Vec<(usize, f32)> {
316    if index.is_empty() || top_k == 0 {
317        return Vec::new();
318    }
319    let mask = selector_to_mask(selector, index.len());
320    let mut scores = index.score(query);
321    if let Some(m) = &mask {
322        for (i, allowed) in m.iter().enumerate() {
323            if !allowed {
324                scores[i] = 0.0;
325            }
326        }
327    }
328    let mut indexed: Vec<(usize, f32)> = scores
329        .into_iter()
330        .enumerate()
331        .filter(|(_, s)| *s > 0.0)
332        .collect();
333    indexed.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
334    indexed.truncate(top_k);
335    indexed
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    fn chunk(path: &str, content: &str) -> CodeChunk {
343        CodeChunk {
344            file_path: path.to_string(),
345            name: String::new(),
346            kind: String::new(),
347            start_line: 1,
348            end_line: 1,
349            content: content.to_string(),
350            enriched_content: content.to_string(),
351        }
352    }
353
354    /// `test:bm25-enrich-stem-doubled` — the file stem appears twice in
355    /// the enriched text so BM25 up-weights stem matches.
356    #[test]
357    fn bm25_enrich_stem_doubled() {
358        let c = chunk("src/foo.rs", "fn run() {}");
359        let enriched = enrich_for_bm25(&c);
360        let occurrences = enriched.matches("foo").count();
361        assert_eq!(occurrences, 2, "expected 'foo' twice; got: {enriched}");
362    }
363
364    /// `test:bm25-enrich-last-3-dir-parts` — only the last 3 directory
365    /// components are appended (mirrors Python's `dir_parts[-3:]`).
366    #[test]
367    fn bm25_enrich_last_3_dir_parts() {
368        let c = chunk("a/b/c/d/e/foo.rs", "");
369        let enriched = enrich_for_bm25(&c);
370        // The dir part text should include the last three dirs c, d, e
371        // (in path order), not a or b.
372        assert!(enriched.contains("c d e"), "got: {enriched:?}");
373        assert!(!enriched.contains(" b "), "got: {enriched:?}");
374    }
375
376    /// `test:bm25-selector-mask-excludes-non-selected` — masked chunks
377    /// receive zero score even when they contain query terms.
378    #[test]
379    fn bm25_selector_mask_excludes_non_selected() {
380        let chunks = vec![
381            chunk("src/a.rs", "alpha bravo"),
382            chunk("src/b.rs", "alpha gamma"),
383        ];
384        let idx = Bm25Index::build(&chunks);
385        // Without mask both docs match "alpha".
386        let all = search_bm25("alpha", &idx, 10, None);
387        assert_eq!(all.len(), 2);
388        // With selector [0], only doc 0 is allowed.
389        let masked = search_bm25("alpha", &idx, 10, Some(&[0]));
390        assert_eq!(masked.len(), 1);
391        assert_eq!(masked[0].0, 0);
392    }
393
394    /// `test:bm25-zero-score-excluded` — documents with no query-term
395    /// matches don't appear in the results.
396    #[test]
397    fn bm25_zero_score_excluded() {
398        let chunks = vec![chunk("src/a.rs", "alpha"), chunk("src/b.rs", "bravo")];
399        let idx = Bm25Index::build(&chunks);
400        let r = search_bm25("alpha", &idx, 10, None);
401        assert_eq!(r.len(), 1);
402        assert_eq!(r[0].0, 0);
403    }
404
405    #[test]
406    fn empty_query_returns_empty() {
407        let chunks = vec![chunk("src/a.rs", "alpha")];
408        let idx = Bm25Index::build(&chunks);
409        assert!(search_bm25("", &idx, 10, None).is_empty());
410    }
411
412    /// Stem appears doubled even when the chunk doesn't otherwise
413    /// mention it; this lets a query like "foo" hit `foo.rs` files.
414    #[test]
415    fn stem_hits_via_enrichment_only() {
416        let chunks = vec![
417            chunk("src/foo.rs", "alpha bravo"),
418            chunk("src/bar.rs", "alpha bravo"),
419        ];
420        let idx = Bm25Index::build(&chunks);
421        let r = search_bm25("foo", &idx, 10, None);
422        assert_eq!(r.len(), 1);
423        assert_eq!(r[0].0, 0);
424    }
425}