use serde::{Deserialize, Serialize};
use khive_runtime::{FusionStrategy, RuntimeError};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct RecallConfig {
pub relevance_weight: f64,
pub importance_weight: f64,
pub temporal_weight: f64,
pub temporal_half_life_days: f64,
pub decay_model: DecayModel,
pub candidate_multiplier: u32,
pub candidate_limit: Option<u32>,
pub fuse_strategy: FusionStrategy,
pub min_score: f64,
pub min_salience: f64,
pub include_breakdown: bool,
}
impl Default for RecallConfig {
fn default() -> Self {
Self {
relevance_weight: 0.70,
importance_weight: 0.20,
temporal_weight: 0.10,
temporal_half_life_days: 30.0,
decay_model: DecayModel::default(),
candidate_multiplier: 20,
candidate_limit: None,
fuse_strategy: FusionStrategy::default(),
min_score: 0.0,
min_salience: 0.0,
include_breakdown: false,
}
}
}
impl RecallConfig {
pub fn validate(&self) -> Result<(), RuntimeError> {
if self.relevance_weight < 0.0 {
return Err(RuntimeError::InvalidInput(
"relevance_weight must be non-negative".to_string(),
));
}
if self.importance_weight < 0.0 {
return Err(RuntimeError::InvalidInput(
"importance_weight must be non-negative".to_string(),
));
}
if self.temporal_weight < 0.0 {
return Err(RuntimeError::InvalidInput(
"temporal_weight must be non-negative".to_string(),
));
}
let weight_sum = self.relevance_weight + self.importance_weight + self.temporal_weight;
if weight_sum <= 0.0 {
return Err(RuntimeError::InvalidInput(
"at least one of relevance_weight / importance_weight / temporal_weight must be positive".to_string(),
));
}
if self.temporal_half_life_days <= 0.0 {
return Err(RuntimeError::InvalidInput(
"temporal_half_life_days must be positive".to_string(),
));
}
if self.candidate_limit == Some(0) {
return Err(RuntimeError::InvalidInput(
"candidate_limit must be positive when provided".to_string(),
));
}
if !self.min_score.is_finite() {
return Err(RuntimeError::InvalidInput(
"min_score must be finite".to_string(),
));
}
if !self.min_salience.is_finite() {
return Err(RuntimeError::InvalidInput(
"min_salience must be finite".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum DecayModel {
#[default]
Exponential,
Hyperbolic,
PowerLaw {
half_life_days: f64,
},
None,
}
impl DecayModel {
pub fn apply(&self, salience: f64, age_days: f64, decay_factor: f64, half_life: f64) -> f64 {
match self {
DecayModel::Exponential => {
let k = std::f64::consts::LN_2 / half_life;
salience * (-k * age_days).exp()
}
DecayModel::Hyperbolic => salience / (1.0 + decay_factor * age_days),
DecayModel::PowerLaw { half_life_days } => {
let hl = *half_life_days;
salience * hl / (hl + age_days)
}
DecayModel::None => salience,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoreBreakdown {
pub relevance: f64,
pub importance_raw: f64,
pub importance_decayed: f64,
pub temporal: f64,
pub weighted: WeightedContributions,
}
impl ScoreBreakdown {
pub fn total(&self) -> f64 {
self.weighted.relevance_contribution
+ self.weighted.importance_contribution
+ self.weighted.temporal_contribution
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WeightedContributions {
pub relevance_contribution: f64,
pub importance_contribution: f64,
pub temporal_contribution: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exponential_halves_at_half_life() {
let model = DecayModel::Exponential;
let salience = 1.0;
let half_life = 30.0;
let result = model.apply(salience, half_life, 0.01, half_life);
let diff = (result - 0.5).abs();
assert!(
diff < 1e-10,
"exponential should give 0.5 at half-life, got {result}"
);
}
#[test]
fn exponential_full_salience_at_zero_age() {
let model = DecayModel::Exponential;
let result = model.apply(0.8, 0.0, 0.01, 30.0);
let diff = (result - 0.8).abs();
assert!(
diff < 1e-12,
"at age=0 salience should be unchanged, got {result}"
);
}
#[test]
fn hyperbolic_halves_at_one_over_decay_factor() {
let model = DecayModel::Hyperbolic;
let salience = 1.0;
let k = 0.05;
let age = 1.0 / k; let result = model.apply(salience, age, k, 30.0);
let diff = (result - 0.5).abs();
assert!(
diff < 1e-10,
"hyperbolic at age=1/k should give 0.5, got {result}"
);
}
#[test]
fn hyperbolic_full_salience_at_zero_age() {
let model = DecayModel::Hyperbolic;
let result = model.apply(0.7, 0.0, 0.05, 30.0);
let diff = (result - 0.7).abs();
assert!(
diff < 1e-12,
"at age=0 salience should be unchanged, got {result}"
);
}
#[test]
fn powerlaw_halves_at_half_life() {
let hl = 30.0;
let model = DecayModel::PowerLaw { half_life_days: hl };
let salience = 1.0;
let result = model.apply(salience, hl, 0.01, hl);
let diff = (result - 0.5).abs();
assert!(
diff < 1e-10,
"power-law should give 0.5 at half-life, got {result}"
);
}
#[test]
fn decay_none_returns_salience_unchanged() {
let model = DecayModel::None;
let result = model.apply(0.6, 100.0, 0.99, 30.0);
let diff = (result - 0.6).abs();
assert!(
diff < 1e-12,
"None model must not alter salience, got {result}"
);
}
#[test]
fn default_config_validates() {
assert!(RecallConfig::default().validate().is_ok());
}
#[test]
fn negative_relevance_weight_fails_validation() {
let cfg = RecallConfig {
relevance_weight: -0.1,
..RecallConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn negative_importance_weight_fails_validation() {
let cfg = RecallConfig {
importance_weight: -1.0,
..RecallConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn negative_temporal_weight_fails_validation() {
let cfg = RecallConfig {
temporal_weight: -0.5,
..RecallConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn all_zero_weights_fails_validation() {
let cfg = RecallConfig {
relevance_weight: 0.0,
importance_weight: 0.0,
temporal_weight: 0.0,
..RecallConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn zero_half_life_fails_validation() {
let cfg = RecallConfig {
temporal_half_life_days: 0.0,
..RecallConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn negative_half_life_fails_validation() {
let cfg = RecallConfig {
temporal_half_life_days: -5.0,
..RecallConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn non_uniform_weights_validate() {
let cfg = RecallConfig {
relevance_weight: 0.5,
importance_weight: 0.3,
temporal_weight: 0.2,
..RecallConfig::default()
};
assert!(cfg.validate().is_ok());
}
#[test]
fn default_config_roundtrip() {
let cfg = RecallConfig::default();
let json = serde_json::to_string(&cfg).expect("serialize");
let back: RecallConfig = serde_json::from_str(&json).expect("deserialize");
let diff = (cfg.relevance_weight - back.relevance_weight).abs();
assert!(diff < 1e-12);
assert_eq!(cfg.decay_model, back.decay_model);
}
#[test]
fn decay_model_exponential_roundtrip() {
let m = DecayModel::Exponential;
let json = serde_json::to_string(&m).expect("serialize");
let back: DecayModel = serde_json::from_str(&json).expect("deserialize");
assert_eq!(m, back);
}
#[test]
fn decay_model_hyperbolic_roundtrip() {
let m = DecayModel::Hyperbolic;
let json = serde_json::to_string(&m).expect("serialize");
let back: DecayModel = serde_json::from_str(&json).expect("deserialize");
assert_eq!(m, back);
}
#[test]
fn decay_model_powerlaw_roundtrip() {
let m = DecayModel::PowerLaw {
half_life_days: 14.0,
};
let json = serde_json::to_string(&m).expect("serialize");
let back: DecayModel = serde_json::from_str(&json).expect("deserialize");
assert_eq!(m, back);
}
#[test]
fn decay_model_none_roundtrip() {
let m = DecayModel::None;
let json = serde_json::to_string(&m).expect("serialize");
let back: DecayModel = serde_json::from_str(&json).expect("deserialize");
assert_eq!(m, back);
}
#[test]
fn partial_config_deserializes_with_defaults() {
let json = r#"{"relevance_weight": 0.5}"#;
let cfg: RecallConfig = serde_json::from_str(json).expect("deserialize partial");
let diff = (cfg.relevance_weight - 0.5).abs();
assert!(diff < 1e-12);
let diff2 = (cfg.importance_weight - 0.20).abs();
assert!(diff2 < 1e-12);
assert_eq!(cfg.decay_model, DecayModel::Exponential);
}
#[test]
fn new_fields_have_correct_defaults() {
let cfg = RecallConfig::default();
assert_eq!(cfg.candidate_limit, None);
assert_eq!(cfg.fuse_strategy, FusionStrategy::Rrf { k: 60 });
assert!(!cfg.include_breakdown);
}
#[test]
fn candidate_limit_zero_fails_validation() {
let cfg = RecallConfig {
candidate_limit: Some(0),
..RecallConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn candidate_limit_some_positive_validates() {
let cfg = RecallConfig {
candidate_limit: Some(100),
..RecallConfig::default()
};
assert!(cfg.validate().is_ok());
}
#[test]
fn min_score_nan_fails_validation() {
let cfg = RecallConfig {
min_score: f64::NAN,
..RecallConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn min_salience_nan_fails_validation() {
let cfg = RecallConfig {
min_salience: f64::NAN,
..RecallConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn new_fields_roundtrip() {
let cfg = RecallConfig {
candidate_limit: Some(50),
fuse_strategy: FusionStrategy::Union,
include_breakdown: true,
..RecallConfig::default()
};
let json = serde_json::to_string(&cfg).expect("serialize");
let back: RecallConfig = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back.candidate_limit, Some(50));
assert_eq!(back.fuse_strategy, FusionStrategy::Union);
assert!(back.include_breakdown);
}
#[test]
fn partial_config_new_fields_use_defaults() {
let json = r#"{"temporal_weight": 0.15}"#;
let cfg: RecallConfig = serde_json::from_str(json).expect("deserialize partial");
assert_eq!(cfg.candidate_limit, None);
assert_eq!(cfg.fuse_strategy, FusionStrategy::Rrf { k: 60 });
assert!(!cfg.include_breakdown);
}
#[test]
fn score_breakdown_total_sums_contributions() {
let bd = ScoreBreakdown {
relevance: 0.5,
importance_raw: 0.8,
importance_decayed: 0.6,
temporal: 0.3,
weighted: WeightedContributions {
relevance_contribution: 0.35,
importance_contribution: 0.12,
temporal_contribution: 0.03,
},
};
let expected = 0.35 + 0.12 + 0.03;
let diff = (bd.total() - expected).abs();
assert!(
diff < 1e-12,
"total() should sum weighted contributions, got {}",
bd.total()
);
}
}