use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
pub struct BM25Config {
pub k1: f32,
pub b: f32,
pub min_idf: f32,
}
impl Default for BM25Config {
fn default() -> Self {
Self {
k1: 1.2,
b: 0.75,
min_idf: 0.0,
}
}
}
impl BM25Config {
pub fn lucene() -> Self {
Self {
k1: 1.2,
b: 0.75,
min_idf: 0.0,
}
}
pub fn elasticsearch() -> Self {
Self {
k1: 1.2,
b: 0.75,
min_idf: 0.0,
}
}
pub fn short_queries() -> Self {
Self {
k1: 1.5,
b: 0.5, min_idf: 0.0,
}
}
}
pub struct BM25Scorer {
config: BM25Config,
num_docs: usize,
total_len: usize,
doc_freqs: HashMap<String, usize>,
}
impl BM25Scorer {
pub fn new(config: BM25Config) -> Self {
Self {
config,
num_docs: 0,
total_len: 0,
doc_freqs: HashMap::new(),
}
}
pub fn build<I, D, T>(documents: I, config: BM25Config) -> Self
where
I: IntoIterator<Item = D>,
D: IntoIterator<Item = T>,
T: AsRef<str>,
{
let mut scorer = Self::new(config);
let mut total_len = 0usize;
let mut num_docs = 0usize;
let mut doc_freqs: HashMap<String, usize> = HashMap::new();
for doc in documents {
num_docs += 1;
let mut seen_terms: std::collections::HashSet<String> =
std::collections::HashSet::new();
let mut doc_len = 0usize;
for token in doc {
let term = token.as_ref().to_lowercase();
if !term.is_empty() {
seen_terms.insert(term);
doc_len += 1;
}
}
total_len += doc_len;
for term in seen_terms {
*doc_freqs.entry(term).or_insert(0) += 1;
}
}
scorer.num_docs = num_docs;
scorer.total_len = total_len;
scorer.doc_freqs = doc_freqs;
scorer
}
#[inline]
pub fn avg_doc_len(&self) -> f32 {
if self.num_docs > 0 {
self.total_len as f32 / self.num_docs as f32
} else {
0.0
}
}
#[inline]
pub fn config(&self) -> BM25Config {
self.config
}
#[inline]
fn compute_idf(&self, df: usize, n: usize) -> f32 {
let n = n as f32;
let df = df as f32;
((n - df + 0.5) / (df + 0.5) + 1.0).ln()
}
pub fn idf(&self, term: &str) -> f32 {
let df = self
.doc_freqs
.get(&term.to_lowercase())
.copied()
.unwrap_or(0);
let idf = self.compute_idf(df, self.num_docs);
if idf < self.config.min_idf { 0.0 } else { idf }
}
pub fn score<I, T>(&self, query_terms: I, doc_terms: &[T], doc_len: usize) -> f32
where
I: IntoIterator<Item = T>,
T: AsRef<str> + std::hash::Hash + Eq,
{
let mut tf: HashMap<&str, usize> = HashMap::new();
for term in doc_terms {
*tf.entry(term.as_ref()).or_insert(0) += 1;
}
let k1 = self.config.k1;
let b = self.config.b;
let dl = doc_len as f32;
let avgdl = self.avg_doc_len();
let mut score = 0.0f32;
for query_term in query_terms {
let term = query_term.as_ref().to_lowercase();
let term_str = term.as_str();
let term_tf = *tf.get(term_str).unwrap_or(&0) as f32;
if term_tf == 0.0 {
continue;
}
let idf = self.idf(&term);
let numerator = term_tf * (k1 + 1.0);
let denominator = term_tf + k1 * (1.0 - b + b * dl / avgdl);
score += idf * numerator / denominator;
}
score
}
#[inline]
pub fn score_with_tf(
&self,
query_terms: &[String],
doc_tf: &HashMap<String, usize>,
doc_len: usize,
) -> f32 {
self.score_tf_lookup(query_terms, doc_len, |term| {
*doc_tf.get(term).unwrap_or(&0) as f32
})
}
#[inline]
pub fn score_with_tf_u32(
&self,
query_terms: &[String],
doc_tf: &HashMap<String, u32>,
doc_len: usize,
) -> f32 {
self.score_tf_lookup(query_terms, doc_len, |term| {
*doc_tf.get(term).unwrap_or(&0) as f32
})
}
#[inline]
fn score_tf_lookup<F>(&self, query_terms: &[String], doc_len: usize, mut tf_of: F) -> f32
where
F: FnMut(&str) -> f32,
{
let k1 = self.config.k1;
let b = self.config.b;
let dl = doc_len as f32;
let avgdl = self.avg_doc_len();
let mut score = 0.0f32;
for term in query_terms {
let term_tf = tf_of(term);
if term_tf == 0.0 {
continue;
}
let idf = self.idf(term);
let numerator = term_tf * (k1 + 1.0);
let denominator = term_tf + k1 * (1.0 - b + b * dl / avgdl);
score += idf * numerator / denominator;
}
score
}
pub fn add_document<I, T>(&mut self, tokens: I)
where
I: IntoIterator<Item = T>,
T: AsRef<str>,
{
let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut doc_len = 0usize;
for token in tokens {
let term = token.as_ref().to_lowercase();
if !term.is_empty() {
seen.insert(term);
doc_len += 1;
}
}
self.num_docs += 1;
self.total_len += doc_len;
for term in seen {
*self.doc_freqs.entry(term).or_insert(0) += 1;
}
}
pub fn remove_document<'a, I>(&mut self, unique_terms: I, doc_len: usize)
where
I: IntoIterator<Item = &'a str>,
{
if self.num_docs == 0 {
return;
}
self.num_docs -= 1;
self.total_len = self.total_len.saturating_sub(doc_len);
for term in unique_terms {
let term = term.to_lowercase();
if let Some(df) = self.doc_freqs.get_mut(&term) {
*df -= 1;
if *df == 0 {
self.doc_freqs.remove(&term);
}
}
}
}
pub fn stats(&self) -> BM25Stats {
BM25Stats {
num_docs: self.num_docs,
avg_doc_len: self.avg_doc_len(),
vocab_size: self.doc_freqs.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct BM25Stats {
pub num_docs: usize,
pub avg_doc_len: f32,
pub vocab_size: usize,
}
pub fn tokenize(text: &str) -> Vec<String> {
text.split_whitespace()
.map(|s| s.to_lowercase())
.filter(|s| !s.is_empty())
.collect()
}
pub fn tokenize_minimal(text: &str) -> Vec<String> {
text.split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
.map(|s| s.to_lowercase())
.filter(|s| !s.is_empty() && s.len() > 1) .collect()
}
pub fn tokenize_query(text: &str) -> Vec<String> {
let mut tokens = Vec::new();
for part in text.split_whitespace() {
let lower = part.to_lowercase();
tokens.push(lower);
}
tokens
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25_basic() {
let docs = vec![
vec!["hello", "world"],
vec!["hello", "there"],
vec!["goodbye", "world"],
];
let scorer = BM25Scorer::build(docs.iter().map(|d| d.iter()), BM25Config::default());
assert_eq!(scorer.num_docs, 3);
assert!((scorer.avg_doc_len() - 2.0).abs() < 0.001);
}
#[test]
fn test_bm25_idf() {
let docs = vec![
vec!["common", "common", "rare"],
vec!["common", "other"],
vec!["common", "another"],
];
let scorer = BM25Scorer::build(docs.iter().map(|d| d.iter()), BM25Config::default());
let idf_common = scorer.idf("common");
let idf_rare = scorer.idf("rare");
assert!(idf_rare > idf_common);
}
#[test]
fn test_bm25_scoring() {
let docs = vec![
vec!["the", "quick", "brown", "fox"],
vec!["the", "lazy", "dog"],
vec!["quick", "quick", "quick"], ];
let scorer = BM25Scorer::build(docs.iter().map(|d| d.iter()), BM25Config::default());
let score = scorer.score(vec!["quick"], &["quick", "quick", "quick"], 3);
assert!(score > 0.0);
let score1 = scorer.score(vec!["quick"], &["the", "quick", "brown", "fox"], 4);
assert!(score > score1);
}
#[test]
fn test_tokenize() {
let text = "Hello, World! This is a test.";
let tokens = tokenize(text);
assert_eq!(tokens, vec!["hello,", "world!", "this", "is", "a", "test."]);
}
#[test]
fn test_tokenize_minimal() {
let text = "Hello, World! This is a test.";
let tokens = tokenize_minimal(text);
assert!(tokens.contains(&"hello".to_string()));
assert!(tokens.contains(&"world".to_string()));
assert!(!tokens.contains(&"a".to_string())); }
#[test]
fn test_add_document() {
let mut scorer = BM25Scorer::new(BM25Config::default());
scorer.add_document(vec!["hello", "world"]);
assert_eq!(scorer.num_docs, 1);
scorer.add_document(vec!["hello", "there", "friend"]);
assert_eq!(scorer.num_docs, 2);
assert!((scorer.avg_doc_len() - 2.5).abs() < 0.001);
}
#[test]
fn test_build_equals_incremental() {
let docs: Vec<Vec<&str>> = vec![
vec!["the", "quick", "brown", "fox"],
vec!["the", "lazy", "dog", "sleeps"],
vec!["quick", "quick", "brown", "dog"],
vec!["the", "fox", "and", "the", "dog"],
];
let batch = BM25Scorer::build(docs.iter().map(|d| d.iter()), BM25Config::default());
let mut incremental = BM25Scorer::new(BM25Config::default());
for d in &docs {
incremental.add_document(d.iter().copied());
}
assert_eq!(batch.num_docs, incremental.num_docs);
assert_eq!(batch.total_len, incremental.total_len);
assert_eq!(
batch.avg_doc_len().to_bits(),
incremental.avg_doc_len().to_bits()
);
for term in [
"the", "quick", "brown", "fox", "lazy", "dog", "sleeps", "and",
] {
assert_eq!(
batch.idf(term).to_bits(),
incremental.idf(term).to_bits(),
"IDF mismatch for term {term:?}"
);
}
let doc = ["quick", "quick", "brown", "dog"];
assert_eq!(
batch.score(vec!["quick", "dog"], &doc, doc.len()).to_bits(),
incremental
.score(vec!["quick", "dog"], &doc, doc.len())
.to_bits(),
);
}
}