khive_pack_memory/
tunable.rs1use khive_pack_brain::state::BrainState;
2use khive_pack_brain::tunable::{PackTunable, ParameterDef, ParameterSpace};
3use khive_runtime::RuntimeError;
4use serde_json::Value;
5
6use crate::config::RecallConfig;
7use crate::MemoryPack;
8
9impl PackTunable for MemoryPack {
20 fn parameter_space(&self) -> ParameterSpace {
21 ParameterSpace {
22 parameters: vec![
23 ParameterDef {
24 name: "memory::relevance_weight".into(),
25 prior_alpha: 7.0,
28 prior_beta: 3.0,
29 bounds: (0.0, 1.0),
30 },
31 ParameterDef {
32 name: "memory::importance_weight".into(),
33 prior_alpha: 2.0,
35 prior_beta: 8.0,
36 bounds: (0.0, 1.0),
37 },
38 ParameterDef {
39 name: "memory::temporal_weight".into(),
40 prior_alpha: 1.0,
42 prior_beta: 9.0,
43 bounds: (0.0, 1.0),
44 },
45 ],
46 }
47 }
48
49 fn project_config(&self, state: &BrainState) -> Value {
54 let current = self.active_config();
55
56 let relevance = state
57 .parameters
58 .get("memory::relevance_weight")
59 .map(|p| p.mean())
60 .unwrap_or(current.relevance_weight);
61
62 let importance = state
63 .parameters
64 .get("memory::importance_weight")
65 .map(|p| p.mean())
66 .unwrap_or(current.importance_weight);
67
68 let temporal = state
69 .parameters
70 .get("memory::temporal_weight")
71 .map(|p| p.mean())
72 .unwrap_or(current.temporal_weight);
73
74 let projected = RecallConfig {
75 relevance_weight: relevance,
76 importance_weight: importance,
77 temporal_weight: temporal,
78 ..current
79 };
80
81 serde_json::to_value(projected).unwrap_or_else(|_| serde_json::json!({}))
82 }
83
84 fn apply_config(&self, config: Value) -> Result<(), RuntimeError> {
90 let new_cfg: RecallConfig = serde_json::from_value(config)
91 .map_err(|e| RuntimeError::InvalidInput(format!("invalid RecallConfig: {e}")))?;
92 new_cfg.validate()?;
93 *self.config.lock().unwrap() = new_cfg;
94 Ok(())
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use khive_pack_brain::state::BetaPosterior;
102 use khive_runtime::KhiveRuntime;
103 use std::collections::HashMap;
104
105 fn make_pack() -> MemoryPack {
106 let rt = KhiveRuntime::memory().expect("in-memory runtime");
107 MemoryPack::new(rt)
108 }
109
110 fn brain_state_with_params(params: HashMap<String, BetaPosterior>) -> BrainState {
111 BrainState::new(params, 100)
112 }
113
114 #[test]
115 fn parameter_space_has_three_params() {
116 let pack = make_pack();
117 let space = pack.parameter_space();
118 assert_eq!(space.parameters.len(), 3);
119 let names: Vec<&str> = space.parameters.iter().map(|p| p.name.as_str()).collect();
120 assert!(names.contains(&"memory::relevance_weight"));
121 assert!(names.contains(&"memory::importance_weight"));
122 assert!(names.contains(&"memory::temporal_weight"));
123 }
124
125 #[test]
126 fn project_config_reads_posterior_means() {
127 let pack = make_pack();
128 let mut params = HashMap::new();
129 params.insert(
130 "memory::relevance_weight".into(),
131 BetaPosterior::new(6.0, 4.0), );
133 params.insert(
134 "memory::importance_weight".into(),
135 BetaPosterior::new(3.0, 7.0), );
137 params.insert(
138 "memory::temporal_weight".into(),
139 BetaPosterior::new(1.0, 9.0), );
141 let state = brain_state_with_params(params);
142 let projected = pack.project_config(&state);
143
144 let cfg: RecallConfig = serde_json::from_value(projected).unwrap();
145 assert!((cfg.relevance_weight - 0.6).abs() < 1e-10);
146 assert!((cfg.importance_weight - 0.3).abs() < 1e-10);
147 assert!((cfg.temporal_weight - 0.1).abs() < 1e-10);
148 }
149
150 #[test]
151 fn project_config_falls_back_to_active_when_param_absent() {
152 let pack = make_pack();
153 let state = brain_state_with_params(HashMap::new());
154 let projected = pack.project_config(&state);
155
156 let cfg: RecallConfig = serde_json::from_value(projected).unwrap();
157 assert!((cfg.relevance_weight - 0.70).abs() < 1e-10);
158 assert!((cfg.importance_weight - 0.20).abs() < 1e-10);
159 assert!((cfg.temporal_weight - 0.10).abs() < 1e-10);
160 }
161
162 #[test]
163 fn apply_config_updates_active_config() {
164 let pack = make_pack();
165 let new_cfg = RecallConfig {
166 relevance_weight: 0.5,
167 importance_weight: 0.3,
168 temporal_weight: 0.2,
169 ..RecallConfig::default()
170 };
171 let config_value = serde_json::to_value(&new_cfg).unwrap();
172 pack.apply_config(config_value)
173 .expect("apply_config succeeds");
174
175 let active = pack.active_config();
176 assert!((active.relevance_weight - 0.5).abs() < 1e-10);
177 assert!((active.importance_weight - 0.3).abs() < 1e-10);
178 assert!((active.temporal_weight - 0.2).abs() < 1e-10);
179 }
180
181 #[test]
182 fn apply_config_rejects_all_zero_weights() {
183 let pack = make_pack();
184 let bad_cfg = RecallConfig {
185 relevance_weight: 0.0,
186 importance_weight: 0.0,
187 temporal_weight: 0.0,
188 ..RecallConfig::default()
189 };
190 let config_value = serde_json::to_value(&bad_cfg).unwrap();
191 assert!(pack.apply_config(config_value).is_err());
192 }
193
194 #[test]
195 fn apply_config_rejects_malformed_json() {
196 let pack = make_pack();
197 let bad = serde_json::json!({ "relevance_weight": "not_a_number" });
198 assert!(pack.apply_config(bad).is_err());
199 }
200
201 #[test]
202 fn prior_for_relevance_weight_matches_fold_priors() {
203 let pack = make_pack();
204 let space = pack.parameter_space();
205 let def = space
206 .parameters
207 .iter()
208 .find(|p| p.name == "memory::relevance_weight")
209 .unwrap();
210 assert!((def.prior_alpha - 7.0).abs() < 1e-12);
211 assert!((def.prior_beta - 3.0).abs() < 1e-12);
212 }
213}