Skip to main content

bamboo_engine/runtime/hooks/
mod.rs

1//! Hook runner — dispatches registered hooks at lifecycle points.
2
3use std::sync::Arc;
4
5use bamboo_agent_core::AgentHook;
6use bamboo_agent_core::Session;
7use bamboo_domain::{AgentHookPoint, AgentRuntimeState, HookCheckpoint, HookResult};
8use chrono::Utc;
9
10/// Runs registered hooks at a given hook point.
11#[derive(Clone)]
12pub struct HookRunner {
13    hooks: Vec<Arc<dyn AgentHook>>,
14}
15
16impl HookRunner {
17    pub fn new() -> Self {
18        Self { hooks: Vec::new() }
19    }
20
21    /// Register a hook. Hooks are sorted by priority (lower runs first).
22    pub fn register(&mut self, hook: Arc<dyn AgentHook>) {
23        self.hooks.push(hook);
24        self.hooks.sort_by_key(|h| h.priority());
25    }
26
27    /// Run all hooks matching the given point.
28    ///
29    /// Records checkpoints in `runtime_state`. Returns the first
30    /// `Suspend` or `Abort` result, or the aggregate result otherwise.
31    pub async fn run_hooks(
32        &self,
33        point: AgentHookPoint,
34        session: &Session,
35        runtime_state: &mut AgentRuntimeState,
36    ) -> HookResult {
37        let mut final_result = HookResult::Continue;
38
39        for hook in &self.hooks {
40            if hook.point() != point {
41                continue;
42            }
43
44            let start = std::time::Instant::now();
45            let result = hook.run(point, session).await;
46            let elapsed = start.elapsed();
47
48            runtime_state.checkpoints.push(HookCheckpoint {
49                hook_point: format!("{:?}", point),
50                timestamp: Utc::now(),
51                result: format!("{:?}", result),
52                duration_ms: elapsed.as_millis() as u64,
53            });
54
55            match &result {
56                HookResult::Abort { .. } | HookResult::Suspend { .. } => return result,
57                HookResult::Mutated => final_result = HookResult::Mutated,
58                HookResult::Continue => {}
59            }
60        }
61
62        final_result
63    }
64
65    /// Check if any hooks are registered for the given point.
66    pub fn has_hooks_for(&self, point: AgentHookPoint) -> bool {
67        self.hooks.iter().any(|h| h.point() == point)
68    }
69
70    /// Number of registered hooks.
71    pub fn len(&self) -> usize {
72        self.hooks.len()
73    }
74
75    /// Whether any hooks are registered.
76    pub fn is_empty(&self) -> bool {
77        self.hooks.is_empty()
78    }
79}
80
81impl Default for HookRunner {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    /// A no-op hook that always returns Continue.
92    struct ContinueHook {
93        point: AgentHookPoint,
94        pri: u32,
95        name: String,
96    }
97
98    #[async_trait::async_trait]
99    impl AgentHook for ContinueHook {
100        fn point(&self) -> AgentHookPoint {
101            self.point
102        }
103
104        async fn run(&self, _point: AgentHookPoint, _session: &Session) -> HookResult {
105            HookResult::Continue
106        }
107
108        fn priority(&self) -> u32 {
109            self.pri
110        }
111
112        fn name(&self) -> &str {
113            &self.name
114        }
115    }
116
117    /// A hook that always returns Abort.
118    struct AbortHook;
119
120    #[async_trait::async_trait]
121    impl AgentHook for AbortHook {
122        fn point(&self) -> AgentHookPoint {
123            AgentHookPoint::BeforeLlmCall
124        }
125
126        async fn run(&self, _point: AgentHookPoint, _session: &Session) -> HookResult {
127            HookResult::Abort {
128                reason: "test abort".to_string(),
129            }
130        }
131
132        fn name(&self) -> &str {
133            "abort_hook"
134        }
135    }
136
137    fn test_session() -> Session {
138        Session::new("test", "test-model")
139    }
140
141    #[tokio::test]
142    async fn empty_runner_returns_continue() {
143        let runner = HookRunner::new();
144        let mut state = AgentRuntimeState::new("run-1");
145        let session = test_session();
146
147        let result = runner
148            .run_hooks(AgentHookPoint::BeforeRound, &session, &mut state)
149            .await;
150
151        assert_eq!(result, HookResult::Continue);
152        assert!(state.checkpoints.is_empty());
153    }
154
155    #[tokio::test]
156    async fn hooks_run_in_priority_order() {
157        let mut runner = HookRunner::new();
158        runner.register(Arc::new(ContinueHook {
159            point: AgentHookPoint::BeforeRound,
160            pri: 200,
161            name: "slow".to_string(),
162        }));
163        runner.register(Arc::new(ContinueHook {
164            point: AgentHookPoint::BeforeRound,
165            pri: 50,
166            name: "fast".to_string(),
167        }));
168
169        let mut state = AgentRuntimeState::new("run-2");
170        let session = test_session();
171
172        let result = runner
173            .run_hooks(AgentHookPoint::BeforeRound, &session, &mut state)
174            .await;
175
176        assert_eq!(result, HookResult::Continue);
177        assert_eq!(state.checkpoints.len(), 2);
178        // Lower priority runs first
179        assert!(state.checkpoints[0].result.contains("Continue"));
180    }
181
182    #[tokio::test]
183    async fn abort_short_circuits() {
184        let mut runner = HookRunner::new();
185        runner.register(Arc::new(AbortHook));
186
187        let mut state = AgentRuntimeState::new("run-3");
188        let session = test_session();
189
190        let result = runner
191            .run_hooks(AgentHookPoint::BeforeLlmCall, &session, &mut state)
192            .await;
193
194        assert!(matches!(result, HookResult::Abort { .. }));
195        assert_eq!(state.checkpoints.len(), 1);
196    }
197
198    #[tokio::test]
199    async fn wrong_point_hooks_are_skipped() {
200        let mut runner = HookRunner::new();
201        runner.register(Arc::new(AbortHook)); // registered for BeforeLlmCall
202
203        let mut state = AgentRuntimeState::new("run-4");
204        let session = test_session();
205
206        let result = runner
207            .run_hooks(AgentHookPoint::AfterRound, &session, &mut state)
208            .await;
209
210        assert_eq!(result, HookResult::Continue);
211        assert!(state.checkpoints.is_empty());
212    }
213}