use std::collections::HashMap;
use std::f64;
use super::common::{BM25Params, FieldWeights};
use terraphim_types::Document;
pub struct BM25FScorer {
params: BM25Params,
weights: FieldWeights,
avg_doc_length: f64,
doc_count: usize,
term_doc_frequencies: HashMap<String, usize>,
}
impl BM25FScorer {
pub fn new() -> Self {
Self {
params: BM25Params::default(),
weights: FieldWeights::default(),
avg_doc_length: 0.0,
doc_count: 0,
term_doc_frequencies: HashMap::new(),
}
}
#[allow(dead_code)]
pub fn with_params(params: BM25Params, weights: FieldWeights) -> Self {
Self {
params,
weights,
avg_doc_length: 0.0,
doc_count: 0,
term_doc_frequencies: HashMap::new(),
}
}
pub fn initialize(&mut self, documents: &[Document]) {
self.doc_count = documents.len();
let total_length: usize = documents
.iter()
.map(|doc| {
let title_len = doc.title.split_whitespace().count();
let body_len = doc.body.split_whitespace().count();
let desc_len = doc
.description
.as_ref()
.map_or(0, |d| d.split_whitespace().count());
let tags_len = doc.tags.as_ref().map_or(0, |t| {
t.iter().map(|tag| tag.split_whitespace().count()).sum()
});
title_len + body_len + desc_len + tags_len
})
.sum();
if self.doc_count > 0 {
self.avg_doc_length = total_length as f64 / self.doc_count as f64;
}
let mut term_doc_frequencies = HashMap::new();
for doc in documents {
let mut terms = Vec::new();
terms.extend(doc.title.split_whitespace().map(|s| s.to_lowercase()));
terms.extend(doc.body.split_whitespace().map(|s| s.to_lowercase()));
if let Some(desc) = &doc.description {
terms.extend(desc.split_whitespace().map(|s| s.to_lowercase()));
}
if let Some(tags) = &doc.tags {
for tag in tags {
terms.extend(tag.split_whitespace().map(|s| s.to_lowercase()));
}
}
let mut doc_terms = std::collections::HashSet::new();
for term in terms {
doc_terms.insert(term);
}
for term in doc_terms {
*term_doc_frequencies.entry(term).or_insert(0) += 1;
}
}
self.term_doc_frequencies = term_doc_frequencies;
}
pub fn score(&self, query: &str, doc: &Document) -> f64 {
let query_terms: Vec<String> = query.split_whitespace().map(|s| s.to_lowercase()).collect();
if query_terms.is_empty() || self.doc_count == 0 {
return 0.0;
}
let mut score = 0.0;
for term in &query_terms {
let n_docs_with_term = self.term_doc_frequencies.get(term).copied().unwrap_or(0);
if n_docs_with_term == 0 {
continue;
}
let idf = f64::ln(
(self.doc_count as f64 - n_docs_with_term as f64 + 0.5)
/ (n_docs_with_term as f64 + 0.5)
+ 1.0,
);
let mut weighted_tf = 0.0;
let title_tf = count_term_occurrences(&doc.title, term);
weighted_tf += self.weights.title * title_tf as f64;
let body_tf = count_term_occurrences(&doc.body, term);
weighted_tf += self.weights.body * body_tf as f64;
if let Some(desc) = &doc.description {
let desc_tf = count_term_occurrences(desc, term);
weighted_tf += self.weights.description * desc_tf as f64;
}
if let Some(tags) = &doc.tags {
for tag in tags {
let tag_tf = count_term_occurrences(tag, term);
weighted_tf += self.weights.tags * tag_tf as f64;
}
}
let doc_length = doc.title.split_whitespace().count()
+ doc.body.split_whitespace().count()
+ doc
.description
.as_ref()
.map_or(0, |d| d.split_whitespace().count())
+ doc.tags.as_ref().map_or(0, |t| {
t.iter().map(|tag| tag.split_whitespace().count()).sum()
});
let length_norm =
1.0 - self.params.b + self.params.b * (doc_length as f64 / self.avg_doc_length);
let term_score = idf * (weighted_tf / (self.params.k1 * length_norm + weighted_tf));
score += term_score;
}
score
}
}
pub struct BM25PlusScorer {
params: BM25Params,
avg_doc_length: f64,
doc_count: usize,
term_doc_frequencies: HashMap<String, usize>,
}
impl BM25PlusScorer {
pub fn new() -> Self {
Self {
params: BM25Params::default(),
avg_doc_length: 0.0,
doc_count: 0,
term_doc_frequencies: HashMap::new(),
}
}
#[allow(dead_code)]
pub fn with_params(params: BM25Params) -> Self {
Self {
params,
avg_doc_length: 0.0,
doc_count: 0,
term_doc_frequencies: HashMap::new(),
}
}
pub fn initialize(&mut self, documents: &[Document]) {
self.doc_count = documents.len();
let total_length: usize = documents
.iter()
.map(|doc| doc.body.split_whitespace().count())
.sum();
if self.doc_count > 0 {
self.avg_doc_length = total_length as f64 / self.doc_count as f64;
}
let mut term_doc_frequencies = HashMap::new();
for doc in documents {
let mut terms = Vec::new();
terms.extend(doc.body.split_whitespace().map(|s| s.to_lowercase()));
let mut doc_terms = std::collections::HashSet::new();
for term in terms {
doc_terms.insert(term);
}
for term in doc_terms {
*term_doc_frequencies.entry(term).or_insert(0) += 1;
}
}
self.term_doc_frequencies = term_doc_frequencies;
}
pub fn score(&self, query: &str, doc: &Document) -> f64 {
let query_terms: Vec<String> = query.split_whitespace().map(|s| s.to_lowercase()).collect();
if query_terms.is_empty() || self.doc_count == 0 {
return 0.0;
}
let mut score = 0.0;
for term in &query_terms {
let n_docs_with_term = self.term_doc_frequencies.get(term).copied().unwrap_or(0);
if n_docs_with_term == 0 {
continue;
}
let idf = f64::ln(
(self.doc_count as f64 - n_docs_with_term as f64 + 0.5)
/ (n_docs_with_term as f64 + 0.5)
+ 1.0,
);
let tf = count_term_occurrences(&doc.body, term) as f64;
let doc_length = doc.body.split_whitespace().count() as f64;
let length_norm =
1.0 - self.params.b + self.params.b * (doc_length / self.avg_doc_length);
let term_score = idf
* ((tf * (self.params.k1 + 1.0)) / (self.params.k1 * length_norm + tf)
+ self.params.delta);
score += term_score;
}
score
}
}
fn count_term_occurrences(text: &str, term: &str) -> usize {
text.to_lowercase()
.split_whitespace()
.filter(|word| *word == term)
.count()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25f_scorer() {
let mut scorer = BM25FScorer::new();
let documents = vec![
Document {
id: "1".to_string(),
url: "http://example.com/1".to_string(),
title: "Rust Programming Language".to_string(),
body: "Rust is a systems programming language focused on safety, speed, and concurrency.".to_string(),
description: Some("Learn about Rust programming".to_string()),
summarization: None,
stub: None,
tags: Some(vec!["programming".to_string(), "systems".to_string()]),
rank: None,
source_haystack: None,
doc_type: terraphim_types::DocumentType::KgEntry,
synonyms: None,
route: None,
priority: None,
},
Document {
id: "2".to_string(),
url: "http://example.com/2".to_string(),
title: "Python Programming Tutorial".to_string(),
body: "Python is a high-level programming language known for its readability.".to_string(),
description: Some("Learn Python programming".to_string()),
summarization: None,
stub: None,
tags: Some(vec!["programming".to_string(), "tutorial".to_string()]),
rank: None,
source_haystack: None,
doc_type: terraphim_types::DocumentType::KgEntry,
synonyms: None,
route: None,
priority: None,
},
];
scorer.initialize(&documents);
let score1 = scorer.score("rust programming", &documents[0]);
let score2 = scorer.score("rust programming", &documents[1]);
assert!(score1 > score2);
let score1 = scorer.score("python tutorial", &documents[0]);
let score2 = scorer.score("python tutorial", &documents[1]);
assert!(score2 > score1);
}
#[test]
fn test_bm25plus_scorer() {
let mut scorer = BM25PlusScorer::new();
let documents = vec![
Document {
id: "1".to_string(),
url: "http://example.com/1".to_string(),
title: "Rust Programming Language".to_string(),
body: "Rust is a systems programming language focused on safety, speed, and concurrency.".to_string(),
description: Some("Learn about Rust programming".to_string()),
summarization: None,
stub: None,
tags: Some(vec!["programming".to_string(), "systems".to_string()]),
rank: None,
source_haystack: None,
doc_type: terraphim_types::DocumentType::KgEntry,
synonyms: None,
route: None,
priority: None,
},
Document {
id: "2".to_string(),
url: "http://example.com/2".to_string(),
title: "Python Programming Tutorial".to_string(),
body: "Python is a high-level programming language known for its readability.".to_string(),
description: Some("Learn Python programming".to_string()),
summarization: None,
stub: None,
tags: Some(vec!["programming".to_string(), "tutorial".to_string()]),
rank: None,
source_haystack: None,
doc_type: terraphim_types::DocumentType::KgEntry,
synonyms: None,
route: None,
priority: None,
},
];
scorer.initialize(&documents);
let score1 = scorer.score("rust programming", &documents[0]);
let score2 = scorer.score("rust programming", &documents[1]);
assert!(score1 > score2);
let score1 = scorer.score("python tutorial", &documents[0]);
let score2 = scorer.score("python tutorial", &documents[1]);
assert!(score2 > score1);
}
#[test]
fn test_count_term_occurrences() {
let text = "Rust is a systems programming language. Rust is safe and fast.";
assert_eq!(count_term_occurrences(text, "rust"), 2);
assert_eq!(count_term_occurrences(text, "is"), 2);
assert_eq!(count_term_occurrences(text, "programming"), 1);
assert_eq!(count_term_occurrences(text, "python"), 0);
}
}