use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
pub const DEFAULT_K1: f32 = 1.2;
pub const DEFAULT_B: f32 = 0.75;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorpusStats {
pub avg_doc_length: f32,
pub doc_count: u64,
pub total_terms: u64,
pub last_update: i64,
}
impl Default for CorpusStats {
fn default() -> Self {
Self {
avg_doc_length: 0.0,
doc_count: 0,
total_terms: 0,
last_update: 0,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct BM25Config {
pub k1: f32,
pub b: f32,
}
impl Default for BM25Config {
fn default() -> Self {
Self {
k1: DEFAULT_K1,
b: DEFAULT_B,
}
}
}
impl BM25Config {
pub fn new(k1: f32, b: f32) -> Self {
Self {
k1: k1.max(0.0),
b: b.clamp(0.0, 1.0),
}
}
}
#[derive(Debug, Clone)]
pub struct TermFrequencies {
pub frequencies: HashMap<String, u32>,
pub doc_length: u32,
}
impl TermFrequencies {
pub fn new(frequencies: HashMap<String, u32>) -> Self {
let doc_length = frequencies.values().sum();
Self {
frequencies,
doc_length,
}
}
pub fn get(&self, term: &str) -> Option<u32> {
self.frequencies.get(term).copied()
}
}
pub struct Document<'a> {
pub term_freqs: &'a TermFrequencies,
}
impl<'a> Document<'a> {
pub fn new(term_freqs: &'a TermFrequencies) -> Self {
Self { term_freqs }
}
pub fn term_freq(&self, term: &str) -> Option<u32> {
self.term_freqs.get(term)
}
pub fn term_count(&self) -> u32 {
self.term_freqs.doc_length
}
}
pub struct BM25Scorer {
config: BM25Config,
corpus_stats: CorpusStats,
idf_cache: Arc<RwLock<HashMap<String, f32>>>,
df_cache: Arc<RwLock<HashMap<String, u64>>>,
}
impl BM25Scorer {
pub fn new(corpus_stats: CorpusStats) -> Self {
Self::with_config(corpus_stats, BM25Config::default())
}
pub fn with_config(corpus_stats: CorpusStats, config: BM25Config) -> Self {
Self {
config,
corpus_stats,
idf_cache: Arc::new(RwLock::new(HashMap::new())),
df_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn update_corpus_stats(&mut self, stats: CorpusStats) {
self.corpus_stats = stats;
self.idf_cache.write().clear();
}
pub fn set_doc_freq(&self, term: &str, doc_freq: u64) {
self.df_cache.write().insert(term.to_string(), doc_freq);
self.idf_cache.write().remove(term);
}
pub fn idf(&self, term: &str) -> f32 {
if let Some(&cached) = self.idf_cache.read().get(term) {
return cached;
}
let df = self.df_cache.read().get(term).copied().unwrap_or(0);
let n = self.corpus_stats.doc_count as f32;
let df_f = df as f32;
let idf = if df == 0 {
(n + 0.5).ln()
} else {
((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
};
self.idf_cache.write().insert(term.to_string(), idf);
idf
}
pub fn idf_with_df(&self, doc_freq: u64) -> f32 {
let n = self.corpus_stats.doc_count as f32;
let df = doc_freq as f32;
if doc_freq == 0 {
(n + 0.5).ln()
} else {
((n - df + 0.5) / (df + 0.5) + 1.0).ln()
}
}
pub fn score(&self, doc: &Document, query_terms: &[String]) -> f32 {
let doc_len = doc.term_count() as f32;
let avg_doc_len = self.corpus_stats.avg_doc_length.max(1.0);
let len_norm = 1.0 - self.config.b + self.config.b * (doc_len / avg_doc_len);
query_terms
.iter()
.filter_map(|term| {
let tf = doc.term_freq(term)? as f32;
let idf = self.idf(term);
let numerator = tf * (self.config.k1 + 1.0);
let denominator = tf + self.config.k1 * len_norm;
Some(idf * numerator / denominator)
})
.sum()
}
pub fn score_with_freqs(
&self,
term_freqs: &[(String, u32, u64)], doc_length: u32,
) -> f32 {
let doc_len = doc_length as f32;
let avg_doc_len = self.corpus_stats.avg_doc_length.max(1.0);
let len_norm = 1.0 - self.config.b + self.config.b * (doc_len / avg_doc_len);
term_freqs
.iter()
.map(|(_, tf, df)| {
let tf = *tf as f32;
let idf = self.idf_with_df(*df);
let numerator = tf * (self.config.k1 + 1.0);
let denominator = tf + self.config.k1 * len_norm;
idf * numerator / denominator
})
.sum()
}
pub fn score_batch<'a>(
&self,
docs: impl Iterator<Item = &'a Document<'a>>,
query_terms: &[String],
) -> Vec<f32> {
docs.map(|doc| self.score(doc, query_terms)).collect()
}
pub fn config(&self) -> &BM25Config {
&self.config
}
pub fn corpus_stats(&self) -> &CorpusStats {
&self.corpus_stats
}
pub fn clear_cache(&self) {
self.idf_cache.write().clear();
self.df_cache.write().clear();
}
}
pub fn tokenize_query(text: &str) -> Vec<String> {
text.to_lowercase()
.split_whitespace()
.filter(|s| s.len() > 1) .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
.filter(|s| !s.is_empty())
.collect()
}
pub fn parse_tsvector(tsvector_str: &str) -> HashMap<String, u32> {
let mut frequencies = HashMap::new();
for part in tsvector_str.split_whitespace() {
if let Some(quote_end) = part.find("':") {
if part.starts_with('\'') {
let term = &part[1..quote_end];
let positions = &part[quote_end + 2..];
let freq = positions.split(',').count() as u32;
frequencies.insert(term.to_string(), freq.max(1));
}
} else if part.starts_with('\'') && part.ends_with('\'') {
let term = &part[1..part.len() - 1];
frequencies.insert(term.to_string(), 1);
}
}
frequencies
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_scorer() -> BM25Scorer {
let stats = CorpusStats {
avg_doc_length: 100.0,
doc_count: 1000,
total_terms: 100000,
last_update: 0,
};
BM25Scorer::new(stats)
}
#[test]
fn test_idf_common_term() {
let scorer = create_test_scorer();
scorer.set_doc_freq("the", 900);
let idf = scorer.idf("the");
assert!(idf > 0.0, "IDF should be positive");
assert!(idf < 1.0, "IDF for common term should be low");
}
#[test]
fn test_idf_rare_term() {
let scorer = create_test_scorer();
scorer.set_doc_freq("xyzzy", 5);
let idf = scorer.idf("xyzzy");
assert!(idf > 4.0, "IDF for rare term should be high");
}
#[test]
fn test_idf_unknown_term() {
let scorer = create_test_scorer();
let idf = scorer.idf("unknown_term_xyz");
assert!(idf > 5.0, "IDF for unknown term should be maximum");
}
#[test]
fn test_bm25_score() {
let scorer = create_test_scorer();
scorer.set_doc_freq("database", 100);
scorer.set_doc_freq("query", 50);
let mut freqs = HashMap::new();
freqs.insert("database".to_string(), 3);
freqs.insert("query".to_string(), 2);
freqs.insert("other".to_string(), 5);
let term_freqs = TermFrequencies::new(freqs);
let doc = Document::new(&term_freqs);
let query_terms = vec!["database".to_string(), "query".to_string()];
let score = scorer.score(&doc, &query_terms);
assert!(score > 0.0, "Score should be positive");
}
#[test]
fn test_length_normalization() {
let scorer = create_test_scorer();
scorer.set_doc_freq("test", 100);
let mut short_freqs = HashMap::new();
short_freqs.insert("test".to_string(), 2);
for i in 0..48 {
short_freqs.insert(format!("filler{}", i), 1);
}
let short_tf = TermFrequencies::new(short_freqs);
let short_doc = Document::new(&short_tf);
let mut long_freqs = HashMap::new();
long_freqs.insert("test".to_string(), 2);
for i in 0..198 {
long_freqs.insert(format!("filler{}", i), 1);
}
let long_tf = TermFrequencies::new(long_freqs);
let long_doc = Document::new(&long_tf);
let query_terms = vec!["test".to_string()];
let short_score = scorer.score(&short_doc, &query_terms);
let long_score = scorer.score(&long_doc, &query_terms);
assert!(
short_score > long_score,
"Short doc should score higher than long doc with same TF"
);
}
#[test]
fn test_tokenize_query() {
let tokens = tokenize_query("Hello World! Database Query.");
assert_eq!(tokens, vec!["hello", "world", "database", "query"]);
}
#[test]
fn test_parse_tsvector() {
let tsvector = "'database':1,3,5 'query':2,4";
let freqs = parse_tsvector(tsvector);
assert_eq!(freqs.get("database"), Some(&3));
assert_eq!(freqs.get("query"), Some(&2));
}
#[test]
fn test_config_clamping() {
let config = BM25Config::new(-1.0, 1.5);
assert_eq!(config.k1, 0.0, "k1 should be clamped to 0");
assert_eq!(config.b, 1.0, "b should be clamped to 1");
}
}