bamboo_engine/runtime/hooks/
mod.rs1use std::sync::Arc;
4
5use bamboo_agent_core::AgentHook;
6use bamboo_agent_core::Session;
7use bamboo_domain::{AgentHookPoint, AgentRuntimeState, HookCheckpoint, HookResult};
8use chrono::Utc;
9
10#[derive(Clone)]
12pub struct HookRunner {
13 hooks: Vec<Arc<dyn AgentHook>>,
14}
15
16impl HookRunner {
17 pub fn new() -> Self {
18 Self { hooks: Vec::new() }
19 }
20
21 pub fn register(&mut self, hook: Arc<dyn AgentHook>) {
23 self.hooks.push(hook);
24 self.hooks.sort_by_key(|h| h.priority());
25 }
26
27 pub async fn run_hooks(
32 &self,
33 point: AgentHookPoint,
34 session: &Session,
35 runtime_state: &mut AgentRuntimeState,
36 ) -> HookResult {
37 let mut final_result = HookResult::Continue;
38
39 for hook in &self.hooks {
40 if hook.point() != point {
41 continue;
42 }
43
44 let start = std::time::Instant::now();
45 let result = hook.run(point, session).await;
46 let elapsed = start.elapsed();
47
48 runtime_state.checkpoints.push(HookCheckpoint {
49 hook_point: format!("{:?}", point),
50 timestamp: Utc::now(),
51 result: format!("{:?}", result),
52 duration_ms: elapsed.as_millis() as u64,
53 });
54
55 match &result {
56 HookResult::Abort { .. } | HookResult::Suspend { .. } => return result,
57 HookResult::Mutated => final_result = HookResult::Mutated,
58 HookResult::Continue => {}
59 }
60 }
61
62 final_result
63 }
64
65 pub fn has_hooks_for(&self, point: AgentHookPoint) -> bool {
67 self.hooks.iter().any(|h| h.point() == point)
68 }
69
70 pub fn len(&self) -> usize {
72 self.hooks.len()
73 }
74
75 pub fn is_empty(&self) -> bool {
77 self.hooks.is_empty()
78 }
79}
80
81impl Default for HookRunner {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 struct ContinueHook {
93 point: AgentHookPoint,
94 pri: u32,
95 name: String,
96 }
97
98 #[async_trait::async_trait]
99 impl AgentHook for ContinueHook {
100 fn point(&self) -> AgentHookPoint {
101 self.point
102 }
103
104 async fn run(&self, _point: AgentHookPoint, _session: &Session) -> HookResult {
105 HookResult::Continue
106 }
107
108 fn priority(&self) -> u32 {
109 self.pri
110 }
111
112 fn name(&self) -> &str {
113 &self.name
114 }
115 }
116
117 struct AbortHook;
119
120 #[async_trait::async_trait]
121 impl AgentHook for AbortHook {
122 fn point(&self) -> AgentHookPoint {
123 AgentHookPoint::BeforeLlmCall
124 }
125
126 async fn run(&self, _point: AgentHookPoint, _session: &Session) -> HookResult {
127 HookResult::Abort {
128 reason: "test abort".to_string(),
129 }
130 }
131
132 fn name(&self) -> &str {
133 "abort_hook"
134 }
135 }
136
137 fn test_session() -> Session {
138 Session::new("test", "test-model")
139 }
140
141 #[tokio::test]
142 async fn empty_runner_returns_continue() {
143 let runner = HookRunner::new();
144 let mut state = AgentRuntimeState::new("run-1");
145 let session = test_session();
146
147 let result = runner
148 .run_hooks(AgentHookPoint::BeforeRound, &session, &mut state)
149 .await;
150
151 assert_eq!(result, HookResult::Continue);
152 assert!(state.checkpoints.is_empty());
153 }
154
155 #[tokio::test]
156 async fn hooks_run_in_priority_order() {
157 let mut runner = HookRunner::new();
158 runner.register(Arc::new(ContinueHook {
159 point: AgentHookPoint::BeforeRound,
160 pri: 200,
161 name: "slow".to_string(),
162 }));
163 runner.register(Arc::new(ContinueHook {
164 point: AgentHookPoint::BeforeRound,
165 pri: 50,
166 name: "fast".to_string(),
167 }));
168
169 let mut state = AgentRuntimeState::new("run-2");
170 let session = test_session();
171
172 let result = runner
173 .run_hooks(AgentHookPoint::BeforeRound, &session, &mut state)
174 .await;
175
176 assert_eq!(result, HookResult::Continue);
177 assert_eq!(state.checkpoints.len(), 2);
178 assert!(state.checkpoints[0].result.contains("Continue"));
180 }
181
182 #[tokio::test]
183 async fn abort_short_circuits() {
184 let mut runner = HookRunner::new();
185 runner.register(Arc::new(AbortHook));
186
187 let mut state = AgentRuntimeState::new("run-3");
188 let session = test_session();
189
190 let result = runner
191 .run_hooks(AgentHookPoint::BeforeLlmCall, &session, &mut state)
192 .await;
193
194 assert!(matches!(result, HookResult::Abort { .. }));
195 assert_eq!(state.checkpoints.len(), 1);
196 }
197
198 #[tokio::test]
199 async fn wrong_point_hooks_are_skipped() {
200 let mut runner = HookRunner::new();
201 runner.register(Arc::new(AbortHook)); let mut state = AgentRuntimeState::new("run-4");
204 let session = test_session();
205
206 let result = runner
207 .run_hooks(AgentHookPoint::AfterRound, &session, &mut state)
208 .await;
209
210 assert_eq!(result, HookResult::Continue);
211 assert!(state.checkpoints.is_empty());
212 }
213}