Skip to main content

khive_pack_memory/
tunable.rs

1use khive_pack_brain::state::BrainState;
2use khive_pack_brain::tunable::{PackTunable, ParameterDef, ParameterSpace};
3use khive_runtime::RuntimeError;
4use serde_json::Value;
5
6use crate::config::RecallConfig;
7use crate::MemoryPack;
8
9/// `MemoryPack` implements `PackTunable` so that the brain can adjust the
10/// recall scoring pipeline based on observed usage patterns (Issue #159).
11///
12/// Parameter names (`memory::relevance_weight`, `memory::importance_weight`,
13/// `memory::temporal_weight`) match the keys that brain's `EventFold` tracks,
14/// so posteriors from real-time dispatch events flow directly into these params.
15///
16/// `project_config` reads posterior means → `RecallConfig`.
17/// `apply_config` validates and stores the new config; future recall calls
18/// pick it up via `MemoryPack::active_config()`.
19impl PackTunable for MemoryPack {
20    fn parameter_space(&self) -> ParameterSpace {
21        ParameterSpace {
22            parameters: vec![
23                ParameterDef {
24                    name: "memory::relevance_weight".into(),
25                    // Prior: relevance is the dominant signal (7:3), matching
26                    // EventFold's initial "recall::relevance_weight" posterior.
27                    prior_alpha: 7.0,
28                    prior_beta: 3.0,
29                    bounds: (0.0, 1.0),
30                },
31                ParameterDef {
32                    name: "memory::importance_weight".into(),
33                    // Prior: importance is secondary (2:8).
34                    prior_alpha: 2.0,
35                    prior_beta: 8.0,
36                    bounds: (0.0, 1.0),
37                },
38                ParameterDef {
39                    name: "memory::temporal_weight".into(),
40                    // Prior: temporal is weakest signal (1:9).
41                    prior_alpha: 1.0,
42                    prior_beta: 9.0,
43                    bounds: (0.0, 1.0),
44                },
45            ],
46        }
47    }
48
49    /// Project the current `BrainState` posteriors into a `RecallConfig` value.
50    ///
51    /// Reads `memory::*_weight` posterior means from `state`. Falls back to the
52    /// current active config if a parameter is absent (brain not yet warmed up).
53    fn project_config(&self, state: &BrainState) -> Value {
54        let current = self.active_config();
55
56        let relevance = state
57            .parameters
58            .get("memory::relevance_weight")
59            .map(|p| p.mean())
60            .unwrap_or(current.relevance_weight);
61
62        let importance = state
63            .parameters
64            .get("memory::importance_weight")
65            .map(|p| p.mean())
66            .unwrap_or(current.importance_weight);
67
68        let temporal = state
69            .parameters
70            .get("memory::temporal_weight")
71            .map(|p| p.mean())
72            .unwrap_or(current.temporal_weight);
73
74        let projected = RecallConfig {
75            relevance_weight: relevance,
76            importance_weight: importance,
77            temporal_weight: temporal,
78            ..current
79        };
80
81        serde_json::to_value(projected).unwrap_or_else(|_| serde_json::json!({}))
82    }
83
84    /// Apply a projected config to the pack.
85    ///
86    /// Deserializes the JSON value into a `RecallConfig`, validates it, and
87    /// stores it as the active config. Future recall calls pick up the new
88    /// weights via `MemoryPack::active_config()`.
89    fn apply_config(&self, config: Value) -> Result<(), RuntimeError> {
90        let new_cfg: RecallConfig = serde_json::from_value(config)
91            .map_err(|e| RuntimeError::InvalidInput(format!("invalid RecallConfig: {e}")))?;
92        new_cfg.validate()?;
93        *self.config.lock().unwrap() = new_cfg;
94        Ok(())
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use khive_pack_brain::state::BetaPosterior;
102    use khive_runtime::KhiveRuntime;
103    use std::collections::HashMap;
104
105    fn make_pack() -> MemoryPack {
106        let rt = KhiveRuntime::memory().expect("in-memory runtime");
107        MemoryPack::new(rt)
108    }
109
110    fn brain_state_with_params(params: HashMap<String, BetaPosterior>) -> BrainState {
111        BrainState::new(params, 100)
112    }
113
114    #[test]
115    fn parameter_space_has_three_params() {
116        let pack = make_pack();
117        let space = pack.parameter_space();
118        assert_eq!(space.parameters.len(), 3);
119        let names: Vec<&str> = space.parameters.iter().map(|p| p.name.as_str()).collect();
120        assert!(names.contains(&"memory::relevance_weight"));
121        assert!(names.contains(&"memory::importance_weight"));
122        assert!(names.contains(&"memory::temporal_weight"));
123    }
124
125    #[test]
126    fn project_config_reads_posterior_means() {
127        let pack = make_pack();
128        let mut params = HashMap::new();
129        params.insert(
130            "memory::relevance_weight".into(),
131            BetaPosterior::new(6.0, 4.0), // mean = 0.6
132        );
133        params.insert(
134            "memory::importance_weight".into(),
135            BetaPosterior::new(3.0, 7.0), // mean = 0.3
136        );
137        params.insert(
138            "memory::temporal_weight".into(),
139            BetaPosterior::new(1.0, 9.0), // mean = 0.1
140        );
141        let state = brain_state_with_params(params);
142        let projected = pack.project_config(&state);
143
144        let cfg: RecallConfig = serde_json::from_value(projected).unwrap();
145        assert!((cfg.relevance_weight - 0.6).abs() < 1e-10);
146        assert!((cfg.importance_weight - 0.3).abs() < 1e-10);
147        assert!((cfg.temporal_weight - 0.1).abs() < 1e-10);
148    }
149
150    #[test]
151    fn project_config_falls_back_to_active_when_param_absent() {
152        let pack = make_pack();
153        let state = brain_state_with_params(HashMap::new());
154        let projected = pack.project_config(&state);
155
156        let cfg: RecallConfig = serde_json::from_value(projected).unwrap();
157        assert!((cfg.relevance_weight - 0.70).abs() < 1e-10);
158        assert!((cfg.importance_weight - 0.20).abs() < 1e-10);
159        assert!((cfg.temporal_weight - 0.10).abs() < 1e-10);
160    }
161
162    #[test]
163    fn apply_config_updates_active_config() {
164        let pack = make_pack();
165        let new_cfg = RecallConfig {
166            relevance_weight: 0.5,
167            importance_weight: 0.3,
168            temporal_weight: 0.2,
169            ..RecallConfig::default()
170        };
171        let config_value = serde_json::to_value(&new_cfg).unwrap();
172        pack.apply_config(config_value)
173            .expect("apply_config succeeds");
174
175        let active = pack.active_config();
176        assert!((active.relevance_weight - 0.5).abs() < 1e-10);
177        assert!((active.importance_weight - 0.3).abs() < 1e-10);
178        assert!((active.temporal_weight - 0.2).abs() < 1e-10);
179    }
180
181    #[test]
182    fn apply_config_rejects_all_zero_weights() {
183        let pack = make_pack();
184        let bad_cfg = RecallConfig {
185            relevance_weight: 0.0,
186            importance_weight: 0.0,
187            temporal_weight: 0.0,
188            ..RecallConfig::default()
189        };
190        let config_value = serde_json::to_value(&bad_cfg).unwrap();
191        assert!(pack.apply_config(config_value).is_err());
192    }
193
194    #[test]
195    fn apply_config_rejects_malformed_json() {
196        let pack = make_pack();
197        let bad = serde_json::json!({ "relevance_weight": "not_a_number" });
198        assert!(pack.apply_config(bad).is_err());
199    }
200
201    #[test]
202    fn prior_for_relevance_weight_matches_fold_priors() {
203        let pack = make_pack();
204        let space = pack.parameter_space();
205        let def = space
206            .parameters
207            .iter()
208            .find(|p| p.name == "memory::relevance_weight")
209            .unwrap();
210        assert!((def.prior_alpha - 7.0).abs() < 1e-12);
211        assert!((def.prior_beta - 3.0).abs() < 1e-12);
212    }
213}