use std::collections::HashMap;
use super::common::BM25Params;
use terraphim_types::Document;
pub struct OkapiBM25Scorer {
params: BM25Params,
avg_doc_length: f64,
doc_count: usize,
term_doc_frequencies: HashMap<String, usize>,
}
impl OkapiBM25Scorer {
pub fn new() -> Self {
Self {
params: BM25Params::default(),
avg_doc_length: 0.0,
doc_count: 0,
term_doc_frequencies: HashMap::new(),
}
}
#[allow(dead_code)]
pub fn with_params(params: BM25Params) -> Self {
Self {
params,
avg_doc_length: 0.0,
doc_count: 0,
term_doc_frequencies: HashMap::new(),
}
}
pub fn initialize(&mut self, documents: &[Document]) {
self.doc_count = documents.len();
let total_length: usize = documents
.iter()
.map(|doc| doc.body.split_whitespace().count())
.sum();
if self.doc_count > 0 {
self.avg_doc_length = total_length as f64 / self.doc_count as f64;
}
let mut term_doc_frequencies = HashMap::new();
for doc in documents {
let mut terms = Vec::new();
terms.extend(doc.body.split_whitespace().map(|s| s.to_lowercase()));
let mut doc_terms = std::collections::HashSet::new();
for term in terms {
doc_terms.insert(term);
}
for term in doc_terms {
*term_doc_frequencies.entry(term).or_insert(0) += 1;
}
}
self.term_doc_frequencies = term_doc_frequencies;
}
pub fn score(&self, query: &str, doc: &Document) -> f64 {
let query_terms: Vec<String> = query.split_whitespace().map(|s| s.to_lowercase()).collect();
if query_terms.is_empty() || self.doc_count == 0 {
return 0.0;
}
let mut score = 0.0;
for term in &query_terms {
let n_docs_with_term = self.term_doc_frequencies.get(term).copied().unwrap_or(0);
if n_docs_with_term == 0 {
continue;
}
let idf = f64::ln(
(self.doc_count as f64 - n_docs_with_term as f64 + 0.5)
/ (n_docs_with_term as f64 + 0.5)
+ 1.0,
);
let tf = count_term_occurrences(&doc.body, term) as f64;
let doc_length = doc.body.split_whitespace().count() as f64;
let length_norm =
1.0 - self.params.b + self.params.b * (doc_length / self.avg_doc_length);
let term_score =
idf * ((tf * (self.params.k1 + 1.0)) / (self.params.k1 * length_norm + tf));
score += term_score;
}
score
}
}
pub struct TFIDFScorer {
doc_count: usize,
term_doc_frequencies: HashMap<String, usize>,
}
impl TFIDFScorer {
pub fn new() -> Self {
Self {
doc_count: 0,
term_doc_frequencies: HashMap::new(),
}
}
pub fn initialize(&mut self, documents: &[Document]) {
self.doc_count = documents.len();
let mut term_doc_frequencies = HashMap::new();
for doc in documents {
let mut terms = Vec::new();
terms.extend(doc.body.split_whitespace().map(|s| s.to_lowercase()));
let mut doc_terms = std::collections::HashSet::new();
for term in terms {
doc_terms.insert(term);
}
for term in doc_terms {
*term_doc_frequencies.entry(term).or_insert(0) += 1;
}
}
self.term_doc_frequencies = term_doc_frequencies;
}
pub fn score(&self, query: &str, doc: &Document) -> f64 {
let query_terms: Vec<String> = query.split_whitespace().map(|s| s.to_lowercase()).collect();
if query_terms.is_empty() || self.doc_count == 0 {
return 0.0;
}
let mut score = 0.0;
for term in &query_terms {
let n_docs_with_term = self.term_doc_frequencies.get(term).copied().unwrap_or(0);
if n_docs_with_term == 0 {
continue;
}
let idf = f64::ln((self.doc_count as f64) / (n_docs_with_term as f64));
let tf = count_term_occurrences(&doc.body, term) as f64;
let term_score = tf * idf;
score += term_score;
}
score
}
}
pub struct JaccardScorer {
doc_count: usize,
}
impl JaccardScorer {
pub fn new() -> Self {
Self { doc_count: 0 }
}
pub fn initialize(&mut self, documents: &[Document]) {
self.doc_count = documents.len();
}
pub fn score(&self, query: &str, doc: &Document) -> f64 {
let query_terms: Vec<String> = query.split_whitespace().map(|s| s.to_lowercase()).collect();
if query_terms.is_empty() || self.doc_count == 0 {
return 0.0;
}
let doc_terms: Vec<String> = doc
.body
.split_whitespace()
.map(|s| s.to_lowercase())
.collect();
if doc_terms.is_empty() {
return 0.0;
}
let query_set: std::collections::HashSet<&String> = query_terms.iter().collect();
let doc_set: std::collections::HashSet<&String> = doc_terms.iter().collect();
let intersection_size = query_set.intersection(&doc_set).count();
let union_size = query_set.union(&doc_set).count();
if union_size > 0 {
intersection_size as f64 / union_size as f64
} else {
0.0
}
}
}
pub struct QueryRatioScorer {
doc_count: usize,
}
impl QueryRatioScorer {
pub fn new() -> Self {
Self { doc_count: 0 }
}
pub fn initialize(&mut self, documents: &[Document]) {
self.doc_count = documents.len();
}
pub fn score(&self, query: &str, doc: &Document) -> f64 {
let query_terms: Vec<String> = query.split_whitespace().map(|s| s.to_lowercase()).collect();
if query_terms.is_empty() || self.doc_count == 0 {
return 0.0;
}
let doc_terms: Vec<String> = doc
.body
.split_whitespace()
.map(|s| s.to_lowercase())
.collect();
if doc_terms.is_empty() {
return 0.0;
}
let query_set: std::collections::HashSet<&String> = query_terms.iter().collect();
let doc_set: std::collections::HashSet<&String> = doc_terms.iter().collect();
let intersection_size = query_set.intersection(&doc_set).count();
if !query_set.is_empty() {
intersection_size as f64 / query_set.len() as f64
} else {
0.0
}
}
}
fn count_term_occurrences(text: &str, term: &str) -> usize {
text.to_lowercase()
.split_whitespace()
.filter(|word| *word == term)
.count()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_okapi_bm25_scorer() {
let mut scorer = OkapiBM25Scorer::new();
let documents = vec![
Document {
id: "1".to_string(),
url: "http://example.com/1".to_string(),
title: "Rust Programming Language".to_string(),
body: "Rust is a systems programming language focused on safety, speed, and concurrency.".to_string(),
description: Some("Learn about Rust programming".to_string()),
summarization: None,
stub: None,
tags: Some(vec!["programming".to_string(), "systems".to_string()]),
rank: None,
source_haystack: None,
doc_type: terraphim_types::DocumentType::KgEntry,
synonyms: None,
route: None,
priority: None,
},
Document {
id: "2".to_string(),
url: "http://example.com/2".to_string(),
title: "Python Programming Tutorial".to_string(),
body: "Python is a high-level programming language known for its readability.".to_string(),
description: Some("Learn Python programming".to_string()),
summarization: None,
stub: None,
tags: Some(vec!["programming".to_string(), "tutorial".to_string()]),
rank: None,
source_haystack: None,
doc_type: terraphim_types::DocumentType::KgEntry,
synonyms: None,
route: None,
priority: None,
},
];
scorer.initialize(&documents);
let score1 = scorer.score("rust programming", &documents[0]);
let score2 = scorer.score("rust programming", &documents[1]);
assert!(score1 > score2);
let score1 = scorer.score("python tutorial", &documents[0]);
let score2 = scorer.score("python tutorial", &documents[1]);
assert!(score2 > score1);
}
#[test]
fn test_tfidf_scorer() {
let mut scorer = TFIDFScorer::new();
let documents = vec![
Document {
id: "1".to_string(),
url: "http://example.com/1".to_string(),
title: "Rust Programming Language".to_string(),
body: "Rust is a systems programming language focused on safety, speed, and concurrency.".to_string(),
description: Some("Learn about Rust programming".to_string()),
summarization: None,
stub: None,
tags: Some(vec!["programming".to_string(), "systems".to_string()]),
rank: None,
source_haystack: None,
doc_type: terraphim_types::DocumentType::KgEntry,
synonyms: None,
route: None,
priority: None,
},
Document {
id: "2".to_string(),
url: "http://example.com/2".to_string(),
title: "Python Programming Tutorial".to_string(),
body: "Python is a high-level programming language known for its readability.".to_string(),
description: Some("Learn Python programming".to_string()),
summarization: None,
stub: None,
tags: Some(vec!["programming".to_string(), "tutorial".to_string()]),
rank: None,
source_haystack: None,
doc_type: terraphim_types::DocumentType::KgEntry,
synonyms: None,
route: None,
priority: None,
},
];
scorer.initialize(&documents);
let score1 = scorer.score("rust programming", &documents[0]);
let score2 = scorer.score("rust programming", &documents[1]);
assert!(score1 > score2);
let score1 = scorer.score("python tutorial", &documents[0]);
let score2 = scorer.score("python tutorial", &documents[1]);
assert!(score2 > score1);
}
#[test]
fn test_jaccard_scorer() {
let mut scorer = JaccardScorer::new();
let documents = vec![
Document {
id: "1".to_string(),
url: "http://example.com/1".to_string(),
title: "Rust Programming Language".to_string(),
body: "Rust is a systems programming language focused on safety, speed, and concurrency.".to_string(),
description: Some("Learn about Rust programming".to_string()),
summarization: None,
stub: None,
tags: Some(vec!["programming".to_string(), "systems".to_string()]),
rank: None,
source_haystack: None,
doc_type: terraphim_types::DocumentType::KgEntry,
synonyms: None,
route: None,
priority: None,
},
Document {
id: "2".to_string(),
url: "http://example.com/2".to_string(),
title: "Python Programming Tutorial".to_string(),
body: "Python is a high-level programming language known for its readability.".to_string(),
description: Some("Learn Python programming".to_string()),
summarization: None,
stub: None,
tags: Some(vec!["programming".to_string(), "tutorial".to_string()]),
rank: None,
source_haystack: None,
doc_type: terraphim_types::DocumentType::KgEntry,
synonyms: None,
route: None,
priority: None,
},
];
scorer.initialize(&documents);
let score1 = scorer.score("rust programming", &documents[0]);
let score2 = scorer.score("rust programming", &documents[1]);
assert!(score1 > score2);
let score1 = scorer.score("python tutorial", &documents[0]);
let score2 = scorer.score("python tutorial", &documents[1]);
assert!(score2 > score1);
}
#[test]
fn test_query_ratio_scorer() {
let mut scorer = QueryRatioScorer::new();
let documents = vec![
Document {
id: "1".to_string(),
url: "http://example.com/1".to_string(),
title: "Rust Programming Language".to_string(),
body: "Rust is a systems programming language focused on safety, speed, and concurrency.".to_string(),
description: Some("Learn about Rust programming".to_string()),
summarization: None,
stub: None,
tags: Some(vec!["programming".to_string(), "systems".to_string()]),
rank: None,
source_haystack: None,
doc_type: terraphim_types::DocumentType::KgEntry,
synonyms: None,
route: None,
priority: None,
},
Document {
id: "2".to_string(),
url: "http://example.com/2".to_string(),
title: "Python Programming Tutorial".to_string(),
body: "Python is a high-level programming language known for its readability.".to_string(),
description: Some("Learn Python programming".to_string()),
summarization: None,
stub: None,
tags: Some(vec!["programming".to_string(), "tutorial".to_string()]),
rank: None,
source_haystack: None,
doc_type: terraphim_types::DocumentType::KgEntry,
synonyms: None,
route: None,
priority: None,
},
];
scorer.initialize(&documents);
let score1 = scorer.score("rust programming", &documents[0]);
let score2 = scorer.score("rust programming", &documents[1]);
assert!(score1 > score2);
let score1 = scorer.score("python tutorial", &documents[0]);
let score2 = scorer.score("python tutorial", &documents[1]);
assert!(score2 > score1);
}
}