use crate::CupelError;
use crate::model::ContextItem;
use crate::scorer::Scorer;
pub struct CompositeScorer {
scorers: Vec<Box<dyn Scorer>>,
normalized_weights: Vec<f64>,
}
impl CompositeScorer {
pub fn new(entries: Vec<(Box<dyn Scorer>, f64)>) -> Result<Self, CupelError> {
if entries.is_empty() {
return Err(CupelError::ScorerConfig(
"at least one scorer entry is required".to_owned(),
));
}
let mut total_weight = 0.0;
for (_, weight) in &entries {
if *weight <= 0.0 {
return Err(CupelError::ScorerConfig(
"weight must be positive".to_owned(),
));
}
if !weight.is_finite() {
return Err(CupelError::ScorerConfig("weight must be finite".to_owned()));
}
total_weight += weight;
}
let normalized_weights: Vec<f64> = entries.iter().map(|(_, w)| w / total_weight).collect();
let scorers: Vec<Box<dyn Scorer>> = entries.into_iter().map(|(s, _)| s).collect();
Ok(Self {
scorers,
normalized_weights,
})
}
}
impl Scorer for CompositeScorer {
fn score(&self, item: &ContextItem, all_items: &[ContextItem]) -> f64 {
let mut result = 0.0;
for i in 0..self.scorers.len() {
result += self.scorers[i].score(item, all_items) * self.normalized_weights[i];
}
result
}
}
impl std::fmt::Debug for CompositeScorer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompositeScorer")
.field("num_scorers", &self.scorers.len())
.field("normalized_weights", &self.normalized_weights)
.finish()
}
}