use bm25::{EmbedderBuilder, Scorer, Tokenizer};
use std::collections::{BTreeMap, BTreeSet};
pub(crate) fn tokenize(s: &str) -> Vec<String> {
s.split(|c: char| !c.is_alphanumeric())
.filter(|t| !t.is_empty())
.map(str::to_lowercase)
.collect()
}
pub(crate) struct LexDoc {
pub(crate) id: String,
pub(crate) text: String,
}
pub(crate) enum LexicalCorpus<'a> {
Raw(&'a [LexDoc]),
}
impl LexicalCorpus<'_> {
fn docs(&self) -> &[LexDoc] {
match self {
LexicalCorpus::Raw(docs) => docs,
}
}
}
pub(crate) trait LexicalRanker {
fn score(
&self,
query: Option<&str>,
corpus: &LexicalCorpus<'_>,
targets: &[&str],
) -> Vec<(String, u32)>;
}
#[cfg(test)]
pub(crate) struct OverlapRanker;
pub(crate) const LEX_SCALE: f32 = 1_000_000.0;
pub(crate) fn quantize(score: f32) -> u32 {
debug_assert!(score.is_finite(), "non-finite lexical score: scorer bug"); if !score.is_finite() {
return 0;
}
let scaled = (score.max(0.0) * LEX_SCALE).round();
#[expect(
clippy::as_conversions,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
reason = "saturating float→u32 (Rust >= 1.45); scaled is finite and >= 0; no safe std API"
)]
let q = scaled as u32;
q
}
fn assert_targets_subset(corpus: &LexicalCorpus<'_>, targets: &[&str]) {
let docs = corpus.docs();
debug_assert!(
docs.iter()
.map(|d| d.id.as_str())
.collect::<BTreeSet<_>>()
.len()
== docs.len(),
"duplicate corpus ids violate uniqueness"
);
let ids: BTreeSet<&str> = docs.iter().map(|d| d.id.as_str()).collect();
for t in targets {
assert!(
ids.contains(t),
"target id not in fit corpus (targets ⊆ corpus violated)"
);
}
}
#[cfg(test)]
impl LexicalRanker for OverlapRanker {
fn score(
&self,
query: Option<&str>,
corpus: &LexicalCorpus<'_>,
targets: &[&str],
) -> Vec<(String, u32)> {
assert_targets_subset(corpus, targets);
let q_tokens: BTreeSet<String> = match query {
Some(q) => tokenize(q).into_iter().collect(),
None => BTreeSet::new(),
};
let by_id: BTreeMap<&str, &str> = corpus
.docs()
.iter()
.map(|d| (d.id.as_str(), d.text.as_str()))
.collect();
targets
.iter()
.map(|t| {
let hits = match (q_tokens.is_empty(), by_id.get(t)) {
(false, Some(text)) => {
let bag: BTreeSet<String> = tokenize(text).into_iter().collect();
q_tokens.iter().filter(|qt| bag.contains(*qt)).count()
}
_ => 0,
};
((*t).to_string(), u32::try_from(hits).unwrap_or(u32::MAX))
})
.collect()
}
}
type Space = u32;
struct LexTokenizer;
impl Tokenizer for LexTokenizer {
fn tokenize(&self, input_text: &str) -> Vec<String> {
tokenize(input_text) }
}
pub(crate) struct Bm25Ranker;
fn zeros(targets: &[&str]) -> Vec<(String, u32)> {
targets.iter().map(|t| ((*t).to_string(), 0)).collect()
}
impl LexicalRanker for Bm25Ranker {
fn score(
&self,
query: Option<&str>,
corpus: &LexicalCorpus<'_>,
targets: &[&str],
) -> Vec<(String, u32)> {
assert_targets_subset(corpus, targets);
let docs = corpus.docs();
let q = match query {
Some(q) if !tokenize(q).is_empty() && !docs.is_empty() => q,
_ => return zeros(targets), };
let texts: Vec<&str> = docs.iter().map(|d| d.text.as_str()).collect();
let embedder =
EmbedderBuilder::<Space, _>::with_tokenizer_and_fit_to_corpus(LexTokenizer, &texts)
.build();
let mut scorer = Scorer::<String, Space>::new();
for d in docs {
scorer.upsert(&d.id, embedder.embed(d.text.as_str()));
}
let matched = scorer.matches(&embedder.embed(q));
let scored: BTreeMap<&str, f32> =
matched.iter().map(|m| (m.id.as_str(), m.score)).collect();
targets
.iter()
.map(|t| {
(
(*t).to_string(),
quantize(scored.get(t).copied().unwrap_or(0.0)),
)
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tokenize_casefolds_and_splits_on_non_alphanumeric() {
assert_eq!(tokenize("Auth Bug"), vec!["auth", "bug"]);
assert_eq!(
tokenize("src/memory.rs__OK"),
vec!["src", "memory", "rs", "ok"]
);
assert_eq!(tokenize("mem.auth.token"), vec!["mem", "auth", "token"]);
assert!(tokenize(" ").is_empty());
}
#[test]
fn quantize_zero_is_zero() {
assert_eq!(quantize(0.0), 0);
}
#[test]
fn quantize_is_monotonic_non_decreasing() {
let xs = [0.0_f32, 1e-6, 0.1, 1.0, 2.5, 30.0, 1000.0];
for w in xs.windows(2) {
assert!(
quantize(w[0]) <= quantize(w[1]),
"quantize not monotonic at {:?}",
w
);
}
}
#[test]
fn quantize_saturates() {
assert_eq!(quantize(f32::MAX), u32::MAX);
assert_eq!(quantize(1e30), u32::MAX);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "non-finite lexical score")]
fn quantize_non_finite_panics_in_debug() {
let _ = quantize(f32::NAN);
}
#[test]
#[cfg(not(debug_assertions))]
fn quantize_non_finite_is_zero_in_release() {
assert_eq!(quantize(f32::NAN), 0);
assert_eq!(quantize(f32::INFINITY), 0);
}
#[test]
fn quantize_negative_is_zero() {
assert_eq!(quantize(-1.0), 0);
}
fn doc(id: &str, text: &str) -> LexDoc {
LexDoc {
id: id.to_string(),
text: text.to_string(),
}
}
#[test]
fn overlap_counts_distinct_query_tokens_over_text() {
let docs = vec![doc("a", "token expiry middleware check rust mem auth flow")];
let corpus = LexicalCorpus::Raw(&docs);
assert_eq!(
OverlapRanker.score(Some("token middleware rust auth"), &corpus, &["a"]),
vec![("a".to_string(), 4)]
);
assert_eq!(
OverlapRanker.score(Some("token token token"), &corpus, &["a"]),
vec![("a".to_string(), 1)]
);
assert_eq!(
OverlapRanker.score(Some("python django"), &corpus, &["a"]),
vec![("a".to_string(), 0)]
);
}
#[test]
fn overlap_no_query_is_all_zero() {
let docs = vec![doc("a", "token"), doc("b", "auth")];
let corpus = LexicalCorpus::Raw(&docs);
assert_eq!(
OverlapRanker.score(None, &corpus, &["a", "b"]),
vec![("a".to_string(), 0), ("b".to_string(), 0)]
);
assert_eq!(
OverlapRanker.score(Some(""), &corpus, &["a"]),
vec![("a".to_string(), 0)]
);
assert_eq!(
OverlapRanker.score(Some(" ... "), &corpus, &["a"]),
vec![("a".to_string(), 0)]
);
}
#[test]
fn overlap_is_positional_over_targets() {
let docs = vec![doc("a", "token"), doc("b", "auth"), doc("c", "rust")];
let corpus = LexicalCorpus::Raw(&docs);
let got = OverlapRanker.score(Some("rust token"), &corpus, &["c", "a"]);
assert_eq!(got, vec![("c".to_string(), 1), ("a".to_string(), 1)]);
assert_eq!(got.len(), 2);
}
#[test]
fn overlap_empty_targets_is_empty_vec() {
let docs = vec![doc("a", "token")];
let corpus = LexicalCorpus::Raw(&docs);
assert!(OverlapRanker.score(Some("token"), &corpus, &[]).is_empty());
}
#[test]
#[should_panic(expected = "targets ⊆ corpus")]
fn target_outside_corpus_panics_in_all_builds() {
let docs = vec![doc("a", "token")];
let corpus = LexicalCorpus::Raw(&docs);
let _ = OverlapRanker.score(Some("token"), &corpus, &["ghost"]);
}
#[test]
#[should_panic(expected = "duplicate corpus ids")]
fn duplicate_corpus_ids_trip_debug_assert() {
let docs = vec![doc("a", "token"), doc("a", "other")];
let corpus = LexicalCorpus::Raw(&docs);
let _ = OverlapRanker.score(Some("token"), &corpus, &["a"]);
}
#[test]
fn bm25_idf_rare_outranks_common() {
let docs = vec![
doc("a", "common alpha"),
doc("b", "rare beta"),
doc("c", "common gamma"),
doc("d", "common delta"),
];
let corpus = LexicalCorpus::Raw(&docs);
let got = Bm25Ranker.score(Some("common rare"), &corpus, &["a", "b"]);
let (sa, sb) = (got[0].1, got[1].1);
assert!(
sb > sa,
"rare-term target must outrank common-term: a={sa} b={sb}"
);
}
#[test]
fn bm25_length_norm_shorter_outranks_longer() {
let docs = vec![
doc("short", "term"),
doc("long", "term alpha beta gamma delta"),
];
let corpus = LexicalCorpus::Raw(&docs);
let got = Bm25Ranker.score(Some("term"), &corpus, &["short", "long"]);
assert!(
got[0].1 > got[1].1,
"shorter doc must outrank longer at equal TF: {got:?}"
);
}
#[test]
fn bm25_is_shuffle_invariant_and_repeatable() {
let d1 = vec![
doc("a", "alpha beta"),
doc("b", "beta gamma gamma"),
doc("c", "alpha"),
];
let d2 = vec![
doc("c", "alpha"),
doc("a", "alpha beta"),
doc("b", "beta gamma gamma"),
];
let c1 = LexicalCorpus::Raw(&d1);
let c2 = LexicalCorpus::Raw(&d2);
let targets = ["a", "b", "c"];
let r1 = Bm25Ranker.score(Some("beta gamma"), &c1, &targets);
let r2 = Bm25Ranker.score(Some("beta gamma"), &c2, &targets);
assert_eq!(
r1, r2,
"permuted corpus order must yield identical per-target scores"
);
assert_eq!(r1, Bm25Ranker.score(Some("beta gamma"), &c1, &targets));
}
#[test]
fn bm25_edges_are_all_zero_with_exact_arity() {
let docs = vec![doc("a", "token auth"), doc("b", "rust")];
let corpus = LexicalCorpus::Raw(&docs);
let z = vec![("a".to_string(), 0), ("b".to_string(), 0)];
assert_eq!(Bm25Ranker.score(None, &corpus, &["a", "b"]), z);
assert_eq!(Bm25Ranker.score(Some(""), &corpus, &["a", "b"]), z);
assert_eq!(Bm25Ranker.score(Some(" ... "), &corpus, &["a", "b"]), z);
}
#[test]
fn bm25_empty_corpus_returns_empty_vec() {
let docs: Vec<LexDoc> = vec![];
let corpus = LexicalCorpus::Raw(&docs);
assert!(Bm25Ranker.score(Some("token"), &corpus, &[]).is_empty());
}
#[test]
fn bm25_survivor_untouched_by_query_is_zero() {
let docs = vec![doc("a", "token auth"), doc("b", "python django")];
let corpus = LexicalCorpus::Raw(&docs);
let got = Bm25Ranker.score(Some("token"), &corpus, &["a", "b"]);
assert!(got[0].1 > 0, "matched target must be nonzero");
assert_eq!(got[1].1, 0, "survivor no query term touches ⇒ 0");
}
#[test]
fn bm25_df_reflects_full_corpus_not_targets() {
let docs = vec![
doc("a", "common"),
doc("b", "rare"),
doc("x", "common"),
doc("y", "common"),
doc("z", "common"),
];
let corpus = LexicalCorpus::Raw(&docs);
let got = Bm25Ranker.score(Some("common rare"), &corpus, &["a", "b"]);
assert!(
got[1].1 > got[0].1,
"df over full corpus must depress common: {got:?}"
);
}
#[test]
fn bm25_avgdl_equals_multiset_mean_on_real_tokenizer() {
let texts = ["mem.auth.token", "src/x.rs ok", "single"];
let reference_mean = {
let total: usize = texts.iter().map(|t| tokenize(t).len()).sum();
total as f32 / texts.len() as f32
};
let embedder =
EmbedderBuilder::<Space, _>::with_tokenizer_and_fit_to_corpus(LexTokenizer, &texts)
.build();
assert_eq!(
embedder.avgdl(),
reference_mean,
"self-computed avgdl must equal multiset mean over real tokenize (A3, VT-5)"
);
}
}