use std::collections::HashMap;
use crate::CupelError;
use crate::model::{ContextItem, ContextKind};
use crate::scorer::Scorer;
#[derive(Debug, Clone)]
pub struct KindScorer {
weights: HashMap<ContextKind, f64>,
}
impl KindScorer {
pub fn with_default_weights() -> Self {
let mut weights = HashMap::new();
weights.insert(ContextKind::from_static(ContextKind::SYSTEM_PROMPT), 1.0);
weights.insert(ContextKind::from_static(ContextKind::MEMORY), 0.8);
weights.insert(ContextKind::from_static(ContextKind::TOOL_OUTPUT), 0.6);
weights.insert(ContextKind::from_static(ContextKind::DOCUMENT), 0.4);
weights.insert(ContextKind::from_static(ContextKind::MESSAGE), 0.2);
Self { weights }
}
pub fn new(weights: HashMap<ContextKind, f64>) -> Result<Self, CupelError> {
for (kind, &weight) in &weights {
if weight < 0.0 {
return Err(CupelError::ScorerConfig(format!(
"weight for kind '{}' must be non-negative",
kind,
)));
}
if !weight.is_finite() {
return Err(CupelError::ScorerConfig(format!(
"weight for kind '{}' must be finite",
kind,
)));
}
}
Ok(Self { weights })
}
}
impl Scorer for KindScorer {
fn score(&self, item: &ContextItem, _all_items: &[ContextItem]) -> f64 {
self.weights.get(item.kind()).copied().unwrap_or(0.0)
}
}