Skip to main content

khive_pack_brain/
tunable.rs

1use khive_runtime::pack::PackRuntime;
2use khive_runtime::RuntimeError;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::state::{BetaPosterior, BrainState};
7
8/// Packs that want auto-tuning implement this trait.
9/// The brain discovers tunable packs at startup via the PackRegistry.
10pub trait PackTunable: PackRuntime {
11    fn parameter_space(&self) -> ParameterSpace;
12    fn project_config(&self, state: &BrainState) -> Value;
13    fn apply_config(&self, config: Value) -> Result<(), RuntimeError>;
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ParameterSpace {
18    pub parameters: Vec<ParameterDef>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ParameterDef {
23    pub name: String,
24    pub prior_alpha: f64,
25    pub prior_beta: f64,
26    pub bounds: (f64, f64),
27}
28
29impl ParameterDef {
30    pub fn prior(&self) -> BetaPosterior {
31        BetaPosterior::new(self.prior_alpha, self.prior_beta)
32    }
33}
34
35#[cfg(test)]
36mod tests {
37    use super::*;
38
39    #[test]
40    fn parameter_def_prior_returns_matching_beta_posterior() {
41        let def = ParameterDef {
42            name: "recall::relevance_weight".into(),
43            prior_alpha: 2.0,
44            prior_beta: 8.0,
45            bounds: (0.0, 1.0),
46        };
47        let prior = def.prior();
48        assert!((prior.alpha - 2.0).abs() < 1e-12);
49        assert!((prior.beta - 8.0).abs() < 1e-12);
50        assert!((prior.mean() - 0.2).abs() < 1e-12);
51    }
52
53    #[test]
54    fn parameter_space_serializes() {
55        let space = ParameterSpace {
56            parameters: vec![ParameterDef {
57                name: "p".into(),
58                prior_alpha: 1.0,
59                prior_beta: 1.0,
60                bounds: (0.0, 1.0),
61            }],
62        };
63        let json = serde_json::to_string(&space).unwrap();
64        let back: ParameterSpace = serde_json::from_str(&json).unwrap();
65        assert_eq!(back.parameters.len(), 1);
66        assert_eq!(back.parameters[0].name, "p");
67    }
68}