use std::path::Path;
use lasso::{Spur, ThreadedRodeo};
use rayon::prelude::*;
use rustc_hash::{FxBuildHasher, FxHashMap};
use crate::chunk::CodeChunk;
use crate::encoder::ripvec::tokens::tokenize;
const K1: f32 = 1.5;
const B: f32 = 0.75;
#[must_use]
pub fn enrich_for_bm25(chunk: &CodeChunk) -> String {
let path = Path::new(&chunk.file_path);
let stem = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or_default();
let dir_parts: Vec<&str> = path
.parent()
.into_iter()
.flat_map(|p| p.iter())
.filter_map(|os| os.to_str())
.filter(|part| *part != "." && *part != "/")
.collect();
let tail_len = dir_parts.len().min(3);
let dir_text = dir_parts[dir_parts.len() - tail_len..].join(" ");
format!("{} {stem} {stem} {dir_text}", chunk.content)
}
pub struct Bm25Index {
rodeo: ThreadedRodeo<Spur, FxBuildHasher>,
doc_lengths: Vec<u32>,
avgdl: f32,
df_idf: FxHashMap<Spur, (u32, f32)>,
postings: FxHashMap<Spur, Vec<(u32, u32)>>,
}
impl Bm25Index {
#[must_use]
pub fn build(chunks: &[CodeChunk]) -> Self {
let n = chunks.len();
let rodeo: ThreadedRodeo<Spur, FxBuildHasher> = ThreadedRodeo::with_hasher(FxBuildHasher);
if n == 0 {
return Self {
rodeo,
doc_lengths: Vec::new(),
avgdl: 0.0,
df_idf: FxHashMap::default(),
postings: FxHashMap::default(),
};
}
let per_doc: Vec<(FxHashMap<Spur, u32>, u32)> = chunks
.par_iter()
.map(|chunk| {
let enriched = enrich_for_bm25(chunk);
let tokens = tokenize(&enriched);
let token_count = u32::try_from(tokens.len()).unwrap_or(u32::MAX);
let mut tfs: FxHashMap<Spur, u32> =
FxHashMap::with_capacity_and_hasher(tokens.len(), FxBuildHasher);
for tok in &tokens {
let id = rodeo.get_or_intern(tok);
*tfs.entry(id).or_insert(0) += 1;
}
(tfs, token_count)
})
.collect();
let mut doc_lengths: Vec<u32> = Vec::with_capacity(n);
let mut df: FxHashMap<Spur, u32> = FxHashMap::default();
let mut postings: FxHashMap<Spur, Vec<(u32, u32)>> = FxHashMap::default();
for (doc_idx, (tfs, len)) in per_doc.into_iter().enumerate() {
doc_lengths.push(len);
let d = u32::try_from(doc_idx).unwrap_or(u32::MAX);
for (term_id, tf) in tfs {
*df.entry(term_id).or_insert(0) += 1;
postings.entry(term_id).or_default().push((d, tf));
}
}
postings.values_mut().for_each(Vec::shrink_to_fit);
let total_len: u64 = doc_lengths.iter().map(|&l| u64::from(l)).sum();
#[expect(
clippy::cast_precision_loss,
reason = "doc counts are bounded; f32 precision is sufficient for avgdl"
)]
let avgdl = (total_len as f32) / (n as f32);
#[expect(
clippy::cast_precision_loss,
reason = "doc counts are bounded; f32 precision is sufficient for idf"
)]
let n_f = n as f32;
let df_idf: FxHashMap<Spur, (u32, f32)> = df
.into_iter()
.map(|(term_id, df_count)| {
#[expect(
clippy::cast_precision_loss,
reason = "df is u32; f32 precision sufficient for idf"
)]
let df_f = df_count as f32;
let idf = ((n_f - df_f + 0.5) / (df_f + 0.5) + 1.0).ln();
(term_id, (df_count, idf))
})
.collect();
Self {
rodeo,
doc_lengths,
avgdl,
df_idf,
postings,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.doc_lengths.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.doc_lengths.is_empty()
}
#[must_use]
pub fn score(&self, query: &str) -> Vec<f32> {
let n = self.doc_lengths.len();
let q_tokens = tokenize(query);
if q_tokens.is_empty() || n == 0 {
return vec![0.0; n];
}
let mut query_ids: Vec<Spur> = Vec::with_capacity(q_tokens.len());
let mut seen: rustc_hash::FxHashSet<Spur> = rustc_hash::FxHashSet::default();
for term in &q_tokens {
if let Some(id) = self.rodeo.get(term)
&& seen.insert(id)
{
query_ids.push(id);
}
}
if query_ids.is_empty() {
return vec![0.0; n];
}
let avgdl = self.avgdl;
let doc_lengths = &self.doc_lengths;
let df_idf = &self.df_idf;
let postings = &self.postings;
query_ids
.par_iter()
.fold(
|| vec![0.0_f32; n],
|mut acc, term_id| {
let Some(&(_, idf)) = df_idf.get(term_id) else {
return acc;
};
let Some(posting) = postings.get(term_id) else {
return acc;
};
#[expect(
clippy::cast_precision_loss,
reason = "tf/dl are u32 counts; f32 precision sufficient"
)]
for &(doc_idx, tf) in posting {
let tf_f = tf as f32;
let dl = doc_lengths[doc_idx as usize] as f32;
let norm = if avgdl > 0.0 { dl / avgdl } else { 0.0 };
let denom = tf_f + K1 * (1.0 - B + B * norm);
acc[doc_idx as usize] += idf * tf_f * (K1 + 1.0) / denom.max(f32::EPSILON);
}
acc
},
)
.reduce(
|| vec![0.0_f32; n],
|mut a, b| {
for i in 0..n {
a[i] += b[i];
}
a
},
)
}
}
#[must_use]
pub fn selector_to_mask(selector: Option<&[usize]>, size: usize) -> Option<Vec<bool>> {
selector.map(|sel| {
let mut mask = vec![false; size];
for &i in sel {
if i < size {
mask[i] = true;
}
}
mask
})
}
#[must_use]
pub fn search_bm25(
query: &str,
index: &Bm25Index,
top_k: usize,
selector: Option<&[usize]>,
) -> Vec<(usize, f32)> {
if index.is_empty() || top_k == 0 {
return Vec::new();
}
let mask = selector_to_mask(selector, index.len());
let mut scores = index.score(query);
if let Some(m) = &mask {
for (i, allowed) in m.iter().enumerate() {
if !allowed {
scores[i] = 0.0;
}
}
}
let mut indexed: Vec<(usize, f32)> = scores
.into_iter()
.enumerate()
.filter(|(_, s)| *s > 0.0)
.collect();
indexed.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
indexed.truncate(top_k);
indexed
}
#[cfg(test)]
mod tests {
use super::*;
fn chunk(path: &str, content: &str) -> CodeChunk {
CodeChunk {
file_path: path.to_string(),
name: String::new(),
kind: String::new(),
start_line: 1,
end_line: 1,
content: content.to_string(),
enriched_content: content.to_string(),
}
}
#[test]
fn bm25_enrich_stem_doubled() {
let c = chunk("src/foo.rs", "fn run() {}");
let enriched = enrich_for_bm25(&c);
let occurrences = enriched.matches("foo").count();
assert_eq!(occurrences, 2, "expected 'foo' twice; got: {enriched}");
}
#[test]
fn bm25_enrich_last_3_dir_parts() {
let c = chunk("a/b/c/d/e/foo.rs", "");
let enriched = enrich_for_bm25(&c);
assert!(enriched.contains("c d e"), "got: {enriched:?}");
assert!(!enriched.contains(" b "), "got: {enriched:?}");
}
#[test]
fn bm25_selector_mask_excludes_non_selected() {
let chunks = vec![
chunk("src/a.rs", "alpha bravo"),
chunk("src/b.rs", "alpha gamma"),
];
let idx = Bm25Index::build(&chunks);
let all = search_bm25("alpha", &idx, 10, None);
assert_eq!(all.len(), 2);
let masked = search_bm25("alpha", &idx, 10, Some(&[0]));
assert_eq!(masked.len(), 1);
assert_eq!(masked[0].0, 0);
}
#[test]
fn bm25_zero_score_excluded() {
let chunks = vec![chunk("src/a.rs", "alpha"), chunk("src/b.rs", "bravo")];
let idx = Bm25Index::build(&chunks);
let r = search_bm25("alpha", &idx, 10, None);
assert_eq!(r.len(), 1);
assert_eq!(r[0].0, 0);
}
#[test]
fn empty_query_returns_empty() {
let chunks = vec![chunk("src/a.rs", "alpha")];
let idx = Bm25Index::build(&chunks);
assert!(search_bm25("", &idx, 10, None).is_empty());
}
#[test]
fn stem_hits_via_enrichment_only() {
let chunks = vec![
chunk("src/foo.rs", "alpha bravo"),
chunk("src/bar.rs", "alpha bravo"),
];
let idx = Bm25Index::build(&chunks);
let r = search_bm25("foo", &idx, 10, None);
assert_eq!(r.len(), 1);
assert_eq!(r[0].0, 0);
}
}