Skip to main content

khive_pack_brain/
state.rs

1use std::collections::{HashMap, VecDeque};
2
3use serde::{Deserialize, Serialize};
4use uuid::Uuid;
5
6/// Beta-Binomial posterior for a single parameter.
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
8pub struct BetaPosterior {
9    pub alpha: f64,
10    pub beta: f64,
11}
12
13impl BetaPosterior {
14    pub fn new(alpha: f64, beta: f64) -> Self {
15        Self { alpha, beta }
16    }
17
18    pub fn mean(&self) -> f64 {
19        self.alpha / (self.alpha + self.beta)
20    }
21
22    pub fn variance(&self) -> f64 {
23        let n = self.alpha + self.beta;
24        (self.alpha * self.beta) / (n * n * (n + 1.0))
25    }
26
27    pub fn effective_sample_size(&self) -> f64 {
28        self.alpha + self.beta
29    }
30
31    pub fn update_success(&mut self) {
32        self.alpha += 1.0;
33    }
34
35    pub fn update_failure(&mut self) {
36        self.beta += 1.0;
37    }
38}
39
40impl Default for BetaPosterior {
41    fn default() -> Self {
42        Self::new(1.0, 1.0)
43    }
44}
45
46/// Bounded LRU map for per-entity posteriors.
47/// Uses a VecDeque to track access order; evicts oldest on insert when full.
48pub struct EntityPosteriors {
49    map: HashMap<Uuid, BetaPosterior>,
50    order: VecDeque<Uuid>,
51    capacity: usize,
52}
53
54impl EntityPosteriors {
55    pub fn new(capacity: usize) -> Self {
56        Self {
57            map: HashMap::with_capacity(capacity),
58            order: VecDeque::with_capacity(capacity),
59            capacity,
60        }
61    }
62
63    pub fn get_or_insert(
64        &mut self,
65        id: Uuid,
66        default: impl FnOnce() -> BetaPosterior,
67    ) -> &mut BetaPosterior {
68        if !self.map.contains_key(&id) {
69            if self.map.len() >= self.capacity {
70                if let Some(evicted) = self.order.pop_front() {
71                    self.map.remove(&evicted);
72                }
73            }
74            self.map.insert(id, default());
75            self.order.push_back(id);
76        }
77        self.map.get_mut(&id).unwrap()
78    }
79
80    pub fn get(&self, id: &Uuid) -> Option<&BetaPosterior> {
81        self.map.get(id)
82    }
83
84    pub fn len(&self) -> usize {
85        self.map.len()
86    }
87
88    pub fn is_empty(&self) -> bool {
89        self.map.is_empty()
90    }
91
92    pub fn clear(&mut self) {
93        self.map.clear();
94        self.order.clear();
95    }
96
97    pub fn to_snapshot(&self) -> HashMap<Uuid, BetaPosterior> {
98        self.map.clone()
99    }
100
101    pub fn from_snapshot(snapshot: HashMap<Uuid, BetaPosterior>, capacity: usize) -> Self {
102        let mut ep = Self::new(capacity);
103        for (id, posterior) in snapshot {
104            ep.map.insert(id, posterior);
105            ep.order.push_back(id);
106        }
107        ep
108    }
109}
110
111/// Runtime brain state — not directly serializable (contains LRU).
112pub struct BrainState {
113    pub parameters: HashMap<String, BetaPosterior>,
114    pub entity_posteriors: EntityPosteriors,
115    pub total_events: u64,
116    pub exploration_epoch: u64,
117}
118
119impl BrainState {
120    pub fn new(parameters: HashMap<String, BetaPosterior>, entity_capacity: usize) -> Self {
121        Self {
122            parameters,
123            entity_posteriors: EntityPosteriors::new(entity_capacity),
124            total_events: 0,
125            exploration_epoch: 0,
126        }
127    }
128
129    pub fn to_snapshot(&self) -> BrainStateSnapshot {
130        BrainStateSnapshot {
131            parameters: self.parameters.clone(),
132            entity_posteriors: self.entity_posteriors.to_snapshot(),
133            total_events: self.total_events,
134            exploration_epoch: self.exploration_epoch,
135        }
136    }
137
138    pub fn from_snapshot(snapshot: BrainStateSnapshot, entity_capacity: usize) -> Self {
139        Self {
140            parameters: snapshot.parameters,
141            entity_posteriors: EntityPosteriors::from_snapshot(
142                snapshot.entity_posteriors,
143                entity_capacity,
144            ),
145            total_events: snapshot.total_events,
146            exploration_epoch: snapshot.exploration_epoch,
147        }
148    }
149
150    pub fn reset_posteriors(&mut self) {
151        for posterior in self.parameters.values_mut() {
152            *posterior = BetaPosterior::new(1.0, 1.0);
153        }
154        self.entity_posteriors.clear();
155        self.exploration_epoch += 1;
156    }
157}
158
159/// Serializable snapshot of BrainState for persistence and inspection.
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct BrainStateSnapshot {
162    pub parameters: HashMap<String, BetaPosterior>,
163    pub entity_posteriors: HashMap<Uuid, BetaPosterior>,
164    pub total_events: u64,
165    pub exploration_epoch: u64,
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn beta_posterior_mean() {
174        let p = BetaPosterior::new(7.0, 3.0);
175        assert!((p.mean() - 0.7).abs() < 1e-12);
176    }
177
178    #[test]
179    fn beta_posterior_variance() {
180        let p = BetaPosterior::new(7.0, 3.0);
181        // var = 7*3 / (10*10*11) = 21/1100 ≈ 0.01909
182        let expected = 21.0 / 1100.0;
183        assert!((p.variance() - expected).abs() < 1e-12);
184    }
185
186    #[test]
187    fn beta_posterior_ess() {
188        let p = BetaPosterior::new(7.0, 3.0);
189        assert!((p.effective_sample_size() - 10.0).abs() < 1e-12);
190    }
191
192    #[test]
193    fn beta_posterior_update() {
194        let mut p = BetaPosterior::new(1.0, 1.0);
195        p.update_success();
196        p.update_success();
197        p.update_failure();
198        assert!((p.alpha - 3.0).abs() < 1e-12);
199        assert!((p.beta - 2.0).abs() < 1e-12);
200        assert!((p.mean() - 0.6).abs() < 1e-12);
201    }
202
203    #[test]
204    fn entity_posteriors_eviction() {
205        let mut ep = EntityPosteriors::new(3);
206        let ids: Vec<Uuid> = (0..5).map(|_| Uuid::new_v4()).collect();
207        for id in &ids {
208            ep.get_or_insert(*id, BetaPosterior::default);
209        }
210        assert_eq!(ep.len(), 3);
211        // First two should be evicted
212        assert!(ep.get(&ids[0]).is_none());
213        assert!(ep.get(&ids[1]).is_none());
214        assert!(ep.get(&ids[2]).is_some());
215        assert!(ep.get(&ids[3]).is_some());
216        assert!(ep.get(&ids[4]).is_some());
217    }
218
219    #[test]
220    fn entity_posteriors_get_or_insert_existing() {
221        let mut ep = EntityPosteriors::new(10);
222        let id = Uuid::new_v4();
223        ep.get_or_insert(id, BetaPosterior::default)
224            .update_success();
225        let p = ep.get_or_insert(id, BetaPosterior::default);
226        assert!((p.alpha - 2.0).abs() < 1e-12);
227    }
228
229    #[test]
230    fn brain_state_snapshot_roundtrip() {
231        let mut state = BrainState::new(HashMap::new(), 100);
232        state.parameters.insert(
233            "memory::relevance_weight".into(),
234            BetaPosterior::new(7.0, 3.0),
235        );
236        state.total_events = 42;
237        let id = Uuid::new_v4();
238        state
239            .entity_posteriors
240            .get_or_insert(id, BetaPosterior::default)
241            .update_success();
242
243        let snapshot = state.to_snapshot();
244        let json = serde_json::to_string(&snapshot).unwrap();
245        let back: BrainStateSnapshot = serde_json::from_str(&json).unwrap();
246        assert_eq!(back.total_events, 42);
247        assert!(back.parameters.contains_key("memory::relevance_weight"));
248        assert!(back.entity_posteriors.contains_key(&id));
249    }
250
251    #[test]
252    fn beta_posterior_default_has_uniform_prior() {
253        let p = BetaPosterior::default();
254        assert!((p.alpha - 1.0).abs() < 1e-12);
255        assert!((p.beta - 1.0).abs() < 1e-12);
256        assert!((p.mean() - 0.5).abs() < 1e-12);
257    }
258
259    #[test]
260    fn entity_posteriors_from_snapshot_rebuilds_map() {
261        let id1 = Uuid::new_v4();
262        let id2 = Uuid::new_v4();
263        let mut snapshot = HashMap::new();
264        snapshot.insert(id1, BetaPosterior::new(3.0, 2.0));
265        snapshot.insert(id2, BetaPosterior::new(5.0, 1.0));
266
267        let ep = EntityPosteriors::from_snapshot(snapshot, 100);
268        assert_eq!(ep.len(), 2);
269        let p1 = ep.get(&id1).unwrap();
270        assert!((p1.alpha - 3.0).abs() < 1e-12);
271        let p2 = ep.get(&id2).unwrap();
272        assert!((p2.alpha - 5.0).abs() < 1e-12);
273    }
274
275    #[test]
276    fn brain_state_from_snapshot_roundtrip() {
277        let mut params = HashMap::new();
278        params.insert(
279            "recall::relevance_weight".into(),
280            BetaPosterior::new(7.0, 3.0),
281        );
282        let mut state = BrainState::new(params, 100);
283        state.total_events = 55;
284        state.exploration_epoch = 2;
285        let id = Uuid::new_v4();
286        state
287            .entity_posteriors
288            .get_or_insert(id, || BetaPosterior::new(4.0, 6.0))
289            .update_success();
290
291        let snap1 = state.to_snapshot();
292        let restored = BrainState::from_snapshot(snap1.clone(), 100);
293        let snap2 = restored.to_snapshot();
294
295        assert_eq!(snap2.total_events, 55);
296        assert_eq!(snap2.exploration_epoch, 2);
297        let p = &snap2.parameters["recall::relevance_weight"];
298        assert!((p.alpha - 7.0).abs() < 1e-12);
299        assert!((p.beta - 3.0).abs() < 1e-12);
300        let ep = snap2.entity_posteriors.get(&id).unwrap();
301        // default 4+1=5 alpha (update_success on 4.0), beta stays 6.0
302        assert!((ep.alpha - 5.0).abs() < 1e-12);
303        assert!((ep.beta - 6.0).abs() < 1e-12);
304    }
305
306    #[test]
307    fn reset_posteriors_preserves_event_count() {
308        let mut params = HashMap::new();
309        params.insert("test".into(), BetaPosterior::new(7.0, 3.0));
310        let mut state = BrainState::new(params, 10);
311        state.total_events = 100;
312        state.reset_posteriors();
313        assert_eq!(state.total_events, 100);
314        assert_eq!(state.exploration_epoch, 1);
315        let p = &state.parameters["test"];
316        assert!((p.alpha - 1.0).abs() < 1e-12);
317        assert!((p.beta - 1.0).abs() < 1e-12);
318    }
319}