Skip to main content

ripvec_core/encoder/ripvec/
bm25.rs

1//! BM25 with ripvec's stem-doubled path enrichment.
2//!
3//! Port of `~/src/semble/src/semble/index/sparse.py` (`enrich_for_bm25`
4//! and `selector_to_mask`) plus the BM25 scoring loop used in
5//! `~/src/semble/src/semble/search.py:search_bm25`. The enrichment
6//! appends the file stem twice and the last three directory components
7//! to chunk content before tokenization, so path-based queries hit
8//! even when the query terms aren't in the chunk text.
9//!
10//! Python uses the `bm25s` library; this port hand-rolls Okapi BM25
11//! (k1=1.5, b=0.75) to avoid another dependency. The output ordering
12//! matches `bm25s`'s descending-score semantics with zero-score
13//! exclusion as in `search.py:search_bm25`.
14
15use std::path::Path;
16
17use lasso::{Spur, ThreadedRodeo};
18use rayon::prelude::*;
19use rustc_hash::{FxBuildHasher, FxHashMap};
20
21use crate::chunk::CodeChunk;
22use crate::encoder::ripvec::tokens::tokenize;
23
24/// Okapi BM25 free parameter — term-frequency saturation.
25const K1: f32 = 1.5;
26/// Okapi BM25 free parameter — document-length normalization.
27const B: f32 = 0.75;
28
29/// Append the file stem (twice, for up-weight) and the last three
30/// directory components to a chunk's text content. Mirrors
31/// `enrich_for_bm25` from `sparse.py:18`.
32///
33/// Assumes `chunk.file_path` is already repo-relative so
34/// machine-specific directory components don't leak into the index.
35#[must_use]
36pub fn enrich_for_bm25(chunk: &CodeChunk) -> String {
37    let path = Path::new(&chunk.file_path);
38    let stem = path
39        .file_stem()
40        .and_then(|s| s.to_str())
41        .unwrap_or_default();
42    let dir_parts: Vec<&str> = path
43        .parent()
44        .into_iter()
45        .flat_map(|p| p.iter())
46        .filter_map(|os| os.to_str())
47        .filter(|part| *part != "." && *part != "/")
48        .collect();
49    // Last 3 directory components (mirrors Python's dir_parts[-3:]).
50    let tail_len = dir_parts.len().min(3);
51    let dir_text = dir_parts[dir_parts.len() - tail_len..].join(" ");
52    format!("{} {stem} {stem} {dir_text}", chunk.content)
53}
54
55/// Hand-rolled Okapi BM25 index over a set of enriched documents.
56///
57/// Built once via [`Bm25Index::build`]; queried repeatedly via
58/// [`Bm25Index::score`]. Document order matches the chunk-index
59/// convention used elsewhere in the ripvec port.
60pub struct Bm25Index {
61    /// String interner. All term `String`s in the corpus deduplicate to
62    /// a `Spur` (32-bit ID). A 92K-file linux corpus has ~250K chunks ×
63    /// ~50 unique terms each = ~12.5M term references; before interning
64    /// each was a separately-allocated `String` (~500 MB of duplicated
65    /// keys). After interning the keys are 4-byte IDs and only ~500K
66    /// unique strings live in the rodeo (~10 MB).
67    rodeo: ThreadedRodeo<Spur, FxBuildHasher>,
68    /// Per-document term frequencies: doc_tfs[doc][term_id] = count.
69    /// `FxHashMap<Spur, u32>` — Spur is a 32-bit `NonZeroU32`, so
70    /// hashing is a single-multiply (vs SipHash on a String pointer
71    /// chain) and the map is half the size.
72    doc_tfs: Vec<FxHashMap<Spur, u32>>,
73    /// Per-document length (token count).
74    doc_lengths: Vec<u32>,
75    /// Average document length across the corpus.
76    avgdl: f32,
77    /// Inverted index: term_id -> (doc_frequency, idf).
78    df_idf: FxHashMap<Spur, (u32, f32)>,
79}
80
81impl Bm25Index {
82    /// Build an index over enriched chunks. Tokenization uses
83    /// `crate::encoder::ripvec::tokens::tokenize`.
84    ///
85    /// Three-pass build:
86    ///
87    /// 1. **par_iter (tokenize + intern + TF)**: each chunk is enriched,
88    ///    tokenized, and its tokens interned into a shared
89    ///    `ThreadedRodeo`. The per-doc TF map keys on the `Spur` ID
90    ///    instead of `String`, eliminating the duplicated-string
91    ///    storage that dominated memory + hashing in the previous
92    ///    version.
93    /// 2. **serial DF merge**: walk per-doc TF maps and increment a
94    ///    global `Spur`-keyed counter. With `Spur` keys (4-byte
95    ///    `NonZeroU32`), FxHash lookups are a single multiply.
96    /// 3. **serial IDF compute**: produce the final df_idf map.
97    ///
98    /// On a 92K-file linux corpus (~250K chunks): bm25_build drops
99    /// from 35s serial → ~14s parallel without interning → ~7s with
100    /// interning.
101    #[must_use]
102    pub fn build(chunks: &[CodeChunk]) -> Self {
103        let n = chunks.len();
104        let rodeo: ThreadedRodeo<Spur, FxBuildHasher> = ThreadedRodeo::with_hasher(FxBuildHasher);
105        if n == 0 {
106            return Self {
107                rodeo,
108                doc_tfs: Vec::new(),
109                doc_lengths: Vec::new(),
110                avgdl: 0.0,
111                df_idf: FxHashMap::default(),
112            };
113        }
114
115        // Stage 1: par_iter — produce per-doc (tfs, token_count) pairs.
116        // `ThreadedRodeo::get_or_intern` is lock-free for the common
117        // (already-interned) case and uses a sharded lock only on
118        // first insert. Worker threads share `&rodeo` safely.
119        let per_doc: Vec<(FxHashMap<Spur, u32>, u32)> = chunks
120            .par_iter()
121            .map(|chunk| {
122                let enriched = enrich_for_bm25(chunk);
123                let tokens = tokenize(&enriched);
124                let token_count = u32::try_from(tokens.len()).unwrap_or(u32::MAX);
125                let mut tfs: FxHashMap<Spur, u32> =
126                    FxHashMap::with_capacity_and_hasher(tokens.len(), FxBuildHasher);
127                for tok in &tokens {
128                    let id = rodeo.get_or_intern(tok);
129                    *tfs.entry(id).or_insert(0) += 1;
130                }
131                (tfs, token_count)
132            })
133            .collect();
134
135        // Stage 2: serial — extract doc_tfs/doc_lengths and accumulate df.
136        // With Spur keys this is just a u32 counter increment; no
137        // allocations on the hot path.
138        let mut doc_tfs: Vec<FxHashMap<Spur, u32>> = Vec::with_capacity(n);
139        let mut doc_lengths: Vec<u32> = Vec::with_capacity(n);
140        let mut df: FxHashMap<Spur, u32> = FxHashMap::default();
141        for (tfs, len) in per_doc {
142            for term_id in tfs.keys() {
143                *df.entry(*term_id).or_insert(0) += 1;
144            }
145            doc_lengths.push(len);
146            doc_tfs.push(tfs);
147        }
148
149        let total_len: u64 = doc_lengths.iter().map(|&l| u64::from(l)).sum();
150        #[expect(
151            clippy::cast_precision_loss,
152            reason = "doc counts are bounded; f32 precision is sufficient for avgdl"
153        )]
154        let avgdl = (total_len as f32) / (n as f32);
155
156        // BM25 idf with the "plus 1" smoothing used by bm25s:
157        //   idf(t) = ln( (N - df + 0.5) / (df + 0.5) + 1 )
158        #[expect(
159            clippy::cast_precision_loss,
160            reason = "doc counts are bounded; f32 precision is sufficient for idf"
161        )]
162        let n_f = n as f32;
163        let df_idf: FxHashMap<Spur, (u32, f32)> = df
164            .into_iter()
165            .map(|(term_id, df_count)| {
166                #[expect(
167                    clippy::cast_precision_loss,
168                    reason = "df is u32; f32 precision sufficient for idf"
169                )]
170                let df_f = df_count as f32;
171                let idf = ((n_f - df_f + 0.5) / (df_f + 0.5) + 1.0).ln();
172                (term_id, (df_count, idf))
173            })
174            .collect();
175
176        Self {
177            rodeo,
178            doc_tfs,
179            doc_lengths,
180            avgdl,
181            df_idf,
182        }
183    }
184
185    /// Number of indexed documents.
186    #[must_use]
187    pub fn len(&self) -> usize {
188        self.doc_tfs.len()
189    }
190
191    /// Whether the index has zero documents.
192    #[must_use]
193    pub fn is_empty(&self) -> bool {
194        self.doc_tfs.is_empty()
195    }
196
197    /// Compute BM25 scores for `query` against every document.
198    /// Returns a `Vec<f32>` of length `self.len()` (one score per doc).
199    /// Zero scores indicate no query terms matched.
200    #[must_use]
201    pub fn score(&self, query: &str) -> Vec<f32> {
202        let q_tokens = tokenize(query);
203        if q_tokens.is_empty() || self.doc_tfs.is_empty() {
204            return vec![0.0; self.doc_tfs.len()];
205        }
206        // Resolve query terms to interned IDs, dropping unknown terms
207        // (they can't possibly score) and deduplicating. `self.rodeo.get`
208        // is a read-only lookup — never modifies the interner.
209        let mut query_ids: Vec<Spur> = Vec::with_capacity(q_tokens.len());
210        let mut seen: rustc_hash::FxHashSet<Spur> = rustc_hash::FxHashSet::default();
211        for term in &q_tokens {
212            if let Some(id) = self.rodeo.get(term)
213                && seen.insert(id)
214            {
215                query_ids.push(id);
216            }
217        }
218        if query_ids.is_empty() {
219            return vec![0.0; self.doc_tfs.len()];
220        }
221
222        let mut scores = vec![0.0_f32; self.doc_tfs.len()];
223        #[expect(
224            clippy::cast_precision_loss,
225            reason = "tf/dl are u32 counts; f32 precision sufficient"
226        )]
227        for &term_id in &query_ids {
228            let Some(&(_, idf)) = self.df_idf.get(&term_id) else {
229                continue;
230            };
231            for (doc_idx, tfs) in self.doc_tfs.iter().enumerate() {
232                let Some(&tf) = tfs.get(&term_id) else {
233                    continue;
234                };
235                let tf_f = tf as f32;
236                let dl = self.doc_lengths[doc_idx] as f32;
237                let norm = if self.avgdl > 0.0 {
238                    dl / self.avgdl
239                } else {
240                    0.0
241                };
242                let denom = tf_f + K1 * (1.0 - B + B * norm);
243                scores[doc_idx] += idf * tf_f * (K1 + 1.0) / denom.max(f32::EPSILON);
244            }
245        }
246        scores
247    }
248}
249
250/// Convert a sparse selector (chunk indices to keep) into a dense
251/// boolean mask of `size`. Mirrors `selector_to_mask` from
252/// `sparse.py:9`. Returns `None` when `selector` is `None`.
253#[must_use]
254pub fn selector_to_mask(selector: Option<&[usize]>, size: usize) -> Option<Vec<bool>> {
255    selector.map(|sel| {
256        let mut mask = vec![false; size];
257        for &i in sel {
258            if i < size {
259                mask[i] = true;
260            }
261        }
262        mask
263    })
264}
265
266/// Top-k BM25 search with optional selector mask and zero-score
267/// exclusion. Mirrors `search.py:search_bm25`.
268///
269/// Returns `(chunk_index, score)` pairs sorted by score descending.
270#[must_use]
271pub fn search_bm25(
272    query: &str,
273    index: &Bm25Index,
274    top_k: usize,
275    selector: Option<&[usize]>,
276) -> Vec<(usize, f32)> {
277    if index.is_empty() || top_k == 0 {
278        return Vec::new();
279    }
280    let mask = selector_to_mask(selector, index.len());
281    let mut scores = index.score(query);
282    if let Some(m) = &mask {
283        for (i, allowed) in m.iter().enumerate() {
284            if !allowed {
285                scores[i] = 0.0;
286            }
287        }
288    }
289    let mut indexed: Vec<(usize, f32)> = scores
290        .into_iter()
291        .enumerate()
292        .filter(|(_, s)| *s > 0.0)
293        .collect();
294    indexed.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
295    indexed.truncate(top_k);
296    indexed
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    fn chunk(path: &str, content: &str) -> CodeChunk {
304        CodeChunk {
305            file_path: path.to_string(),
306            name: String::new(),
307            kind: String::new(),
308            start_line: 1,
309            end_line: 1,
310            content: content.to_string(),
311            enriched_content: content.to_string(),
312        }
313    }
314
315    /// `test:bm25-enrich-stem-doubled` — the file stem appears twice in
316    /// the enriched text so BM25 up-weights stem matches.
317    #[test]
318    fn bm25_enrich_stem_doubled() {
319        let c = chunk("src/foo.rs", "fn run() {}");
320        let enriched = enrich_for_bm25(&c);
321        let occurrences = enriched.matches("foo").count();
322        assert_eq!(occurrences, 2, "expected 'foo' twice; got: {enriched}");
323    }
324
325    /// `test:bm25-enrich-last-3-dir-parts` — only the last 3 directory
326    /// components are appended (mirrors Python's `dir_parts[-3:]`).
327    #[test]
328    fn bm25_enrich_last_3_dir_parts() {
329        let c = chunk("a/b/c/d/e/foo.rs", "");
330        let enriched = enrich_for_bm25(&c);
331        // The dir part text should include the last three dirs c, d, e
332        // (in path order), not a or b.
333        assert!(enriched.contains("c d e"), "got: {enriched:?}");
334        assert!(!enriched.contains(" b "), "got: {enriched:?}");
335    }
336
337    /// `test:bm25-selector-mask-excludes-non-selected` — masked chunks
338    /// receive zero score even when they contain query terms.
339    #[test]
340    fn bm25_selector_mask_excludes_non_selected() {
341        let chunks = vec![
342            chunk("src/a.rs", "alpha bravo"),
343            chunk("src/b.rs", "alpha gamma"),
344        ];
345        let idx = Bm25Index::build(&chunks);
346        // Without mask both docs match "alpha".
347        let all = search_bm25("alpha", &idx, 10, None);
348        assert_eq!(all.len(), 2);
349        // With selector [0], only doc 0 is allowed.
350        let masked = search_bm25("alpha", &idx, 10, Some(&[0]));
351        assert_eq!(masked.len(), 1);
352        assert_eq!(masked[0].0, 0);
353    }
354
355    /// `test:bm25-zero-score-excluded` — documents with no query-term
356    /// matches don't appear in the results.
357    #[test]
358    fn bm25_zero_score_excluded() {
359        let chunks = vec![chunk("src/a.rs", "alpha"), chunk("src/b.rs", "bravo")];
360        let idx = Bm25Index::build(&chunks);
361        let r = search_bm25("alpha", &idx, 10, None);
362        assert_eq!(r.len(), 1);
363        assert_eq!(r[0].0, 0);
364    }
365
366    #[test]
367    fn empty_query_returns_empty() {
368        let chunks = vec![chunk("src/a.rs", "alpha")];
369        let idx = Bm25Index::build(&chunks);
370        assert!(search_bm25("", &idx, 10, None).is_empty());
371    }
372
373    /// Stem appears doubled even when the chunk doesn't otherwise
374    /// mention it; this lets a query like "foo" hit `foo.rs` files.
375    #[test]
376    fn stem_hits_via_enrichment_only() {
377        let chunks = vec![
378            chunk("src/foo.rs", "alpha bravo"),
379            chunk("src/bar.rs", "alpha bravo"),
380        ];
381        let idx = Bm25Index::build(&chunks);
382        let r = search_bm25("foo", &idx, 10, None);
383        assert_eq!(r.len(), 1);
384        assert_eq!(r[0].0, 0);
385    }
386}