use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
pub struct Bm25Params {
pub k1: f64,
pub b: f64,
}
#[cfg(test)]
mod defaults {
pub const DEFAULT_K1: f64 = 1.2;
pub const CODE_B: f64 = 0.5;
pub const KNOWLEDGE_B: f64 = 0.3;
}
#[cfg(test)]
impl Bm25Params {
pub fn code() -> Self {
Self {
k1: defaults::DEFAULT_K1,
b: defaults::CODE_B,
}
}
pub fn knowledge() -> Self {
Self {
k1: defaults::DEFAULT_K1,
b: defaults::KNOWLEDGE_B,
}
}
}
pub fn bm25_score(
tf: &HashMap<String, u32>,
doc_len: u32,
query_tokens: &[String],
doc_freq: &HashMap<String, u32>,
total_docs: usize,
avg_doc_len: f64,
params: &Bm25Params,
) -> f64 {
let n = total_docs as f64;
let dl = doc_len as f64;
let mut score = 0.0;
for token in query_tokens {
let term_freq = tf.get(token).copied().unwrap_or(0) as f64;
if term_freq == 0.0 {
continue;
}
let df = doc_freq.get(token).copied().unwrap_or(0) as f64;
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
let tf_norm = (term_freq * (params.k1 + 1.0))
/ (term_freq + params.k1 * (1.0 - params.b + params.b * dl / avg_doc_len));
score += idf * tf_norm;
}
score
}
pub fn elbow_cutoff(
scored: &[(usize, f64)],
max_k: usize,
drop_threshold: f64,
) -> Vec<(usize, f64)> {
let mut result = Vec::new();
for (idx, &(i, score)) in scored.iter().enumerate() {
if idx > 0 {
let prev_score = result
.last()
.map(|&(_, s): &(usize, f64)| s)
.unwrap_or(score);
if prev_score / (score + 0.001) > drop_threshold {
break;
}
}
result.push((i, score));
if result.len() >= max_k {
break;
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25_score_basic() {
let mut tf = HashMap::new();
tf.insert("error".to_string(), 3);
tf.insert("handling".to_string(), 1);
let mut doc_freq = HashMap::new();
doc_freq.insert("error".to_string(), 2);
doc_freq.insert("handling".to_string(), 1);
let query = vec!["error".to_string(), "handling".to_string()];
let score = bm25_score(&tf, 10, &query, &doc_freq, 5, 8.0, &Bm25Params::code());
assert!(score > 0.0, "score should be positive for matching terms");
}
#[test]
fn test_bm25_score_zero_for_no_match() {
let tf = HashMap::new(); let mut doc_freq = HashMap::new();
doc_freq.insert("error".to_string(), 1);
let query = vec!["error".to_string()];
let score = bm25_score(&tf, 0, &query, &doc_freq, 3, 5.0, &Bm25Params::knowledge());
assert_eq!(score, 0.0);
}
#[test]
fn test_elbow_cutoff_stops_at_drop() {
let scored = vec![
(0, 10.0),
(1, 9.5),
(2, 8.0),
(3, 2.0), (4, 1.0),
];
let result = elbow_cutoff(&scored, 10, 3.0);
assert_eq!(result.len(), 3);
}
#[test]
fn test_elbow_cutoff_respects_max_k() {
let scored = vec![(0, 10.0), (1, 9.5), (2, 9.0), (3, 8.5)];
let result = elbow_cutoff(&scored, 2, 3.0);
assert_eq!(result.len(), 2);
}
#[test]
fn test_params_presets() {
let code = Bm25Params::code();
assert_eq!(code.k1, 1.2);
assert_eq!(code.b, 0.5);
let knowledge = Bm25Params::knowledge();
assert_eq!(knowledge.k1, 1.2);
assert_eq!(knowledge.b, 0.3);
}
}