1use std::collections::{HashMap, VecDeque};
2
3use serde::{Deserialize, Serialize};
4use uuid::Uuid;
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
8pub struct BetaPosterior {
9 pub alpha: f64,
10 pub beta: f64,
11}
12
13impl BetaPosterior {
14 pub fn new(alpha: f64, beta: f64) -> Self {
15 Self { alpha, beta }
16 }
17
18 pub fn mean(&self) -> f64 {
19 self.alpha / (self.alpha + self.beta)
20 }
21
22 pub fn variance(&self) -> f64 {
23 let n = self.alpha + self.beta;
24 (self.alpha * self.beta) / (n * n * (n + 1.0))
25 }
26
27 pub fn effective_sample_size(&self) -> f64 {
28 self.alpha + self.beta
29 }
30
31 pub fn update_success(&mut self) {
32 self.alpha += 1.0;
33 }
34
35 pub fn update_failure(&mut self) {
36 self.beta += 1.0;
37 }
38}
39
40impl Default for BetaPosterior {
41 fn default() -> Self {
42 Self::new(1.0, 1.0)
43 }
44}
45
46pub struct EntityPosteriors {
49 map: HashMap<Uuid, BetaPosterior>,
50 order: VecDeque<Uuid>,
51 capacity: usize,
52}
53
54impl EntityPosteriors {
55 pub fn new(capacity: usize) -> Self {
56 Self {
57 map: HashMap::with_capacity(capacity),
58 order: VecDeque::with_capacity(capacity),
59 capacity,
60 }
61 }
62
63 pub fn get_or_insert(
64 &mut self,
65 id: Uuid,
66 default: impl FnOnce() -> BetaPosterior,
67 ) -> &mut BetaPosterior {
68 if !self.map.contains_key(&id) {
69 if self.map.len() >= self.capacity {
70 if let Some(evicted) = self.order.pop_front() {
71 self.map.remove(&evicted);
72 }
73 }
74 self.map.insert(id, default());
75 self.order.push_back(id);
76 }
77 self.map.get_mut(&id).unwrap()
78 }
79
80 pub fn get(&self, id: &Uuid) -> Option<&BetaPosterior> {
81 self.map.get(id)
82 }
83
84 pub fn len(&self) -> usize {
85 self.map.len()
86 }
87
88 pub fn is_empty(&self) -> bool {
89 self.map.is_empty()
90 }
91
92 pub fn clear(&mut self) {
93 self.map.clear();
94 self.order.clear();
95 }
96
97 pub fn to_snapshot(&self) -> HashMap<Uuid, BetaPosterior> {
98 self.map.clone()
99 }
100
101 pub fn from_snapshot(snapshot: HashMap<Uuid, BetaPosterior>, capacity: usize) -> Self {
102 let mut ep = Self::new(capacity);
103 for (id, posterior) in snapshot {
104 ep.map.insert(id, posterior);
105 ep.order.push_back(id);
106 }
107 ep
108 }
109}
110
111pub struct BrainState {
113 pub parameters: HashMap<String, BetaPosterior>,
114 pub entity_posteriors: EntityPosteriors,
115 pub total_events: u64,
116 pub exploration_epoch: u64,
117}
118
119impl BrainState {
120 pub fn new(parameters: HashMap<String, BetaPosterior>, entity_capacity: usize) -> Self {
121 Self {
122 parameters,
123 entity_posteriors: EntityPosteriors::new(entity_capacity),
124 total_events: 0,
125 exploration_epoch: 0,
126 }
127 }
128
129 pub fn to_snapshot(&self) -> BrainStateSnapshot {
130 BrainStateSnapshot {
131 parameters: self.parameters.clone(),
132 entity_posteriors: self.entity_posteriors.to_snapshot(),
133 total_events: self.total_events,
134 exploration_epoch: self.exploration_epoch,
135 }
136 }
137
138 pub fn from_snapshot(snapshot: BrainStateSnapshot, entity_capacity: usize) -> Self {
139 Self {
140 parameters: snapshot.parameters,
141 entity_posteriors: EntityPosteriors::from_snapshot(
142 snapshot.entity_posteriors,
143 entity_capacity,
144 ),
145 total_events: snapshot.total_events,
146 exploration_epoch: snapshot.exploration_epoch,
147 }
148 }
149
150 pub fn reset_posteriors(&mut self) {
151 for posterior in self.parameters.values_mut() {
152 *posterior = BetaPosterior::new(1.0, 1.0);
153 }
154 self.entity_posteriors.clear();
155 self.exploration_epoch += 1;
156 }
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct BrainStateSnapshot {
162 pub parameters: HashMap<String, BetaPosterior>,
163 pub entity_posteriors: HashMap<Uuid, BetaPosterior>,
164 pub total_events: u64,
165 pub exploration_epoch: u64,
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn beta_posterior_mean() {
174 let p = BetaPosterior::new(7.0, 3.0);
175 assert!((p.mean() - 0.7).abs() < 1e-12);
176 }
177
178 #[test]
179 fn beta_posterior_variance() {
180 let p = BetaPosterior::new(7.0, 3.0);
181 let expected = 21.0 / 1100.0;
183 assert!((p.variance() - expected).abs() < 1e-12);
184 }
185
186 #[test]
187 fn beta_posterior_ess() {
188 let p = BetaPosterior::new(7.0, 3.0);
189 assert!((p.effective_sample_size() - 10.0).abs() < 1e-12);
190 }
191
192 #[test]
193 fn beta_posterior_update() {
194 let mut p = BetaPosterior::new(1.0, 1.0);
195 p.update_success();
196 p.update_success();
197 p.update_failure();
198 assert!((p.alpha - 3.0).abs() < 1e-12);
199 assert!((p.beta - 2.0).abs() < 1e-12);
200 assert!((p.mean() - 0.6).abs() < 1e-12);
201 }
202
203 #[test]
204 fn entity_posteriors_eviction() {
205 let mut ep = EntityPosteriors::new(3);
206 let ids: Vec<Uuid> = (0..5).map(|_| Uuid::new_v4()).collect();
207 for id in &ids {
208 ep.get_or_insert(*id, BetaPosterior::default);
209 }
210 assert_eq!(ep.len(), 3);
211 assert!(ep.get(&ids[0]).is_none());
213 assert!(ep.get(&ids[1]).is_none());
214 assert!(ep.get(&ids[2]).is_some());
215 assert!(ep.get(&ids[3]).is_some());
216 assert!(ep.get(&ids[4]).is_some());
217 }
218
219 #[test]
220 fn entity_posteriors_get_or_insert_existing() {
221 let mut ep = EntityPosteriors::new(10);
222 let id = Uuid::new_v4();
223 ep.get_or_insert(id, BetaPosterior::default)
224 .update_success();
225 let p = ep.get_or_insert(id, BetaPosterior::default);
226 assert!((p.alpha - 2.0).abs() < 1e-12);
227 }
228
229 #[test]
230 fn brain_state_snapshot_roundtrip() {
231 let mut state = BrainState::new(HashMap::new(), 100);
232 state.parameters.insert(
233 "memory::relevance_weight".into(),
234 BetaPosterior::new(7.0, 3.0),
235 );
236 state.total_events = 42;
237 let id = Uuid::new_v4();
238 state
239 .entity_posteriors
240 .get_or_insert(id, BetaPosterior::default)
241 .update_success();
242
243 let snapshot = state.to_snapshot();
244 let json = serde_json::to_string(&snapshot).unwrap();
245 let back: BrainStateSnapshot = serde_json::from_str(&json).unwrap();
246 assert_eq!(back.total_events, 42);
247 assert!(back.parameters.contains_key("memory::relevance_weight"));
248 assert!(back.entity_posteriors.contains_key(&id));
249 }
250
251 #[test]
252 fn beta_posterior_default_has_uniform_prior() {
253 let p = BetaPosterior::default();
254 assert!((p.alpha - 1.0).abs() < 1e-12);
255 assert!((p.beta - 1.0).abs() < 1e-12);
256 assert!((p.mean() - 0.5).abs() < 1e-12);
257 }
258
259 #[test]
260 fn entity_posteriors_from_snapshot_rebuilds_map() {
261 let id1 = Uuid::new_v4();
262 let id2 = Uuid::new_v4();
263 let mut snapshot = HashMap::new();
264 snapshot.insert(id1, BetaPosterior::new(3.0, 2.0));
265 snapshot.insert(id2, BetaPosterior::new(5.0, 1.0));
266
267 let ep = EntityPosteriors::from_snapshot(snapshot, 100);
268 assert_eq!(ep.len(), 2);
269 let p1 = ep.get(&id1).unwrap();
270 assert!((p1.alpha - 3.0).abs() < 1e-12);
271 let p2 = ep.get(&id2).unwrap();
272 assert!((p2.alpha - 5.0).abs() < 1e-12);
273 }
274
275 #[test]
276 fn brain_state_from_snapshot_roundtrip() {
277 let mut params = HashMap::new();
278 params.insert(
279 "recall::relevance_weight".into(),
280 BetaPosterior::new(7.0, 3.0),
281 );
282 let mut state = BrainState::new(params, 100);
283 state.total_events = 55;
284 state.exploration_epoch = 2;
285 let id = Uuid::new_v4();
286 state
287 .entity_posteriors
288 .get_or_insert(id, || BetaPosterior::new(4.0, 6.0))
289 .update_success();
290
291 let snap1 = state.to_snapshot();
292 let restored = BrainState::from_snapshot(snap1.clone(), 100);
293 let snap2 = restored.to_snapshot();
294
295 assert_eq!(snap2.total_events, 55);
296 assert_eq!(snap2.exploration_epoch, 2);
297 let p = &snap2.parameters["recall::relevance_weight"];
298 assert!((p.alpha - 7.0).abs() < 1e-12);
299 assert!((p.beta - 3.0).abs() < 1e-12);
300 let ep = snap2.entity_posteriors.get(&id).unwrap();
301 assert!((ep.alpha - 5.0).abs() < 1e-12);
303 assert!((ep.beta - 6.0).abs() < 1e-12);
304 }
305
306 #[test]
307 fn reset_posteriors_preserves_event_count() {
308 let mut params = HashMap::new();
309 params.insert("test".into(), BetaPosterior::new(7.0, 3.0));
310 let mut state = BrainState::new(params, 10);
311 state.total_events = 100;
312 state.reset_posteriors();
313 assert_eq!(state.total_events, 100);
314 assert_eq!(state.exploration_epoch, 1);
315 let p = &state.parameters["test"];
316 assert!((p.alpha - 1.0).abs() < 1e-12);
317 assert!((p.beta - 1.0).abs() < 1e-12);
318 }
319}