Skip to main content

neuron_hooks/
lib.rs

1#![deny(missing_docs)]
2//! Hook registry and composition for neuron.
3//!
4//! The [`HookRegistry`] collects multiple [`Hook`] implementations into
5//! an ordered pipeline. At each hook point, hooks are dispatched in
6//! registration order. The pipeline short-circuits on `Halt`, `SkipTool`,
7//! or `ModifyToolInput` — subsequent hooks are not called. Hook errors
8//! are logged and the pipeline continues (errors don't halt).
9
10use layer0::hook::{Hook, HookAction, HookContext};
11use std::sync::Arc;
12
13/// A registry that dispatches hook events to an ordered pipeline of hooks.
14///
15/// Hooks are called in the order they were registered. The pipeline
16/// short-circuits on any action other than `Continue` (except errors,
17/// which are logged and ignored).
18pub struct HookRegistry {
19    hooks: Vec<Arc<dyn Hook>>,
20}
21
22impl HookRegistry {
23    /// Create a new empty hook registry.
24    pub fn new() -> Self {
25        Self { hooks: Vec::new() }
26    }
27
28    /// Add a hook to the end of the pipeline.
29    pub fn add(&mut self, hook: Arc<dyn Hook>) {
30        self.hooks.push(hook);
31    }
32
33    /// Dispatch a hook event through the pipeline.
34    ///
35    /// Returns the final action. If all hooks return `Continue`, the
36    /// result is `Continue`. If any hook returns `Halt`, `SkipTool`,
37    /// or `ModifyToolInput`, the pipeline stops and that action is returned.
38    /// Hook errors are logged and treated as `Continue`.
39    pub async fn dispatch(&self, ctx: &HookContext) -> HookAction {
40        for hook in &self.hooks {
41            // Only dispatch to hooks registered for this point
42            if !hook.points().contains(&ctx.point) {
43                continue;
44            }
45
46            match hook.on_event(ctx).await {
47                Ok(HookAction::Continue) => continue,
48                Ok(action) => return action,
49                Err(_e) => {
50                    // Hook errors are logged but don't halt the pipeline.
51                    // In a real system, this would go to tracing/logging.
52                    continue;
53                }
54            }
55        }
56
57        HookAction::Continue
58    }
59}
60
61impl Default for HookRegistry {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70    use async_trait::async_trait;
71    use layer0::error::HookError;
72    use layer0::hook::HookPoint;
73
74    struct ContinueHook {
75        points: Vec<HookPoint>,
76    }
77
78    #[async_trait]
79    impl Hook for ContinueHook {
80        fn points(&self) -> &[HookPoint] {
81            &self.points
82        }
83        async fn on_event(&self, _ctx: &HookContext) -> Result<HookAction, HookError> {
84            Ok(HookAction::Continue)
85        }
86    }
87
88    struct HaltHook {
89        points: Vec<HookPoint>,
90        reason: String,
91    }
92
93    #[async_trait]
94    impl Hook for HaltHook {
95        fn points(&self) -> &[HookPoint] {
96            &self.points
97        }
98        async fn on_event(&self, _ctx: &HookContext) -> Result<HookAction, HookError> {
99            Ok(HookAction::Halt {
100                reason: self.reason.clone(),
101            })
102        }
103    }
104
105    struct ErrorHook {
106        points: Vec<HookPoint>,
107    }
108
109    #[async_trait]
110    impl Hook for ErrorHook {
111        fn points(&self) -> &[HookPoint] {
112            &self.points
113        }
114        async fn on_event(&self, _ctx: &HookContext) -> Result<HookAction, HookError> {
115            Err(HookError::Failed("hook error".into()))
116        }
117    }
118
119    #[tokio::test]
120    async fn empty_registry_returns_continue() {
121        let registry = HookRegistry::new();
122        let ctx = HookContext::new(HookPoint::PreInference);
123        let action = registry.dispatch(&ctx).await;
124        assert!(matches!(action, HookAction::Continue));
125    }
126
127    #[tokio::test]
128    async fn continue_hook_returns_continue() {
129        let mut registry = HookRegistry::new();
130        registry.add(Arc::new(ContinueHook {
131            points: vec![HookPoint::PreInference],
132        }));
133
134        let ctx = HookContext::new(HookPoint::PreInference);
135        let action = registry.dispatch(&ctx).await;
136        assert!(matches!(action, HookAction::Continue));
137    }
138
139    #[tokio::test]
140    async fn halt_hook_short_circuits() {
141        let mut registry = HookRegistry::new();
142        registry.add(Arc::new(HaltHook {
143            points: vec![HookPoint::PreInference],
144            reason: "budget exceeded".into(),
145        }));
146        registry.add(Arc::new(ContinueHook {
147            points: vec![HookPoint::PreInference],
148        }));
149
150        let ctx = HookContext::new(HookPoint::PreInference);
151        let action = registry.dispatch(&ctx).await;
152        match action {
153            HookAction::Halt { reason } => assert_eq!(reason, "budget exceeded"),
154            _ => panic!("expected Halt"),
155        }
156    }
157
158    #[tokio::test]
159    async fn hook_not_matching_point_is_skipped() {
160        let mut registry = HookRegistry::new();
161        registry.add(Arc::new(HaltHook {
162            points: vec![HookPoint::PostInference],
163            reason: "should not trigger".into(),
164        }));
165
166        let ctx = HookContext::new(HookPoint::PreInference);
167        let action = registry.dispatch(&ctx).await;
168        assert!(matches!(action, HookAction::Continue));
169    }
170
171    #[tokio::test]
172    async fn error_hook_treated_as_continue() {
173        let mut registry = HookRegistry::new();
174        registry.add(Arc::new(ErrorHook {
175            points: vec![HookPoint::PreInference],
176        }));
177
178        let ctx = HookContext::new(HookPoint::PreInference);
179        let action = registry.dispatch(&ctx).await;
180        assert!(matches!(action, HookAction::Continue));
181    }
182
183    #[tokio::test]
184    async fn multiple_continue_hooks_all_pass() {
185        let mut registry = HookRegistry::new();
186        registry.add(Arc::new(ContinueHook {
187            points: vec![HookPoint::PreInference],
188        }));
189        registry.add(Arc::new(ContinueHook {
190            points: vec![HookPoint::PreInference],
191        }));
192
193        let ctx = HookContext::new(HookPoint::PreInference);
194        let action = registry.dispatch(&ctx).await;
195        assert!(matches!(action, HookAction::Continue));
196    }
197
198    #[test]
199    fn default_registry_is_empty() {
200        let registry = HookRegistry::default();
201        let ctx = HookContext::new(HookPoint::PreInference);
202        // Can't async test in #[test], but verify it constructs
203        let _ = registry;
204        let _ = ctx;
205    }
206}