Skip to main content

khive_pack_brain/
lib.rs

1pub mod event;
2pub mod fold;
3pub mod state;
4pub mod tunable;
5
6use std::sync::Mutex;
7
8use async_trait::async_trait;
9use serde::Deserialize;
10use serde_json::{json, Value};
11
12use khive_fold::{Fold, FoldContext};
13use khive_runtime::pack::PackRuntime;
14use khive_runtime::{DispatchHook, KhiveRuntime, RuntimeError, VerbRegistry};
15use khive_storage::event::{Event, EventFilter};
16use khive_storage::types::PageRequest;
17use khive_types::{Pack, VerbDef};
18
19use crate::fold::EventFold;
20use crate::state::BrainState;
21
22const ENTITY_CACHE_CAPACITY: usize = 10_000;
23
24pub struct BrainPack {
25    runtime: KhiveRuntime,
26    state: Mutex<BrainState>,
27    fold: EventFold,
28}
29
30impl Pack for BrainPack {
31    const NAME: &'static str = "brain";
32    const NOTE_KINDS: &'static [&'static str] = &[];
33    const ENTITY_KINDS: &'static [&'static str] = &[];
34    const VERBS: &'static [VerbDef] = &BRAIN_VERBS;
35    const REQUIRES: &'static [&'static str] = &["kg"];
36}
37
38static BRAIN_VERBS: [VerbDef; 5] = [
39    VerbDef {
40        name: "brain.state",
41        description: "Return current BrainState snapshot for inspection",
42    },
43    VerbDef {
44        name: "brain.config",
45        description: "Return projected config for a named pack parameter",
46    },
47    VerbDef {
48        name: "brain.events",
49        description: "List recent brain-relevant events for debugging",
50    },
51    VerbDef {
52        name: "brain.reset",
53        description: "Reset posteriors to priors (preserves event history)",
54    },
55    VerbDef {
56        name: "brain.emit",
57        description: "Manually emit a feedback event for a specific entity",
58    },
59];
60
61impl BrainPack {
62    pub fn new(runtime: KhiveRuntime) -> Self {
63        let fold = EventFold::new(ENTITY_CACHE_CAPACITY);
64        let ctx = FoldContext::new();
65        let state = fold.initial(&ctx);
66        Self {
67            runtime,
68            state: Mutex::new(state),
69            fold,
70        }
71    }
72
73    async fn handle_state(&self, _params: Value) -> Result<Value, RuntimeError> {
74        let state = self.state.lock().unwrap();
75        let snapshot = state.to_snapshot();
76        serde_json::to_value(&snapshot).map_err(|e| RuntimeError::InvalidInput(e.to_string()))
77    }
78
79    /// Public snapshot of the current `BrainState`.
80    ///
81    /// Equivalent to dispatching the `brain.state` verb but callable directly
82    /// when you hold an `Arc<BrainPack>` (e.g. a test that registered the pack
83    /// as a `DispatchHook` and wants to verify posteriors updated).
84    pub fn snapshot(&self) -> crate::state::BrainStateSnapshot {
85        self.state.lock().unwrap().to_snapshot()
86    }
87
88    async fn handle_config(&self, params: Value) -> Result<Value, RuntimeError> {
89        #[derive(Deserialize)]
90        struct ConfigParams {
91            parameter: Option<String>,
92        }
93        let p: ConfigParams = serde_json::from_value(params)
94            .map_err(|e| RuntimeError::InvalidInput(e.to_string()))?;
95
96        let state = self.state.lock().unwrap();
97        match p.parameter {
98            Some(key) => {
99                let posterior = state
100                    .parameters
101                    .get(&key)
102                    .ok_or_else(|| RuntimeError::NotFound(format!("parameter {key:?}")))?;
103                Ok(json!({
104                    "parameter": key,
105                    "mean": posterior.mean(),
106                    "variance": posterior.variance(),
107                    "ess": posterior.effective_sample_size(),
108                    "alpha": posterior.alpha,
109                    "beta": posterior.beta,
110                }))
111            }
112            None => {
113                let configs: serde_json::Map<String, Value> = state
114                    .parameters
115                    .iter()
116                    .map(|(k, p)| {
117                        (
118                            k.clone(),
119                            json!({
120                                "mean": p.mean(),
121                                "variance": p.variance(),
122                                "ess": p.effective_sample_size(),
123                            }),
124                        )
125                    })
126                    .collect();
127                Ok(Value::Object(configs))
128            }
129        }
130    }
131
132    async fn handle_events(&self, params: Value) -> Result<Value, RuntimeError> {
133        #[derive(Deserialize)]
134        struct EventsParams {
135            namespace: Option<String>,
136            limit: Option<u32>,
137        }
138        let p: EventsParams = serde_json::from_value(params)
139            .map_err(|e| RuntimeError::InvalidInput(e.to_string()))?;
140
141        let limit = p.limit.unwrap_or(20).min(100);
142        let ns = self.runtime.ns(p.namespace.as_deref()).to_string();
143
144        let store = self.runtime.events(p.namespace.as_deref())?;
145        let filter = EventFilter {
146            verbs: vec![
147                "recall".into(),
148                "search".into(),
149                "brain.emit".into(),
150                "get".into(),
151                "remember".into(),
152            ],
153            namespaces: vec![ns],
154            ..EventFilter::default()
155        };
156        let page = store
157            .query_events(filter, PageRequest { offset: 0, limit })
158            .await
159            .map_err(|e| RuntimeError::InvalidInput(e.to_string()))?;
160
161        let events: Vec<Value> = page
162            .items
163            .iter()
164            .map(|e| {
165                json!({
166                    "id": e.id.to_string(),
167                    "verb": e.verb,
168                    "outcome": e.outcome,
169                    "target_id": e.target_id.map(|t| t.to_string()),
170                    "duration_us": e.duration_us,
171                    "created_at": e.created_at,
172                })
173            })
174            .collect();
175
176        Ok(json!({
177            "count": events.len(),
178            "events": events,
179        }))
180    }
181
182    async fn handle_reset(&self, _params: Value) -> Result<Value, RuntimeError> {
183        let mut state = self.state.lock().unwrap();
184        state.reset_posteriors();
185        Ok(json!({
186            "reset": true,
187            "exploration_epoch": state.exploration_epoch,
188        }))
189    }
190
191    async fn handle_emit(&self, params: Value) -> Result<Value, RuntimeError> {
192        #[derive(Deserialize)]
193        struct EmitParams {
194            target_id: String,
195            signal: String,
196            namespace: Option<String>,
197        }
198        let p: EmitParams = serde_json::from_value(params)
199            .map_err(|e| RuntimeError::InvalidInput(e.to_string()))?;
200
201        let target: uuid::Uuid = p
202            .target_id
203            .parse()
204            .map_err(|e| RuntimeError::InvalidInput(format!("invalid target_id: {e}")))?;
205
206        let signal = match p.signal.as_str() {
207            "useful" => "useful",
208            "not_useful" => "not_useful",
209            "wrong" => "wrong",
210            other => {
211                return Err(RuntimeError::InvalidInput(format!(
212                    "unknown signal {other:?}; valid: useful | not_useful | wrong"
213                )))
214            }
215        };
216
217        let event = khive_storage::event::Event::new(
218            self.runtime.ns(p.namespace.as_deref()).to_string(),
219            "brain.emit",
220            khive_types::SubstrateKind::Event,
221            "brain",
222        )
223        .with_target(target)
224        .with_data(json!({"signal": signal}));
225
226        let store = self.runtime.events(p.namespace.as_deref())?;
227        store
228            .append_event(event.clone())
229            .await
230            .map_err(|e| RuntimeError::InvalidInput(e.to_string()))?;
231
232        // Update brain state from this event
233        let ctx = FoldContext::new();
234        let mut state = self.state.lock().unwrap();
235        let current = std::mem::replace(
236            &mut *state,
237            BrainState::new(std::collections::HashMap::new(), 0),
238        );
239        *state = self.fold.step(current, &event, &ctx);
240
241        Ok(json!({
242            "emitted": true,
243            "event_id": event.id.to_string(),
244            "signal": signal,
245            "target_id": target.to_string(),
246        }))
247    }
248}
249
250// ── ADR-063: inventory self-registration ─────────────────────────────────────
251
252struct BrainPackFactory;
253
254impl khive_runtime::PackFactory for BrainPackFactory {
255    fn name(&self) -> &'static str {
256        "brain"
257    }
258
259    fn requires(&self) -> &'static [&'static str] {
260        &["kg"]
261    }
262
263    fn create(&self, runtime: KhiveRuntime) -> Box<dyn PackRuntime> {
264        Box::new(BrainPack::new(runtime))
265    }
266}
267
268inventory::submit! { khive_runtime::PackRegistration(&BrainPackFactory) }
269
270#[async_trait]
271impl PackRuntime for BrainPack {
272    fn name(&self) -> &str {
273        <BrainPack as Pack>::NAME
274    }
275
276    fn note_kinds(&self) -> &'static [&'static str] {
277        <BrainPack as Pack>::NOTE_KINDS
278    }
279
280    fn entity_kinds(&self) -> &'static [&'static str] {
281        <BrainPack as Pack>::ENTITY_KINDS
282    }
283
284    fn verbs(&self) -> &'static [VerbDef] {
285        &BRAIN_VERBS
286    }
287
288    fn requires(&self) -> &'static [&'static str] {
289        <BrainPack as Pack>::REQUIRES
290    }
291
292    async fn dispatch(
293        &self,
294        verb: &str,
295        params: Value,
296        _registry: &VerbRegistry,
297    ) -> Result<Value, RuntimeError> {
298        match verb {
299            "brain.state" => self.handle_state(params).await,
300            "brain.config" => self.handle_config(params).await,
301            "brain.events" => self.handle_events(params).await,
302            "brain.reset" => self.handle_reset(params).await,
303            "brain.emit" => self.handle_emit(params).await,
304            _ => Err(RuntimeError::InvalidInput(format!(
305                "brain pack does not handle verb {verb:?}"
306            ))),
307        }
308    }
309}
310
311/// `BrainPack` as a post-dispatch hook (Issue #158).
312///
313/// When registered via `VerbRegistryBuilder::with_dispatch_hook`, every
314/// successful verb dispatch calls `on_dispatch` with a synthesized `Event`.
315/// The event is fed into `EventFold::step`, updating the brain's posteriors
316/// in real time — no polling required.
317///
318/// This is opt-in: the hook must be explicitly registered. Registries that do
319/// not load the brain pack are unaffected.
320#[async_trait]
321impl DispatchHook for BrainPack {
322    async fn on_dispatch(&self, event: &Event) {
323        let ctx = FoldContext::new();
324        let mut state = self.state.lock().unwrap();
325        // Replace state with fold result. BrainState is not Clone, so we
326        // use mem::replace with a sentinel and immediately overwrite.
327        let current = std::mem::replace(
328            &mut *state,
329            BrainState::new(std::collections::HashMap::new(), 0),
330        );
331        *state = self.fold.step(current, event, &ctx);
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use khive_runtime::VerbRegistryBuilder;
339    use serde_json::json;
340
341    fn make_pack() -> BrainPack {
342        let rt = KhiveRuntime::memory().expect("in-memory runtime");
343        BrainPack::new(rt)
344    }
345
346    fn empty_registry() -> VerbRegistry {
347        VerbRegistryBuilder::new()
348            .build()
349            .expect("empty registry builds successfully")
350    }
351
352    #[tokio::test]
353    async fn dispatch_unknown_verb_returns_invalid_input() {
354        let pack = make_pack();
355        let registry = empty_registry();
356        let err = pack
357            .dispatch("brain.unknown", json!({}), &registry)
358            .await
359            .unwrap_err();
360        if let RuntimeError::InvalidInput(msg) = &err {
361            assert!(
362                msg.contains("brain.unknown"),
363                "expected verb name in error: {msg}"
364            );
365        } else {
366            panic!("expected InvalidInput, got {err:?}");
367        }
368    }
369
370    #[tokio::test]
371    async fn dispatch_reset_returns_true_and_increments_epoch() {
372        let pack = make_pack();
373        let registry = empty_registry();
374        let result = pack
375            .dispatch("brain.reset", json!({}), &registry)
376            .await
377            .unwrap();
378        assert_eq!(result["reset"], json!(true));
379        assert_eq!(result["exploration_epoch"], json!(1u64));
380    }
381
382    #[tokio::test]
383    async fn dispatch_emit_invalid_signal_returns_invalid_input() {
384        let pack = make_pack();
385        let registry = empty_registry();
386        let target = "00000000-0000-0000-0000-000000000001";
387        let err = pack
388            .dispatch(
389                "brain.emit",
390                json!({"target_id": target, "signal": "bad_signal"}),
391                &registry,
392            )
393            .await
394            .unwrap_err();
395        if let RuntimeError::InvalidInput(msg) = &err {
396            assert!(
397                msg.contains("bad_signal"),
398                "expected signal name in error: {msg}"
399            );
400            assert!(
401                msg.contains("valid"),
402                "expected hint about valid values: {msg}"
403            );
404        } else {
405            panic!("expected InvalidInput, got {err:?}");
406        }
407    }
408
409    #[tokio::test]
410    async fn dispatch_state_returns_snapshot_fields() {
411        let pack = make_pack();
412        let registry = empty_registry();
413        let result = pack
414            .dispatch("brain.state", json!({}), &registry)
415            .await
416            .unwrap();
417        assert!(result.get("total_events").is_some(), "missing total_events");
418        assert!(
419            result.get("exploration_epoch").is_some(),
420            "missing exploration_epoch"
421        );
422        assert!(result.get("parameters").is_some(), "missing parameters");
423    }
424}