Skip to main content

khive_pack_brain/
fold.rs

1use khive_fold::{Fold, FoldContext};
2use khive_storage::event::Event;
3
4use crate::event::{entity_signal, interpret, is_recall_positive};
5use crate::state::{BetaPosterior, BrainState};
6
7/// The brain as a meta-fold: `Fold<Event, BrainState>`.
8///
9/// Processes the existing Event substrate stream. Each event is interpreted
10/// via `event::interpret()` and routed to the relevant posteriors.
11/// Deterministic: same events in the same order → same BrainState.
12pub struct EventFold {
13    entity_capacity: usize,
14}
15
16impl EventFold {
17    pub fn new(entity_capacity: usize) -> Self {
18        Self { entity_capacity }
19    }
20}
21
22impl Fold<Event, BrainState> for EventFold {
23    fn initial(&self, _context: &FoldContext) -> BrainState {
24        BrainState::new(
25            [
26                (
27                    "recall::relevance_weight".into(),
28                    BetaPosterior::new(7.0, 3.0),
29                ),
30                (
31                    "recall::importance_weight".into(),
32                    BetaPosterior::new(2.0, 8.0),
33                ),
34                (
35                    "recall::temporal_weight".into(),
36                    BetaPosterior::new(1.0, 9.0),
37                ),
38            ]
39            .into_iter()
40            .collect(),
41            self.entity_capacity,
42        )
43    }
44
45    fn step(&self, mut state: BrainState, event: &Event, _ctx: &FoldContext) -> BrainState {
46        let signal = interpret(event);
47
48        state.total_events += 1;
49
50        // Global recall parameter updates
51        if let Some(positive) = is_recall_positive(&signal) {
52            if let Some(posterior) = state.parameters.get_mut("recall::relevance_weight") {
53                if positive {
54                    posterior.update_success();
55                } else {
56                    posterior.update_failure();
57                }
58            }
59        }
60
61        // Per-entity posterior updates
62        if let Some((entity_id, positive)) = entity_signal(&signal) {
63            let posterior = state
64                .entity_posteriors
65                .get_or_insert(entity_id, || BetaPosterior::new(1.0, 1.0));
66            if positive {
67                posterior.update_success();
68            } else {
69                posterior.update_failure();
70            }
71        }
72
73        state
74    }
75
76    fn finalize(&self, state: BrainState, _context: &FoldContext) -> BrainState {
77        state
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use khive_types::{EventOutcome, SubstrateKind};
85    use uuid::Uuid;
86
87    fn make_event(verb: &str, outcome: EventOutcome, target: Option<Uuid>) -> Event {
88        let mut e = Event::new("test", verb, SubstrateKind::Note, "brain");
89        e.outcome = outcome;
90        e.target_id = target;
91        e
92    }
93
94    #[test]
95    fn initial_state_has_recall_priors() {
96        let fold = EventFold::new(100);
97        let ctx = FoldContext::new();
98        let state = fold.initial(&ctx);
99        assert!(state.parameters.contains_key("recall::relevance_weight"));
100        let p = &state.parameters["recall::relevance_weight"];
101        assert!((p.alpha - 7.0).abs() < 1e-12);
102        assert!((p.beta - 3.0).abs() < 1e-12);
103    }
104
105    #[test]
106    fn recall_hit_updates_global_and_entity() {
107        let fold = EventFold::new(100);
108        let ctx = FoldContext::new();
109        let mut state = fold.initial(&ctx);
110
111        let id = Uuid::new_v4();
112        let event = make_event("recall", EventOutcome::Success, Some(id));
113        state = fold.step(state, &event, &ctx);
114
115        assert_eq!(state.total_events, 1);
116        let p = &state.parameters["recall::relevance_weight"];
117        assert!((p.alpha - 8.0).abs() < 1e-12); // 7 + 1 success
118        let ep = state.entity_posteriors.get(&id).unwrap();
119        assert!((ep.alpha - 2.0).abs() < 1e-12); // 1 + 1 success
120    }
121
122    #[test]
123    fn recall_miss_updates_global_only() {
124        let fold = EventFold::new(100);
125        let ctx = FoldContext::new();
126        let mut state = fold.initial(&ctx);
127
128        let event = make_event("recall", EventOutcome::Success, None);
129        state = fold.step(state, &event, &ctx);
130
131        let p = &state.parameters["recall::relevance_weight"];
132        assert!((p.beta - 4.0).abs() < 1e-12); // 3 + 1 failure
133        assert!(state.entity_posteriors.is_empty());
134    }
135
136    #[test]
137    fn irrelevant_event_increments_counter_only() {
138        let fold = EventFold::new(100);
139        let ctx = FoldContext::new();
140        let mut state = fold.initial(&ctx);
141
142        let event = make_event("link", EventOutcome::Success, Some(Uuid::new_v4()));
143        state = fold.step(state, &event, &ctx);
144
145        assert_eq!(state.total_events, 1);
146        let p = &state.parameters["recall::relevance_weight"];
147        assert!((p.alpha - 7.0).abs() < 1e-12); // unchanged
148    }
149
150    #[test]
151    fn feedback_not_useful_increments_entity_beta() {
152        let fold = EventFold::new(100);
153        let ctx = FoldContext::new();
154        let mut state = fold.initial(&ctx);
155
156        let id = Uuid::new_v4();
157        let mut event = make_event("brain.emit", EventOutcome::Success, Some(id));
158        event.data = Some(serde_json::json!({"signal": "not_useful"}));
159        state = fold.step(state, &event, &ctx);
160
161        assert_eq!(state.total_events, 1);
162        let ep = state.entity_posteriors.get(&id).unwrap();
163        // default prior Beta(1,1); not_useful → update_failure → beta = 2
164        assert!((ep.alpha - 1.0).abs() < 1e-12);
165        assert!((ep.beta - 2.0).abs() < 1e-12);
166    }
167
168    #[test]
169    fn deterministic_replay() {
170        let fold = EventFold::new(100);
171        let ctx = FoldContext::new();
172
173        let id = Uuid::new_v4();
174        let events = vec![
175            make_event("recall", EventOutcome::Success, Some(id)),
176            make_event("recall", EventOutcome::Success, None),
177            make_event("search", EventOutcome::Success, None),
178            make_event("recall", EventOutcome::Success, Some(id)),
179        ];
180
181        let mut s1 = fold.initial(&ctx);
182        for e in &events {
183            s1 = fold.step(s1, e, &ctx);
184        }
185
186        let mut s2 = fold.initial(&ctx);
187        for e in &events {
188            s2 = fold.step(s2, e, &ctx);
189        }
190
191        let snap1 = s1.to_snapshot();
192        let snap2 = s2.to_snapshot();
193        assert_eq!(snap1.total_events, snap2.total_events);
194        assert_eq!(snap1.parameters, snap2.parameters);
195        assert_eq!(snap1.entity_posteriors, snap2.entity_posteriors);
196    }
197}