use std::path::Path;
use lasso::{Spur, ThreadedRodeo};
use rayon::prelude::*;
use rustc_hash::{FxBuildHasher, FxHashMap};
use crate::chunk::CodeChunk;
use crate::encoder::ripvec::tokens::tokenize;
const K1: f32 = 1.5;
const B: f32 = 0.75;
#[must_use]
pub fn enrich_for_bm25(chunk: &CodeChunk) -> String {
let path = Path::new(&chunk.file_path);
let stem = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or_default();
let dir_parts: Vec<&str> = path
.parent()
.into_iter()
.flat_map(|p| p.iter())
.filter_map(|os| os.to_str())
.filter(|part| *part != "." && *part != "/")
.collect();
let tail_len = dir_parts.len().min(3);
let dir_text = dir_parts[dir_parts.len() - tail_len..].join(" ");
format!("{} {stem} {stem} {dir_text}", chunk.content)
}
pub struct Bm25Index {
rodeo: ThreadedRodeo<Spur, FxBuildHasher>,
doc_tfs: Vec<FxHashMap<Spur, u32>>,
doc_lengths: Vec<u32>,
avgdl: f32,
df_idf: FxHashMap<Spur, (u32, f32)>,
}
impl Bm25Index {
#[must_use]
pub fn build(chunks: &[CodeChunk]) -> Self {
let n = chunks.len();
let rodeo: ThreadedRodeo<Spur, FxBuildHasher> = ThreadedRodeo::with_hasher(FxBuildHasher);
if n == 0 {
return Self {
rodeo,
doc_tfs: Vec::new(),
doc_lengths: Vec::new(),
avgdl: 0.0,
df_idf: FxHashMap::default(),
};
}
let per_doc: Vec<(FxHashMap<Spur, u32>, u32)> = chunks
.par_iter()
.map(|chunk| {
let enriched = enrich_for_bm25(chunk);
let tokens = tokenize(&enriched);
let token_count = u32::try_from(tokens.len()).unwrap_or(u32::MAX);
let mut tfs: FxHashMap<Spur, u32> =
FxHashMap::with_capacity_and_hasher(tokens.len(), FxBuildHasher);
for tok in &tokens {
let id = rodeo.get_or_intern(tok);
*tfs.entry(id).or_insert(0) += 1;
}
(tfs, token_count)
})
.collect();
let mut doc_tfs: Vec<FxHashMap<Spur, u32>> = Vec::with_capacity(n);
let mut doc_lengths: Vec<u32> = Vec::with_capacity(n);
let mut df: FxHashMap<Spur, u32> = FxHashMap::default();
for (tfs, len) in per_doc {
for term_id in tfs.keys() {
*df.entry(*term_id).or_insert(0) += 1;
}
doc_lengths.push(len);
doc_tfs.push(tfs);
}
let total_len: u64 = doc_lengths.iter().map(|&l| u64::from(l)).sum();
#[expect(
clippy::cast_precision_loss,
reason = "doc counts are bounded; f32 precision is sufficient for avgdl"
)]
let avgdl = (total_len as f32) / (n as f32);
#[expect(
clippy::cast_precision_loss,
reason = "doc counts are bounded; f32 precision is sufficient for idf"
)]
let n_f = n as f32;
let df_idf: FxHashMap<Spur, (u32, f32)> = df
.into_iter()
.map(|(term_id, df_count)| {
#[expect(
clippy::cast_precision_loss,
reason = "df is u32; f32 precision sufficient for idf"
)]
let df_f = df_count as f32;
let idf = ((n_f - df_f + 0.5) / (df_f + 0.5) + 1.0).ln();
(term_id, (df_count, idf))
})
.collect();
Self {
rodeo,
doc_tfs,
doc_lengths,
avgdl,
df_idf,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.doc_tfs.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.doc_tfs.is_empty()
}
#[must_use]
pub fn score(&self, query: &str) -> Vec<f32> {
let q_tokens = tokenize(query);
if q_tokens.is_empty() || self.doc_tfs.is_empty() {
return vec![0.0; self.doc_tfs.len()];
}
let mut query_ids: Vec<Spur> = Vec::with_capacity(q_tokens.len());
let mut seen: rustc_hash::FxHashSet<Spur> = rustc_hash::FxHashSet::default();
for term in &q_tokens {
if let Some(id) = self.rodeo.get(term)
&& seen.insert(id)
{
query_ids.push(id);
}
}
if query_ids.is_empty() {
return vec![0.0; self.doc_tfs.len()];
}
let mut scores = vec![0.0_f32; self.doc_tfs.len()];
#[expect(
clippy::cast_precision_loss,
reason = "tf/dl are u32 counts; f32 precision sufficient"
)]
for &term_id in &query_ids {
let Some(&(_, idf)) = self.df_idf.get(&term_id) else {
continue;
};
for (doc_idx, tfs) in self.doc_tfs.iter().enumerate() {
let Some(&tf) = tfs.get(&term_id) else {
continue;
};
let tf_f = tf as f32;
let dl = self.doc_lengths[doc_idx] as f32;
let norm = if self.avgdl > 0.0 {
dl / self.avgdl
} else {
0.0
};
let denom = tf_f + K1 * (1.0 - B + B * norm);
scores[doc_idx] += idf * tf_f * (K1 + 1.0) / denom.max(f32::EPSILON);
}
}
scores
}
}
#[must_use]
pub fn selector_to_mask(selector: Option<&[usize]>, size: usize) -> Option<Vec<bool>> {
selector.map(|sel| {
let mut mask = vec![false; size];
for &i in sel {
if i < size {
mask[i] = true;
}
}
mask
})
}
#[must_use]
pub fn search_bm25(
query: &str,
index: &Bm25Index,
top_k: usize,
selector: Option<&[usize]>,
) -> Vec<(usize, f32)> {
if index.is_empty() || top_k == 0 {
return Vec::new();
}
let mask = selector_to_mask(selector, index.len());
let mut scores = index.score(query);
if let Some(m) = &mask {
for (i, allowed) in m.iter().enumerate() {
if !allowed {
scores[i] = 0.0;
}
}
}
let mut indexed: Vec<(usize, f32)> = scores
.into_iter()
.enumerate()
.filter(|(_, s)| *s > 0.0)
.collect();
indexed.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
indexed.truncate(top_k);
indexed
}
#[cfg(test)]
mod tests {
use super::*;
fn chunk(path: &str, content: &str) -> CodeChunk {
CodeChunk {
file_path: path.to_string(),
name: String::new(),
kind: String::new(),
start_line: 1,
end_line: 1,
content: content.to_string(),
enriched_content: content.to_string(),
}
}
#[test]
fn bm25_enrich_stem_doubled() {
let c = chunk("src/foo.rs", "fn run() {}");
let enriched = enrich_for_bm25(&c);
let occurrences = enriched.matches("foo").count();
assert_eq!(occurrences, 2, "expected 'foo' twice; got: {enriched}");
}
#[test]
fn bm25_enrich_last_3_dir_parts() {
let c = chunk("a/b/c/d/e/foo.rs", "");
let enriched = enrich_for_bm25(&c);
assert!(enriched.contains("c d e"), "got: {enriched:?}");
assert!(!enriched.contains(" b "), "got: {enriched:?}");
}
#[test]
fn bm25_selector_mask_excludes_non_selected() {
let chunks = vec![
chunk("src/a.rs", "alpha bravo"),
chunk("src/b.rs", "alpha gamma"),
];
let idx = Bm25Index::build(&chunks);
let all = search_bm25("alpha", &idx, 10, None);
assert_eq!(all.len(), 2);
let masked = search_bm25("alpha", &idx, 10, Some(&[0]));
assert_eq!(masked.len(), 1);
assert_eq!(masked[0].0, 0);
}
#[test]
fn bm25_zero_score_excluded() {
let chunks = vec![chunk("src/a.rs", "alpha"), chunk("src/b.rs", "bravo")];
let idx = Bm25Index::build(&chunks);
let r = search_bm25("alpha", &idx, 10, None);
assert_eq!(r.len(), 1);
assert_eq!(r[0].0, 0);
}
#[test]
fn empty_query_returns_empty() {
let chunks = vec![chunk("src/a.rs", "alpha")];
let idx = Bm25Index::build(&chunks);
assert!(search_bm25("", &idx, 10, None).is_empty());
}
#[test]
fn stem_hits_via_enrichment_only() {
let chunks = vec![
chunk("src/foo.rs", "alpha bravo"),
chunk("src/bar.rs", "alpha bravo"),
];
let idx = Bm25Index::build(&chunks);
let r = search_bm25("foo", &idx, 10, None);
assert_eq!(r.len(), 1);
assert_eq!(r[0].0, 0);
}
}