Skip to main content

khive_pack_memory/
config.rs

1use serde::{Deserialize, Serialize};
2
3use khive_runtime::{FusionStrategy, RuntimeError};
4
5/// Configuration for the recall scoring pipeline.
6/// All fields have sensible defaults matching current behavior.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8#[serde(default)]
9pub struct RecallConfig {
10    // --- Fusion weights ---
11    /// Weight of RRF/fusion score. Default 0.70.
12    pub relevance_weight: f64,
13    /// Weight of decay-adjusted salience. Default 0.20.
14    pub importance_weight: f64,
15    /// Weight of pure recency. Default 0.10.
16    pub temporal_weight: f64,
17
18    // --- Temporal parameters ---
19    /// Days for temporal score to halve. Default 30.0.
20    pub temporal_half_life_days: f64,
21    /// Decay model to apply to salience. Default Exponential.
22    pub decay_model: DecayModel,
23
24    // --- Retrieval parameters ---
25    /// Candidates per retrieval path before fusion = limit × this. Default 20.
26    pub candidate_multiplier: u32,
27    /// Explicit max candidates per retrieval path before fusion. When None,
28    /// candidate_multiplier keeps the legacy behavior.
29    pub candidate_limit: Option<u32>,
30    /// Strategy used to fuse retrieval-source candidate lists. Default RRF k=60.
31    pub fuse_strategy: FusionStrategy,
32    /// Minimum composite score to include in results. Default 0.0.
33    pub min_score: f64,
34    /// Minimum raw salience to include in results. Default 0.0.
35    pub min_salience: f64,
36    /// Include per-component score breakdowns in recall responses. Default false.
37    pub include_breakdown: bool,
38}
39
40impl Default for RecallConfig {
41    fn default() -> Self {
42        Self {
43            relevance_weight: 0.70,
44            importance_weight: 0.20,
45            temporal_weight: 0.10,
46            temporal_half_life_days: 30.0,
47            decay_model: DecayModel::default(),
48            candidate_multiplier: 20,
49            candidate_limit: None,
50            fuse_strategy: FusionStrategy::default(),
51            min_score: 0.0,
52            min_salience: 0.0,
53            include_breakdown: false,
54        }
55    }
56}
57
58impl RecallConfig {
59    /// Validate that the config is internally consistent.
60    ///
61    /// Rejects:
62    /// - Negative weights
63    /// - All three weights summing to zero (no scoring signal)
64    /// - Non-positive temporal half-life
65    pub fn validate(&self) -> Result<(), RuntimeError> {
66        if self.relevance_weight < 0.0 {
67            return Err(RuntimeError::InvalidInput(
68                "relevance_weight must be non-negative".to_string(),
69            ));
70        }
71        if self.importance_weight < 0.0 {
72            return Err(RuntimeError::InvalidInput(
73                "importance_weight must be non-negative".to_string(),
74            ));
75        }
76        if self.temporal_weight < 0.0 {
77            return Err(RuntimeError::InvalidInput(
78                "temporal_weight must be non-negative".to_string(),
79            ));
80        }
81        let weight_sum = self.relevance_weight + self.importance_weight + self.temporal_weight;
82        if weight_sum <= 0.0 {
83            return Err(RuntimeError::InvalidInput(
84                "at least one of relevance_weight / importance_weight / temporal_weight must be positive".to_string(),
85            ));
86        }
87        if self.temporal_half_life_days <= 0.0 {
88            return Err(RuntimeError::InvalidInput(
89                "temporal_half_life_days must be positive".to_string(),
90            ));
91        }
92        if self.candidate_limit == Some(0) {
93            return Err(RuntimeError::InvalidInput(
94                "candidate_limit must be positive when provided".to_string(),
95            ));
96        }
97        if !self.min_score.is_finite() {
98            return Err(RuntimeError::InvalidInput(
99                "min_score must be finite".to_string(),
100            ));
101        }
102        if !self.min_salience.is_finite() {
103            return Err(RuntimeError::InvalidInput(
104                "min_salience must be finite".to_string(),
105            ));
106        }
107        Ok(())
108    }
109}
110
111/// How salience decays over time.
112#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
113#[serde(rename_all = "snake_case")]
114pub enum DecayModel {
115    /// `salience * exp(-age * ln2 / half_life)`
116    ///
117    /// This is the original formula; it is the default.
118    #[default]
119    Exponential,
120    /// `salience / (1 + decay_factor * age_days)`
121    Hyperbolic,
122    /// `salience * half_life / (half_life + age_days)`
123    PowerLaw {
124        /// Override half-life (days) for the power-law model.
125        /// Falls back to RecallConfig.temporal_half_life_days when absent.
126        half_life_days: f64,
127    },
128    /// No decay — salience is used as-is.
129    None,
130}
131
132impl DecayModel {
133    /// Apply decay to a salience value.
134    ///
135    /// - `salience`    — raw importance in [0, 1]
136    /// - `age_days`    — age of the note in days
137    /// - `decay_factor`— per-note decay rate stored on the note (used by Exponential and Hyperbolic)
138    /// - `half_life`   — config half-life, used by Exponential (as formula half-life) and PowerLaw
139    pub fn apply(&self, salience: f64, age_days: f64, decay_factor: f64, half_life: f64) -> f64 {
140        match self {
141            DecayModel::Exponential => {
142                // Uses the proper half-life formula: exp(-age * ln2 / half_life)
143                // This gives exactly 0.5 at age == half_life.
144                let k = std::f64::consts::LN_2 / half_life;
145                salience * (-k * age_days).exp()
146            }
147            DecayModel::Hyperbolic => salience / (1.0 + decay_factor * age_days),
148            DecayModel::PowerLaw { half_life_days } => {
149                let hl = *half_life_days;
150                salience * hl / (hl + age_days)
151            }
152            DecayModel::None => salience,
153        }
154    }
155}
156
157/// Per-component score contributions for a single recall result.
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ScoreBreakdown {
160    /// Raw RRF fusion score (before weighting).
161    pub relevance: f64,
162    /// Raw salience from the note (before decay).
163    pub importance_raw: f64,
164    /// Salience after applying the decay model.
165    pub importance_decayed: f64,
166    /// Temporal recency score (half-life decay, independent of note's own decay_factor).
167    pub temporal: f64,
168    /// Weighted contributions summing to the total score.
169    pub weighted: WeightedContributions,
170}
171
172impl ScoreBreakdown {
173    /// Total composite score.
174    pub fn total(&self) -> f64 {
175        self.weighted.relevance_contribution
176            + self.weighted.importance_contribution
177            + self.weighted.temporal_contribution
178    }
179}
180
181/// The three weighted components that make up the final score.
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct WeightedContributions {
184    pub relevance_contribution: f64,
185    pub importance_contribution: f64,
186    pub temporal_contribution: f64,
187}
188
189// ── Tests ─────────────────────────────────────────────────────────────────────
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    // ── DecayModel ────────────────────────────────────────────────────────────
196
197    #[test]
198    fn exponential_halves_at_half_life() {
199        let model = DecayModel::Exponential;
200        let salience = 1.0;
201        let half_life = 30.0;
202        let result = model.apply(salience, half_life, 0.01, half_life);
203        let diff = (result - 0.5).abs();
204        assert!(
205            diff < 1e-10,
206            "exponential should give 0.5 at half-life, got {result}"
207        );
208    }
209
210    #[test]
211    fn exponential_full_salience_at_zero_age() {
212        let model = DecayModel::Exponential;
213        let result = model.apply(0.8, 0.0, 0.01, 30.0);
214        let diff = (result - 0.8).abs();
215        assert!(
216            diff < 1e-12,
217            "at age=0 salience should be unchanged, got {result}"
218        );
219    }
220
221    #[test]
222    fn hyperbolic_halves_at_one_over_decay_factor() {
223        // salience / (1 + k * age) = 0.5 when age = 1/k
224        let model = DecayModel::Hyperbolic;
225        let salience = 1.0;
226        let k = 0.05;
227        let age = 1.0 / k; // 20 days
228        let result = model.apply(salience, age, k, 30.0);
229        let diff = (result - 0.5).abs();
230        assert!(
231            diff < 1e-10,
232            "hyperbolic at age=1/k should give 0.5, got {result}"
233        );
234    }
235
236    #[test]
237    fn hyperbolic_full_salience_at_zero_age() {
238        let model = DecayModel::Hyperbolic;
239        let result = model.apply(0.7, 0.0, 0.05, 30.0);
240        let diff = (result - 0.7).abs();
241        assert!(
242            diff < 1e-12,
243            "at age=0 salience should be unchanged, got {result}"
244        );
245    }
246
247    #[test]
248    fn powerlaw_halves_at_half_life() {
249        let hl = 30.0;
250        let model = DecayModel::PowerLaw { half_life_days: hl };
251        let salience = 1.0;
252        // salience * hl / (hl + age) = 0.5 when age = hl
253        let result = model.apply(salience, hl, 0.01, hl);
254        let diff = (result - 0.5).abs();
255        assert!(
256            diff < 1e-10,
257            "power-law should give 0.5 at half-life, got {result}"
258        );
259    }
260
261    #[test]
262    fn decay_none_returns_salience_unchanged() {
263        let model = DecayModel::None;
264        let result = model.apply(0.6, 100.0, 0.99, 30.0);
265        let diff = (result - 0.6).abs();
266        assert!(
267            diff < 1e-12,
268            "None model must not alter salience, got {result}"
269        );
270    }
271
272    // ── RecallConfig ──────────────────────────────────────────────────────────
273
274    #[test]
275    fn default_config_validates() {
276        assert!(RecallConfig::default().validate().is_ok());
277    }
278
279    #[test]
280    fn negative_relevance_weight_fails_validation() {
281        let cfg = RecallConfig {
282            relevance_weight: -0.1,
283            ..RecallConfig::default()
284        };
285        assert!(cfg.validate().is_err());
286    }
287
288    #[test]
289    fn negative_importance_weight_fails_validation() {
290        let cfg = RecallConfig {
291            importance_weight: -1.0,
292            ..RecallConfig::default()
293        };
294        assert!(cfg.validate().is_err());
295    }
296
297    #[test]
298    fn negative_temporal_weight_fails_validation() {
299        let cfg = RecallConfig {
300            temporal_weight: -0.5,
301            ..RecallConfig::default()
302        };
303        assert!(cfg.validate().is_err());
304    }
305
306    #[test]
307    fn all_zero_weights_fails_validation() {
308        let cfg = RecallConfig {
309            relevance_weight: 0.0,
310            importance_weight: 0.0,
311            temporal_weight: 0.0,
312            ..RecallConfig::default()
313        };
314        assert!(cfg.validate().is_err());
315    }
316
317    #[test]
318    fn zero_half_life_fails_validation() {
319        let cfg = RecallConfig {
320            temporal_half_life_days: 0.0,
321            ..RecallConfig::default()
322        };
323        assert!(cfg.validate().is_err());
324    }
325
326    #[test]
327    fn negative_half_life_fails_validation() {
328        let cfg = RecallConfig {
329            temporal_half_life_days: -5.0,
330            ..RecallConfig::default()
331        };
332        assert!(cfg.validate().is_err());
333    }
334
335    #[test]
336    fn non_uniform_weights_validate() {
337        let cfg = RecallConfig {
338            relevance_weight: 0.5,
339            importance_weight: 0.3,
340            temporal_weight: 0.2,
341            ..RecallConfig::default()
342        };
343        assert!(cfg.validate().is_ok());
344    }
345
346    // ── Serde roundtrips ──────────────────────────────────────────────────────
347
348    #[test]
349    fn default_config_roundtrip() {
350        let cfg = RecallConfig::default();
351        let json = serde_json::to_string(&cfg).expect("serialize");
352        let back: RecallConfig = serde_json::from_str(&json).expect("deserialize");
353        let diff = (cfg.relevance_weight - back.relevance_weight).abs();
354        assert!(diff < 1e-12);
355        assert_eq!(cfg.decay_model, back.decay_model);
356    }
357
358    #[test]
359    fn decay_model_exponential_roundtrip() {
360        let m = DecayModel::Exponential;
361        let json = serde_json::to_string(&m).expect("serialize");
362        let back: DecayModel = serde_json::from_str(&json).expect("deserialize");
363        assert_eq!(m, back);
364    }
365
366    #[test]
367    fn decay_model_hyperbolic_roundtrip() {
368        let m = DecayModel::Hyperbolic;
369        let json = serde_json::to_string(&m).expect("serialize");
370        let back: DecayModel = serde_json::from_str(&json).expect("deserialize");
371        assert_eq!(m, back);
372    }
373
374    #[test]
375    fn decay_model_powerlaw_roundtrip() {
376        let m = DecayModel::PowerLaw {
377            half_life_days: 14.0,
378        };
379        let json = serde_json::to_string(&m).expect("serialize");
380        let back: DecayModel = serde_json::from_str(&json).expect("deserialize");
381        assert_eq!(m, back);
382    }
383
384    #[test]
385    fn decay_model_none_roundtrip() {
386        let m = DecayModel::None;
387        let json = serde_json::to_string(&m).expect("serialize");
388        let back: DecayModel = serde_json::from_str(&json).expect("deserialize");
389        assert_eq!(m, back);
390    }
391
392    #[test]
393    fn partial_config_deserializes_with_defaults() {
394        // Only override one field — the rest should default.
395        let json = r#"{"relevance_weight": 0.5}"#;
396        let cfg: RecallConfig = serde_json::from_str(json).expect("deserialize partial");
397        // specified field
398        let diff = (cfg.relevance_weight - 0.5).abs();
399        assert!(diff < 1e-12);
400        // unspecified fields keep defaults
401        let diff2 = (cfg.importance_weight - 0.20).abs();
402        assert!(diff2 < 1e-12);
403        assert_eq!(cfg.decay_model, DecayModel::Exponential);
404    }
405
406    // ── RecallConfig new fields ───────────────────────────────────────────────
407
408    #[test]
409    fn new_fields_have_correct_defaults() {
410        let cfg = RecallConfig::default();
411        assert_eq!(cfg.candidate_limit, None);
412        assert_eq!(cfg.fuse_strategy, FusionStrategy::Rrf { k: 60 });
413        assert!(!cfg.include_breakdown);
414    }
415
416    #[test]
417    fn candidate_limit_zero_fails_validation() {
418        let cfg = RecallConfig {
419            candidate_limit: Some(0),
420            ..RecallConfig::default()
421        };
422        assert!(cfg.validate().is_err());
423    }
424
425    #[test]
426    fn candidate_limit_some_positive_validates() {
427        let cfg = RecallConfig {
428            candidate_limit: Some(100),
429            ..RecallConfig::default()
430        };
431        assert!(cfg.validate().is_ok());
432    }
433
434    #[test]
435    fn min_score_nan_fails_validation() {
436        let cfg = RecallConfig {
437            min_score: f64::NAN,
438            ..RecallConfig::default()
439        };
440        assert!(cfg.validate().is_err());
441    }
442
443    #[test]
444    fn min_salience_nan_fails_validation() {
445        let cfg = RecallConfig {
446            min_salience: f64::NAN,
447            ..RecallConfig::default()
448        };
449        assert!(cfg.validate().is_err());
450    }
451
452    #[test]
453    fn new_fields_roundtrip() {
454        let cfg = RecallConfig {
455            candidate_limit: Some(50),
456            fuse_strategy: FusionStrategy::Union,
457            include_breakdown: true,
458            ..RecallConfig::default()
459        };
460        let json = serde_json::to_string(&cfg).expect("serialize");
461        let back: RecallConfig = serde_json::from_str(&json).expect("deserialize");
462        assert_eq!(back.candidate_limit, Some(50));
463        assert_eq!(back.fuse_strategy, FusionStrategy::Union);
464        assert!(back.include_breakdown);
465    }
466
467    #[test]
468    fn partial_config_new_fields_use_defaults() {
469        // Parse JSON that omits all new fields — they should fall back to defaults.
470        let json = r#"{"temporal_weight": 0.15}"#;
471        let cfg: RecallConfig = serde_json::from_str(json).expect("deserialize partial");
472        assert_eq!(cfg.candidate_limit, None);
473        assert_eq!(cfg.fuse_strategy, FusionStrategy::Rrf { k: 60 });
474        assert!(!cfg.include_breakdown);
475    }
476
477    // ── ScoreBreakdown ────────────────────────────────────────────────────────
478
479    #[test]
480    fn score_breakdown_total_sums_contributions() {
481        let bd = ScoreBreakdown {
482            relevance: 0.5,
483            importance_raw: 0.8,
484            importance_decayed: 0.6,
485            temporal: 0.3,
486            weighted: WeightedContributions {
487                relevance_contribution: 0.35,
488                importance_contribution: 0.12,
489                temporal_contribution: 0.03,
490            },
491        };
492        let expected = 0.35 + 0.12 + 0.03;
493        let diff = (bd.total() - expected).abs();
494        assert!(
495            diff < 1e-12,
496            "total() should sum weighted contributions, got {}",
497            bd.total()
498        );
499    }
500}