1use serde::{Deserialize, Serialize};
2
3use khive_runtime::{FusionStrategy, RuntimeError};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
8#[serde(default)]
9pub struct RecallConfig {
10 pub relevance_weight: f64,
13 pub importance_weight: f64,
15 pub temporal_weight: f64,
17
18 pub temporal_half_life_days: f64,
21 pub decay_model: DecayModel,
23
24 pub candidate_multiplier: u32,
27 pub candidate_limit: Option<u32>,
30 pub fuse_strategy: FusionStrategy,
32 pub min_score: f64,
34 pub min_salience: f64,
36 pub include_breakdown: bool,
38}
39
40impl Default for RecallConfig {
41 fn default() -> Self {
42 Self {
43 relevance_weight: 0.70,
44 importance_weight: 0.20,
45 temporal_weight: 0.10,
46 temporal_half_life_days: 30.0,
47 decay_model: DecayModel::default(),
48 candidate_multiplier: 20,
49 candidate_limit: None,
50 fuse_strategy: FusionStrategy::default(),
51 min_score: 0.0,
52 min_salience: 0.0,
53 include_breakdown: false,
54 }
55 }
56}
57
58impl RecallConfig {
59 pub fn validate(&self) -> Result<(), RuntimeError> {
66 if self.relevance_weight < 0.0 {
67 return Err(RuntimeError::InvalidInput(
68 "relevance_weight must be non-negative".to_string(),
69 ));
70 }
71 if self.importance_weight < 0.0 {
72 return Err(RuntimeError::InvalidInput(
73 "importance_weight must be non-negative".to_string(),
74 ));
75 }
76 if self.temporal_weight < 0.0 {
77 return Err(RuntimeError::InvalidInput(
78 "temporal_weight must be non-negative".to_string(),
79 ));
80 }
81 let weight_sum = self.relevance_weight + self.importance_weight + self.temporal_weight;
82 if weight_sum <= 0.0 {
83 return Err(RuntimeError::InvalidInput(
84 "at least one of relevance_weight / importance_weight / temporal_weight must be positive".to_string(),
85 ));
86 }
87 if self.temporal_half_life_days <= 0.0 {
88 return Err(RuntimeError::InvalidInput(
89 "temporal_half_life_days must be positive".to_string(),
90 ));
91 }
92 if self.candidate_limit == Some(0) {
93 return Err(RuntimeError::InvalidInput(
94 "candidate_limit must be positive when provided".to_string(),
95 ));
96 }
97 if !self.min_score.is_finite() {
98 return Err(RuntimeError::InvalidInput(
99 "min_score must be finite".to_string(),
100 ));
101 }
102 if !self.min_salience.is_finite() {
103 return Err(RuntimeError::InvalidInput(
104 "min_salience must be finite".to_string(),
105 ));
106 }
107 Ok(())
108 }
109}
110
111#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
113#[serde(rename_all = "snake_case")]
114pub enum DecayModel {
115 #[default]
119 Exponential,
120 Hyperbolic,
122 PowerLaw {
124 half_life_days: f64,
127 },
128 None,
130}
131
132impl DecayModel {
133 pub fn apply(&self, salience: f64, age_days: f64, decay_factor: f64, half_life: f64) -> f64 {
140 match self {
141 DecayModel::Exponential => {
142 let k = std::f64::consts::LN_2 / half_life;
145 salience * (-k * age_days).exp()
146 }
147 DecayModel::Hyperbolic => salience / (1.0 + decay_factor * age_days),
148 DecayModel::PowerLaw { half_life_days } => {
149 let hl = *half_life_days;
150 salience * hl / (hl + age_days)
151 }
152 DecayModel::None => salience,
153 }
154 }
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ScoreBreakdown {
160 pub relevance: f64,
162 pub importance_raw: f64,
164 pub importance_decayed: f64,
166 pub temporal: f64,
168 pub weighted: WeightedContributions,
170}
171
172impl ScoreBreakdown {
173 pub fn total(&self) -> f64 {
175 self.weighted.relevance_contribution
176 + self.weighted.importance_contribution
177 + self.weighted.temporal_contribution
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct WeightedContributions {
184 pub relevance_contribution: f64,
185 pub importance_contribution: f64,
186 pub temporal_contribution: f64,
187}
188
189#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
198 fn exponential_halves_at_half_life() {
199 let model = DecayModel::Exponential;
200 let salience = 1.0;
201 let half_life = 30.0;
202 let result = model.apply(salience, half_life, 0.01, half_life);
203 let diff = (result - 0.5).abs();
204 assert!(
205 diff < 1e-10,
206 "exponential should give 0.5 at half-life, got {result}"
207 );
208 }
209
210 #[test]
211 fn exponential_full_salience_at_zero_age() {
212 let model = DecayModel::Exponential;
213 let result = model.apply(0.8, 0.0, 0.01, 30.0);
214 let diff = (result - 0.8).abs();
215 assert!(
216 diff < 1e-12,
217 "at age=0 salience should be unchanged, got {result}"
218 );
219 }
220
221 #[test]
222 fn hyperbolic_halves_at_one_over_decay_factor() {
223 let model = DecayModel::Hyperbolic;
225 let salience = 1.0;
226 let k = 0.05;
227 let age = 1.0 / k; let result = model.apply(salience, age, k, 30.0);
229 let diff = (result - 0.5).abs();
230 assert!(
231 diff < 1e-10,
232 "hyperbolic at age=1/k should give 0.5, got {result}"
233 );
234 }
235
236 #[test]
237 fn hyperbolic_full_salience_at_zero_age() {
238 let model = DecayModel::Hyperbolic;
239 let result = model.apply(0.7, 0.0, 0.05, 30.0);
240 let diff = (result - 0.7).abs();
241 assert!(
242 diff < 1e-12,
243 "at age=0 salience should be unchanged, got {result}"
244 );
245 }
246
247 #[test]
248 fn powerlaw_halves_at_half_life() {
249 let hl = 30.0;
250 let model = DecayModel::PowerLaw { half_life_days: hl };
251 let salience = 1.0;
252 let result = model.apply(salience, hl, 0.01, hl);
254 let diff = (result - 0.5).abs();
255 assert!(
256 diff < 1e-10,
257 "power-law should give 0.5 at half-life, got {result}"
258 );
259 }
260
261 #[test]
262 fn decay_none_returns_salience_unchanged() {
263 let model = DecayModel::None;
264 let result = model.apply(0.6, 100.0, 0.99, 30.0);
265 let diff = (result - 0.6).abs();
266 assert!(
267 diff < 1e-12,
268 "None model must not alter salience, got {result}"
269 );
270 }
271
272 #[test]
275 fn default_config_validates() {
276 assert!(RecallConfig::default().validate().is_ok());
277 }
278
279 #[test]
280 fn negative_relevance_weight_fails_validation() {
281 let cfg = RecallConfig {
282 relevance_weight: -0.1,
283 ..RecallConfig::default()
284 };
285 assert!(cfg.validate().is_err());
286 }
287
288 #[test]
289 fn negative_importance_weight_fails_validation() {
290 let cfg = RecallConfig {
291 importance_weight: -1.0,
292 ..RecallConfig::default()
293 };
294 assert!(cfg.validate().is_err());
295 }
296
297 #[test]
298 fn negative_temporal_weight_fails_validation() {
299 let cfg = RecallConfig {
300 temporal_weight: -0.5,
301 ..RecallConfig::default()
302 };
303 assert!(cfg.validate().is_err());
304 }
305
306 #[test]
307 fn all_zero_weights_fails_validation() {
308 let cfg = RecallConfig {
309 relevance_weight: 0.0,
310 importance_weight: 0.0,
311 temporal_weight: 0.0,
312 ..RecallConfig::default()
313 };
314 assert!(cfg.validate().is_err());
315 }
316
317 #[test]
318 fn zero_half_life_fails_validation() {
319 let cfg = RecallConfig {
320 temporal_half_life_days: 0.0,
321 ..RecallConfig::default()
322 };
323 assert!(cfg.validate().is_err());
324 }
325
326 #[test]
327 fn negative_half_life_fails_validation() {
328 let cfg = RecallConfig {
329 temporal_half_life_days: -5.0,
330 ..RecallConfig::default()
331 };
332 assert!(cfg.validate().is_err());
333 }
334
335 #[test]
336 fn non_uniform_weights_validate() {
337 let cfg = RecallConfig {
338 relevance_weight: 0.5,
339 importance_weight: 0.3,
340 temporal_weight: 0.2,
341 ..RecallConfig::default()
342 };
343 assert!(cfg.validate().is_ok());
344 }
345
346 #[test]
349 fn default_config_roundtrip() {
350 let cfg = RecallConfig::default();
351 let json = serde_json::to_string(&cfg).expect("serialize");
352 let back: RecallConfig = serde_json::from_str(&json).expect("deserialize");
353 let diff = (cfg.relevance_weight - back.relevance_weight).abs();
354 assert!(diff < 1e-12);
355 assert_eq!(cfg.decay_model, back.decay_model);
356 }
357
358 #[test]
359 fn decay_model_exponential_roundtrip() {
360 let m = DecayModel::Exponential;
361 let json = serde_json::to_string(&m).expect("serialize");
362 let back: DecayModel = serde_json::from_str(&json).expect("deserialize");
363 assert_eq!(m, back);
364 }
365
366 #[test]
367 fn decay_model_hyperbolic_roundtrip() {
368 let m = DecayModel::Hyperbolic;
369 let json = serde_json::to_string(&m).expect("serialize");
370 let back: DecayModel = serde_json::from_str(&json).expect("deserialize");
371 assert_eq!(m, back);
372 }
373
374 #[test]
375 fn decay_model_powerlaw_roundtrip() {
376 let m = DecayModel::PowerLaw {
377 half_life_days: 14.0,
378 };
379 let json = serde_json::to_string(&m).expect("serialize");
380 let back: DecayModel = serde_json::from_str(&json).expect("deserialize");
381 assert_eq!(m, back);
382 }
383
384 #[test]
385 fn decay_model_none_roundtrip() {
386 let m = DecayModel::None;
387 let json = serde_json::to_string(&m).expect("serialize");
388 let back: DecayModel = serde_json::from_str(&json).expect("deserialize");
389 assert_eq!(m, back);
390 }
391
392 #[test]
393 fn partial_config_deserializes_with_defaults() {
394 let json = r#"{"relevance_weight": 0.5}"#;
396 let cfg: RecallConfig = serde_json::from_str(json).expect("deserialize partial");
397 let diff = (cfg.relevance_weight - 0.5).abs();
399 assert!(diff < 1e-12);
400 let diff2 = (cfg.importance_weight - 0.20).abs();
402 assert!(diff2 < 1e-12);
403 assert_eq!(cfg.decay_model, DecayModel::Exponential);
404 }
405
406 #[test]
409 fn new_fields_have_correct_defaults() {
410 let cfg = RecallConfig::default();
411 assert_eq!(cfg.candidate_limit, None);
412 assert_eq!(cfg.fuse_strategy, FusionStrategy::Rrf { k: 60 });
413 assert!(!cfg.include_breakdown);
414 }
415
416 #[test]
417 fn candidate_limit_zero_fails_validation() {
418 let cfg = RecallConfig {
419 candidate_limit: Some(0),
420 ..RecallConfig::default()
421 };
422 assert!(cfg.validate().is_err());
423 }
424
425 #[test]
426 fn candidate_limit_some_positive_validates() {
427 let cfg = RecallConfig {
428 candidate_limit: Some(100),
429 ..RecallConfig::default()
430 };
431 assert!(cfg.validate().is_ok());
432 }
433
434 #[test]
435 fn min_score_nan_fails_validation() {
436 let cfg = RecallConfig {
437 min_score: f64::NAN,
438 ..RecallConfig::default()
439 };
440 assert!(cfg.validate().is_err());
441 }
442
443 #[test]
444 fn min_salience_nan_fails_validation() {
445 let cfg = RecallConfig {
446 min_salience: f64::NAN,
447 ..RecallConfig::default()
448 };
449 assert!(cfg.validate().is_err());
450 }
451
452 #[test]
453 fn new_fields_roundtrip() {
454 let cfg = RecallConfig {
455 candidate_limit: Some(50),
456 fuse_strategy: FusionStrategy::Union,
457 include_breakdown: true,
458 ..RecallConfig::default()
459 };
460 let json = serde_json::to_string(&cfg).expect("serialize");
461 let back: RecallConfig = serde_json::from_str(&json).expect("deserialize");
462 assert_eq!(back.candidate_limit, Some(50));
463 assert_eq!(back.fuse_strategy, FusionStrategy::Union);
464 assert!(back.include_breakdown);
465 }
466
467 #[test]
468 fn partial_config_new_fields_use_defaults() {
469 let json = r#"{"temporal_weight": 0.15}"#;
471 let cfg: RecallConfig = serde_json::from_str(json).expect("deserialize partial");
472 assert_eq!(cfg.candidate_limit, None);
473 assert_eq!(cfg.fuse_strategy, FusionStrategy::Rrf { k: 60 });
474 assert!(!cfg.include_breakdown);
475 }
476
477 #[test]
480 fn score_breakdown_total_sums_contributions() {
481 let bd = ScoreBreakdown {
482 relevance: 0.5,
483 importance_raw: 0.8,
484 importance_decayed: 0.6,
485 temporal: 0.3,
486 weighted: WeightedContributions {
487 relevance_contribution: 0.35,
488 importance_contribution: 0.12,
489 temporal_contribution: 0.03,
490 },
491 };
492 let expected = 0.35 + 0.12 + 0.03;
493 let diff = (bd.total() - expected).abs();
494 assert!(
495 diff < 1e-12,
496 "total() should sum weighted contributions, got {}",
497 bd.total()
498 );
499 }
500}