use std::collections::HashMap;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::error::Result;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoringConfig {
pub k1: f32,
pub b: f32,
pub tf_idf_boost: f32,
pub enable_field_norm: bool,
pub field_boosts: HashMap<String, f32>,
pub enable_coord: bool,
}
impl Default for ScoringConfig {
fn default() -> Self {
ScoringConfig {
k1: 1.2,
b: 0.75,
tf_idf_boost: 1.0,
enable_field_norm: true,
field_boosts: HashMap::new(),
enable_coord: true,
}
}
}
#[derive(Debug, Clone)]
pub struct DocumentStats {
pub doc_id: u32,
pub doc_length: u64,
pub field_lengths: HashMap<String, u64>,
pub term_frequencies: HashMap<String, u64>,
pub field_term_frequencies: HashMap<String, HashMap<String, u64>>,
}
#[derive(Debug, Clone)]
pub struct CollectionStats {
pub total_docs: u64,
pub avg_doc_length: f64,
pub avg_field_lengths: HashMap<String, f64>,
pub document_frequencies: HashMap<String, u64>,
pub field_document_frequencies: HashMap<String, HashMap<String, u64>>,
}
pub trait ScoringFunction: Send + Sync + std::fmt::Debug {
fn score(
&self,
query_terms: &[String],
doc_stats: &DocumentStats,
collection_stats: &CollectionStats,
config: &ScoringConfig,
) -> Result<f32>;
fn name(&self) -> &str;
fn description(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct BM25ScoringFunction;
impl ScoringFunction for BM25ScoringFunction {
fn score(
&self,
query_terms: &[String],
doc_stats: &DocumentStats,
collection_stats: &CollectionStats,
config: &ScoringConfig,
) -> Result<f32> {
let mut total_score = 0.0;
for term in query_terms {
let tf = *doc_stats.term_frequencies.get(term).unwrap_or(&0) as f32;
if tf == 0.0 {
continue;
}
let df = *collection_stats
.document_frequencies
.get(term)
.unwrap_or(&1) as f32;
let idf_ratio = (collection_stats.total_docs as f32 - df + 0.5) / (df + 0.5);
if idf_ratio <= 0.0 {
continue;
}
let idf = idf_ratio.ln();
let doc_len = doc_stats.doc_length as f32;
let avg_len = collection_stats.avg_doc_length as f32;
let len_ratio = if avg_len > 0.0 {
doc_len / avg_len
} else {
1.0
};
let tf_component = (tf * (config.k1 + 1.0))
/ (tf + config.k1 * (1.0 - config.b + config.b * len_ratio));
total_score += idf * tf_component;
}
Ok(total_score)
}
fn name(&self) -> &str {
"BM25"
}
fn description(&self) -> &str {
"Best Matching 25 probabilistic ranking function"
}
}
#[derive(Debug, Clone)]
pub struct TfIdfScoringFunction;
impl ScoringFunction for TfIdfScoringFunction {
fn score(
&self,
query_terms: &[String],
doc_stats: &DocumentStats,
collection_stats: &CollectionStats,
config: &ScoringConfig,
) -> Result<f32> {
let mut total_score = 0.0;
for term in query_terms {
let tf = *doc_stats.term_frequencies.get(term).unwrap_or(&0) as f32;
if tf == 0.0 {
continue;
}
let tf_component = (1.0 + tf.ln()) * config.tf_idf_boost;
let df = *collection_stats
.document_frequencies
.get(term)
.unwrap_or(&1) as f32;
let total_docs = collection_stats.total_docs as f32;
if total_docs == 0.0 {
continue;
}
let idf = (total_docs / df).ln();
let norm_factor = if config.enable_field_norm {
let doc_len = doc_stats.doc_length as f32;
let avg_len = collection_stats.avg_doc_length as f32;
if doc_len > 0.0 {
(avg_len / doc_len).sqrt()
} else {
1.0
}
} else {
1.0
};
total_score += tf_component * idf * norm_factor;
}
Ok(total_score)
}
fn name(&self) -> &str {
"TF-IDF"
}
fn description(&self) -> &str {
"Term Frequency - Inverse Document Frequency"
}
}
#[derive(Debug, Clone)]
pub struct VectorSpaceScoringFunction;
impl ScoringFunction for VectorSpaceScoringFunction {
fn score(
&self,
query_terms: &[String],
doc_stats: &DocumentStats,
collection_stats: &CollectionStats,
_config: &ScoringConfig,
) -> Result<f32> {
let mut query_vector = Vec::new();
let mut doc_vector = Vec::new();
for term in query_terms {
query_vector.push(1.0);
let tf = *doc_stats.term_frequencies.get(term).unwrap_or(&0) as f32;
let df = *collection_stats
.document_frequencies
.get(term)
.unwrap_or(&1) as f32;
let total_docs = collection_stats.total_docs as f32;
let idf = if total_docs > 0.0 {
(total_docs / df).ln()
} else {
0.0
};
doc_vector.push(tf * idf);
}
let dot_product: f32 = query_vector
.iter()
.zip(doc_vector.iter())
.map(|(q, d)| q * d)
.sum();
let query_magnitude: f32 = query_vector.iter().map(|x| x * x).sum::<f32>().sqrt();
let doc_magnitude: f32 = doc_vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if query_magnitude == 0.0 || doc_magnitude == 0.0 {
Ok(0.0)
} else {
Ok(dot_product / (query_magnitude * doc_magnitude))
}
}
fn name(&self) -> &str {
"Vector Space"
}
fn description(&self) -> &str {
"Vector Space Model with cosine similarity"
}
}
pub struct CustomScoringFunction {
name: String,
description: String,
#[allow(clippy::type_complexity)]
scorer: Arc<
dyn Fn(&[String], &DocumentStats, &CollectionStats, &ScoringConfig) -> Result<f32>
+ Send
+ Sync,
>,
}
impl std::fmt::Debug for CustomScoringFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CustomScoringFunction")
.field("name", &self.name)
.field("description", &self.description)
.field("scorer", &"<custom function>")
.finish()
}
}
impl CustomScoringFunction {
pub fn new<F>(name: String, description: String, scorer: F) -> Self
where
F: Fn(&[String], &DocumentStats, &CollectionStats, &ScoringConfig) -> Result<f32>
+ Send
+ Sync
+ 'static,
{
CustomScoringFunction {
name,
description,
scorer: Arc::new(scorer),
}
}
}
impl ScoringFunction for CustomScoringFunction {
fn score(
&self,
query_terms: &[String],
doc_stats: &DocumentStats,
collection_stats: &CollectionStats,
config: &ScoringConfig,
) -> Result<f32> {
(self.scorer)(query_terms, doc_stats, collection_stats, config)
}
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
}
#[derive(Debug)]
pub struct AdvancedScorer {
scoring_function: Box<dyn ScoringFunction>,
config: ScoringConfig,
collection_stats: CollectionStats,
query_terms: Vec<String>,
}
impl AdvancedScorer {
pub fn new(
scoring_function: Box<dyn ScoringFunction>,
config: ScoringConfig,
collection_stats: CollectionStats,
query_terms: Vec<String>,
) -> Self {
AdvancedScorer {
scoring_function,
config,
collection_stats,
query_terms,
}
}
pub fn score_document(&self, doc_stats: &DocumentStats) -> Result<f32> {
let mut base_score = self.scoring_function.score(
&self.query_terms,
doc_stats,
&self.collection_stats,
&self.config,
)?;
base_score = self.apply_field_boosts(base_score, doc_stats)?;
if self.config.enable_coord {
base_score = self.apply_coordination_factor(base_score, doc_stats)?;
}
Ok(base_score)
}
fn apply_field_boosts(&self, base_score: f32, doc_stats: &DocumentStats) -> Result<f32> {
if self.config.field_boosts.is_empty() {
return Ok(base_score);
}
let mut max_boost = 1.0f32;
for (field, boost) in &self.config.field_boosts {
if let Some(field_term_freqs) = doc_stats.field_term_frequencies.get(field) {
let field_score: u64 = field_term_freqs.values().sum();
if field_score > 0 {
max_boost = max_boost.max(*boost);
}
}
}
Ok(base_score * max_boost)
}
fn apply_coordination_factor(&self, base_score: f32, doc_stats: &DocumentStats) -> Result<f32> {
let matched_terms = self
.query_terms
.iter()
.filter(|term| doc_stats.term_frequencies.contains_key(*term))
.count();
let coord_factor = if self.query_terms.is_empty() {
1.0
} else {
matched_terms as f32 / self.query_terms.len() as f32
};
Ok(base_score * coord_factor)
}
}
#[derive(Debug)]
pub struct ScoringRegistry {
functions: HashMap<String, Box<dyn ScoringFunction>>,
}
impl ScoringRegistry {
pub fn new() -> Self {
let mut registry = ScoringRegistry {
functions: HashMap::new(),
};
registry.register("bm25", Box::new(BM25ScoringFunction));
registry.register("tf_idf", Box::new(TfIdfScoringFunction));
registry.register("vector_space", Box::new(VectorSpaceScoringFunction));
registry
}
pub fn register(&mut self, name: &str, function: Box<dyn ScoringFunction>) {
self.functions.insert(name.to_string(), function);
}
pub fn get(&self, name: &str) -> Option<&dyn ScoringFunction> {
self.functions.get(name).map(|f| f.as_ref())
}
pub fn list_functions(&self) -> Vec<(&str, &str)> {
self.functions
.iter()
.map(|(name, func)| (name.as_str(), func.description()))
.collect()
}
}
impl Default for ScoringRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_doc_stats() -> DocumentStats {
let mut term_frequencies = HashMap::new();
term_frequencies.insert("test".to_string(), 3);
term_frequencies.insert("query".to_string(), 2);
DocumentStats {
doc_id: 1,
doc_length: 100,
field_lengths: HashMap::new(),
term_frequencies,
field_term_frequencies: HashMap::new(),
}
}
fn create_test_collection_stats() -> CollectionStats {
let mut document_frequencies = HashMap::new();
document_frequencies.insert("test".to_string(), 50);
document_frequencies.insert("query".to_string(), 30);
CollectionStats {
total_docs: 1000,
avg_doc_length: 120.0,
avg_field_lengths: HashMap::new(),
document_frequencies,
field_document_frequencies: HashMap::new(),
}
}
#[test]
fn test_bm25_scoring() {
let scorer = BM25ScoringFunction;
let doc_stats = create_test_doc_stats();
let collection_stats = create_test_collection_stats();
let config = ScoringConfig::default();
let query_terms = vec!["test".to_string(), "query".to_string()];
let score = scorer
.score(&query_terms, &doc_stats, &collection_stats, &config)
.unwrap();
assert!(score > 0.0);
assert_eq!(scorer.name(), "BM25");
}
#[test]
fn test_tf_idf_scoring() {
let scorer = TfIdfScoringFunction;
let doc_stats = create_test_doc_stats();
let collection_stats = create_test_collection_stats();
let config = ScoringConfig::default();
let query_terms = vec!["test".to_string(), "query".to_string()];
let score = scorer
.score(&query_terms, &doc_stats, &collection_stats, &config)
.unwrap();
assert!(score > 0.0);
assert_eq!(scorer.name(), "TF-IDF");
}
#[test]
fn test_vector_space_scoring() {
let scorer = VectorSpaceScoringFunction;
let doc_stats = create_test_doc_stats();
let collection_stats = create_test_collection_stats();
let config = ScoringConfig::default();
let query_terms = vec!["test".to_string(), "query".to_string()];
let score = scorer
.score(&query_terms, &doc_stats, &collection_stats, &config)
.unwrap();
assert!((0.0..=1.0).contains(&score)); assert_eq!(scorer.name(), "Vector Space");
}
#[test]
fn test_custom_scoring_function() {
let custom_scorer = CustomScoringFunction::new(
"custom".to_string(),
"Custom test scorer".to_string(),
|_terms, _doc_stats, _collection_stats, _config| Ok(42.0),
);
let doc_stats = create_test_doc_stats();
let collection_stats = create_test_collection_stats();
let config = ScoringConfig::default();
let query_terms = vec!["test".to_string()];
let score = custom_scorer
.score(&query_terms, &doc_stats, &collection_stats, &config)
.unwrap();
assert_eq!(score, 42.0);
assert_eq!(custom_scorer.name(), "custom");
}
#[test]
fn test_scoring_registry() {
let mut registry = ScoringRegistry::new();
assert!(registry.get("bm25").is_some());
assert!(registry.get("tf_idf").is_some());
assert!(registry.get("vector_space").is_some());
let custom_scorer = Box::new(CustomScoringFunction::new(
"test".to_string(),
"Test scorer".to_string(),
|_terms, _doc_stats, _collection_stats, _config| Ok(1.0),
));
registry.register("test", custom_scorer);
assert!(registry.get("test").is_some());
let functions = registry.list_functions();
assert!(functions.len() >= 4);
}
#[test]
fn test_advanced_scorer() {
let scoring_function = Box::new(BM25ScoringFunction);
let config = ScoringConfig::default();
let collection_stats = create_test_collection_stats();
let query_terms = vec!["test".to_string(), "query".to_string()];
let scorer = AdvancedScorer::new(scoring_function, config, collection_stats, query_terms);
let doc_stats = create_test_doc_stats();
let score = scorer.score_document(&doc_stats).unwrap();
assert!(score > 0.0);
}
}