#![deny(missing_docs)]
#[derive(Debug, Clone, Copy)]
pub struct Bm25Opts {
pub k1: f32,
pub b: f32,
}
impl Default for Bm25Opts {
fn default() -> Self {
Self { k1: 1.2, b: 0.75 }
}
}
pub fn rerank<S: AsRef<str>>(query: &str, docs: &[S], opts: Bm25Opts) -> Vec<usize> {
let scores = score(query, docs, opts);
let mut indices: Vec<usize> = (0..docs.len()).collect();
indices.sort_by(|&a, &b| {
scores[b]
.partial_cmp(&scores[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
indices
}
pub fn score<S: AsRef<str>>(query: &str, docs: &[S], opts: Bm25Opts) -> Vec<f32> {
let q_terms: Vec<String> = tokenize(query);
if q_terms.is_empty() || docs.is_empty() {
return vec![0.0; docs.len()];
}
let doc_tokens: Vec<Vec<String>> = docs.iter().map(|d| tokenize(d.as_ref())).collect();
let lens: Vec<f32> = doc_tokens.iter().map(|t| t.len() as f32).collect();
let avgdl: f32 = if lens.is_empty() {
0.0
} else {
lens.iter().sum::<f32>() / lens.len() as f32
};
let n = doc_tokens.len() as f32;
let mut scores = vec![0.0_f32; doc_tokens.len()];
for term in &q_terms {
let df = doc_tokens
.iter()
.filter(|t| t.iter().any(|x| x == term))
.count() as f32;
if df == 0.0 {
continue;
}
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
for (i, tokens) in doc_tokens.iter().enumerate() {
let tf = tokens.iter().filter(|x| *x == term).count() as f32;
if tf == 0.0 {
continue;
}
let dl = lens[i];
let denom = tf + opts.k1 * (1.0 - opts.b + opts.b * (dl / avgdl.max(1.0)));
scores[i] += idf * (tf * (opts.k1 + 1.0)) / denom;
}
}
scores
}
fn tokenize(s: &str) -> Vec<String> {
s.split(|c: char| !c.is_alphanumeric())
.filter(|t| !t.is_empty())
.map(|t| t.to_ascii_lowercase())
.collect()
}