1use khive_fold::{Fold, FoldContext};
2use khive_storage::event::Event;
3
4use crate::event::{entity_signal, interpret, is_recall_positive};
5use crate::state::{BetaPosterior, BrainState};
6
7pub struct EventFold {
13 entity_capacity: usize,
14}
15
16impl EventFold {
17 pub fn new(entity_capacity: usize) -> Self {
18 Self { entity_capacity }
19 }
20}
21
22impl Fold<Event, BrainState> for EventFold {
23 fn initial(&self, _context: &FoldContext) -> BrainState {
24 BrainState::new(
25 [
26 (
27 "recall::relevance_weight".into(),
28 BetaPosterior::new(7.0, 3.0),
29 ),
30 (
31 "recall::importance_weight".into(),
32 BetaPosterior::new(2.0, 8.0),
33 ),
34 (
35 "recall::temporal_weight".into(),
36 BetaPosterior::new(1.0, 9.0),
37 ),
38 ]
39 .into_iter()
40 .collect(),
41 self.entity_capacity,
42 )
43 }
44
45 fn step(&self, mut state: BrainState, event: &Event, _ctx: &FoldContext) -> BrainState {
46 let signal = interpret(event);
47
48 state.total_events += 1;
49
50 if let Some(positive) = is_recall_positive(&signal) {
52 if let Some(posterior) = state.parameters.get_mut("recall::relevance_weight") {
53 if positive {
54 posterior.update_success();
55 } else {
56 posterior.update_failure();
57 }
58 }
59 }
60
61 if let Some((entity_id, positive)) = entity_signal(&signal) {
63 let posterior = state
64 .entity_posteriors
65 .get_or_insert(entity_id, || BetaPosterior::new(1.0, 1.0));
66 if positive {
67 posterior.update_success();
68 } else {
69 posterior.update_failure();
70 }
71 }
72
73 state
74 }
75
76 fn finalize(&self, state: BrainState, _context: &FoldContext) -> BrainState {
77 state
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use khive_types::{EventOutcome, SubstrateKind};
85 use uuid::Uuid;
86
87 fn make_event(verb: &str, outcome: EventOutcome, target: Option<Uuid>) -> Event {
88 let mut e = Event::new("test", verb, SubstrateKind::Note, "brain");
89 e.outcome = outcome;
90 e.target_id = target;
91 e
92 }
93
94 #[test]
95 fn initial_state_has_recall_priors() {
96 let fold = EventFold::new(100);
97 let ctx = FoldContext::new();
98 let state = fold.initial(&ctx);
99 assert!(state.parameters.contains_key("recall::relevance_weight"));
100 let p = &state.parameters["recall::relevance_weight"];
101 assert!((p.alpha - 7.0).abs() < 1e-12);
102 assert!((p.beta - 3.0).abs() < 1e-12);
103 }
104
105 #[test]
106 fn recall_hit_updates_global_and_entity() {
107 let fold = EventFold::new(100);
108 let ctx = FoldContext::new();
109 let mut state = fold.initial(&ctx);
110
111 let id = Uuid::new_v4();
112 let event = make_event("recall", EventOutcome::Success, Some(id));
113 state = fold.step(state, &event, &ctx);
114
115 assert_eq!(state.total_events, 1);
116 let p = &state.parameters["recall::relevance_weight"];
117 assert!((p.alpha - 8.0).abs() < 1e-12); let ep = state.entity_posteriors.get(&id).unwrap();
119 assert!((ep.alpha - 2.0).abs() < 1e-12); }
121
122 #[test]
123 fn recall_miss_updates_global_only() {
124 let fold = EventFold::new(100);
125 let ctx = FoldContext::new();
126 let mut state = fold.initial(&ctx);
127
128 let event = make_event("recall", EventOutcome::Success, None);
129 state = fold.step(state, &event, &ctx);
130
131 let p = &state.parameters["recall::relevance_weight"];
132 assert!((p.beta - 4.0).abs() < 1e-12); assert!(state.entity_posteriors.is_empty());
134 }
135
136 #[test]
137 fn irrelevant_event_increments_counter_only() {
138 let fold = EventFold::new(100);
139 let ctx = FoldContext::new();
140 let mut state = fold.initial(&ctx);
141
142 let event = make_event("link", EventOutcome::Success, Some(Uuid::new_v4()));
143 state = fold.step(state, &event, &ctx);
144
145 assert_eq!(state.total_events, 1);
146 let p = &state.parameters["recall::relevance_weight"];
147 assert!((p.alpha - 7.0).abs() < 1e-12); }
149
150 #[test]
151 fn feedback_not_useful_increments_entity_beta() {
152 let fold = EventFold::new(100);
153 let ctx = FoldContext::new();
154 let mut state = fold.initial(&ctx);
155
156 let id = Uuid::new_v4();
157 let mut event = make_event("brain.emit", EventOutcome::Success, Some(id));
158 event.data = Some(serde_json::json!({"signal": "not_useful"}));
159 state = fold.step(state, &event, &ctx);
160
161 assert_eq!(state.total_events, 1);
162 let ep = state.entity_posteriors.get(&id).unwrap();
163 assert!((ep.alpha - 1.0).abs() < 1e-12);
165 assert!((ep.beta - 2.0).abs() < 1e-12);
166 }
167
168 #[test]
169 fn deterministic_replay() {
170 let fold = EventFold::new(100);
171 let ctx = FoldContext::new();
172
173 let id = Uuid::new_v4();
174 let events = vec![
175 make_event("recall", EventOutcome::Success, Some(id)),
176 make_event("recall", EventOutcome::Success, None),
177 make_event("search", EventOutcome::Success, None),
178 make_event("recall", EventOutcome::Success, Some(id)),
179 ];
180
181 let mut s1 = fold.initial(&ctx);
182 for e in &events {
183 s1 = fold.step(s1, e, &ctx);
184 }
185
186 let mut s2 = fold.initial(&ctx);
187 for e in &events {
188 s2 = fold.step(s2, e, &ctx);
189 }
190
191 let snap1 = s1.to_snapshot();
192 let snap2 = s2.to_snapshot();
193 assert_eq!(snap1.total_events, snap2.total_events);
194 assert_eq!(snap1.parameters, snap2.parameters);
195 assert_eq!(snap1.entity_posteriors, snap2.entity_posteriors);
196 }
197}