1use std::time::Duration;
2
3use futures_util::{FutureExt, future::join_all};
4use serde_json::Value;
5use std::panic::AssertUnwindSafe;
6use tracing::info;
7
8use crate::channels::traits::ChannelMessage;
9use crate::providers::traits::{ChatMessage, ChatResponse};
10use crate::tools::traits::ToolResult;
11
12use super::traits::{HookHandler, HookResult};
13
14pub struct HookRunner {
20 handlers: Vec<Box<dyn HookHandler>>,
21}
22
23impl HookRunner {
24 pub fn new() -> Self {
26 Self {
27 handlers: Vec::new(),
28 }
29 }
30
31 pub fn register(&mut self, handler: Box<dyn HookHandler>) {
33 self.handlers.push(handler);
34 self.handlers
35 .sort_by_key(|h| std::cmp::Reverse(h.priority()));
36 }
37
38 pub async fn fire_gateway_start(&self, host: &str, port: u16) {
43 let futs: Vec<_> = self
44 .handlers
45 .iter()
46 .map(|h| h.on_gateway_start(host, port))
47 .collect();
48 join_all(futs).await;
49 }
50
51 pub async fn fire_gateway_stop(&self) {
52 let futs: Vec<_> = self.handlers.iter().map(|h| h.on_gateway_stop()).collect();
53 join_all(futs).await;
54 }
55
56 pub async fn fire_session_start(&self, session_id: &str, channel: &str) {
57 let futs: Vec<_> = self
58 .handlers
59 .iter()
60 .map(|h| h.on_session_start(session_id, channel))
61 .collect();
62 join_all(futs).await;
63 }
64
65 pub async fn fire_session_end(&self, session_id: &str, channel: &str) {
66 let futs: Vec<_> = self
67 .handlers
68 .iter()
69 .map(|h| h.on_session_end(session_id, channel))
70 .collect();
71 join_all(futs).await;
72 }
73
74 pub async fn fire_llm_input(&self, messages: &[ChatMessage], model: &str) {
75 let futs: Vec<_> = self
76 .handlers
77 .iter()
78 .map(|h| h.on_llm_input(messages, model))
79 .collect();
80 join_all(futs).await;
81 }
82
83 pub async fn fire_llm_output(&self, response: &ChatResponse) {
84 let futs: Vec<_> = self
85 .handlers
86 .iter()
87 .map(|h| h.on_llm_output(response))
88 .collect();
89 join_all(futs).await;
90 }
91
92 pub async fn fire_after_tool_call(&self, tool: &str, result: &ToolResult, duration: Duration) {
93 let futs: Vec<_> = self
94 .handlers
95 .iter()
96 .map(|h| h.on_after_tool_call(tool, result, duration))
97 .collect();
98 join_all(futs).await;
99 }
100
101 pub async fn fire_message_sent(&self, channel: &str, recipient: &str, content: &str) {
102 let futs: Vec<_> = self
103 .handlers
104 .iter()
105 .map(|h| h.on_message_sent(channel, recipient, content))
106 .collect();
107 join_all(futs).await;
108 }
109
110 pub async fn fire_heartbeat_tick(&self) {
111 let futs: Vec<_> = self
112 .handlers
113 .iter()
114 .map(|h| h.on_heartbeat_tick())
115 .collect();
116 join_all(futs).await;
117 }
118
119 pub async fn run_before_model_resolve(
124 &self,
125 mut provider: String,
126 mut model: String,
127 ) -> HookResult<(String, String)> {
128 for h in &self.handlers {
129 let hook_name = h.name();
130 match AssertUnwindSafe(h.before_model_resolve(provider.clone(), model.clone()))
131 .catch_unwind()
132 .await
133 {
134 Ok(HookResult::Continue((p, m))) => {
135 provider = p;
136 model = m;
137 }
138 Ok(HookResult::Cancel(reason)) => {
139 info!(
140 hook = hook_name,
141 reason, "before_model_resolve cancelled by hook"
142 );
143 return HookResult::Cancel(reason);
144 }
145 Err(_) => {
146 tracing::error!(
147 hook = hook_name,
148 "before_model_resolve hook panicked; continuing with previous values"
149 );
150 }
151 }
152 }
153 HookResult::Continue((provider, model))
154 }
155
156 pub async fn run_before_prompt_build(&self, mut prompt: String) -> HookResult<String> {
157 for h in &self.handlers {
158 let hook_name = h.name();
159 match AssertUnwindSafe(h.before_prompt_build(prompt.clone()))
160 .catch_unwind()
161 .await
162 {
163 Ok(HookResult::Continue(p)) => prompt = p,
164 Ok(HookResult::Cancel(reason)) => {
165 info!(
166 hook = hook_name,
167 reason, "before_prompt_build cancelled by hook"
168 );
169 return HookResult::Cancel(reason);
170 }
171 Err(_) => {
172 tracing::error!(
173 hook = hook_name,
174 "before_prompt_build hook panicked; continuing with previous value"
175 );
176 }
177 }
178 }
179 HookResult::Continue(prompt)
180 }
181
182 pub async fn run_before_llm_call(
183 &self,
184 mut messages: Vec<ChatMessage>,
185 mut model: String,
186 ) -> HookResult<(Vec<ChatMessage>, String)> {
187 for h in &self.handlers {
188 let hook_name = h.name();
189 match AssertUnwindSafe(h.before_llm_call(messages.clone(), model.clone()))
190 .catch_unwind()
191 .await
192 {
193 Ok(HookResult::Continue((m, mdl))) => {
194 messages = m;
195 model = mdl;
196 }
197 Ok(HookResult::Cancel(reason)) => {
198 info!(
199 hook = hook_name,
200 reason, "before_llm_call cancelled by hook"
201 );
202 return HookResult::Cancel(reason);
203 }
204 Err(_) => {
205 tracing::error!(
206 hook = hook_name,
207 "before_llm_call hook panicked; continuing with previous values"
208 );
209 }
210 }
211 }
212 HookResult::Continue((messages, model))
213 }
214
215 pub async fn run_before_tool_call(
216 &self,
217 mut name: String,
218 mut args: Value,
219 ) -> HookResult<(String, Value)> {
220 for h in &self.handlers {
221 let hook_name = h.name();
222 match AssertUnwindSafe(h.before_tool_call(name.clone(), args.clone()))
223 .catch_unwind()
224 .await
225 {
226 Ok(HookResult::Continue((n, a))) => {
227 name = n;
228 args = a;
229 }
230 Ok(HookResult::Cancel(reason)) => {
231 info!(
232 hook = hook_name,
233 reason, "before_tool_call cancelled by hook"
234 );
235 return HookResult::Cancel(reason);
236 }
237 Err(_) => {
238 tracing::error!(
239 hook = hook_name,
240 "before_tool_call hook panicked; continuing with previous values"
241 );
242 }
243 }
244 }
245 HookResult::Continue((name, args))
246 }
247
248 pub async fn run_on_message_received(
249 &self,
250 mut message: ChannelMessage,
251 ) -> HookResult<ChannelMessage> {
252 for h in &self.handlers {
253 let hook_name = h.name();
254 match AssertUnwindSafe(h.on_message_received(message.clone()))
255 .catch_unwind()
256 .await
257 {
258 Ok(HookResult::Continue(m)) => message = m,
259 Ok(HookResult::Cancel(reason)) => {
260 info!(
261 hook = hook_name,
262 reason, "on_message_received cancelled by hook"
263 );
264 return HookResult::Cancel(reason);
265 }
266 Err(_) => {
267 tracing::error!(
268 hook = hook_name,
269 "on_message_received hook panicked; continuing with previous message"
270 );
271 }
272 }
273 }
274 HookResult::Continue(message)
275 }
276
277 pub async fn run_on_message_sending(
278 &self,
279 mut channel: String,
280 mut recipient: String,
281 mut content: String,
282 ) -> HookResult<(String, String, String)> {
283 for h in &self.handlers {
284 let hook_name = h.name();
285 match AssertUnwindSafe(h.on_message_sending(
286 channel.clone(),
287 recipient.clone(),
288 content.clone(),
289 ))
290 .catch_unwind()
291 .await
292 {
293 Ok(HookResult::Continue((c, r, ct))) => {
294 channel = c;
295 recipient = r;
296 content = ct;
297 }
298 Ok(HookResult::Cancel(reason)) => {
299 info!(
300 hook = hook_name,
301 reason, "on_message_sending cancelled by hook"
302 );
303 return HookResult::Cancel(reason);
304 }
305 Err(_) => {
306 tracing::error!(
307 hook = hook_name,
308 "on_message_sending hook panicked; continuing with previous message"
309 );
310 }
311 }
312 }
313 HookResult::Continue((channel, recipient, content))
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use async_trait::async_trait;
321 use std::sync::Arc;
322 use std::sync::atomic::{AtomicU32, Ordering};
323
324 struct CountingHook {
326 name: String,
327 priority: i32,
328 fire_count: Arc<AtomicU32>,
329 }
330
331 impl CountingHook {
332 fn new(name: &str, priority: i32) -> (Self, Arc<AtomicU32>) {
333 let count = Arc::new(AtomicU32::new(0));
334 (
335 Self {
336 name: name.to_string(),
337 priority,
338 fire_count: count.clone(),
339 },
340 count,
341 )
342 }
343 }
344
345 #[async_trait]
346 impl HookHandler for CountingHook {
347 fn name(&self) -> &str {
348 &self.name
349 }
350 fn priority(&self) -> i32 {
351 self.priority
352 }
353 async fn on_heartbeat_tick(&self) {
354 self.fire_count.fetch_add(1, Ordering::SeqCst);
355 }
356 }
357
358 struct UppercasePromptHook {
360 name: String,
361 priority: i32,
362 }
363
364 #[async_trait]
365 impl HookHandler for UppercasePromptHook {
366 fn name(&self) -> &str {
367 &self.name
368 }
369 fn priority(&self) -> i32 {
370 self.priority
371 }
372 async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
373 HookResult::Continue(prompt.to_uppercase())
374 }
375 }
376
377 struct CancelPromptHook {
379 name: String,
380 priority: i32,
381 }
382
383 #[async_trait]
384 impl HookHandler for CancelPromptHook {
385 fn name(&self) -> &str {
386 &self.name
387 }
388 fn priority(&self) -> i32 {
389 self.priority
390 }
391 async fn before_prompt_build(&self, _prompt: String) -> HookResult<String> {
392 HookResult::Cancel("blocked by policy".into())
393 }
394 }
395
396 struct SuffixPromptHook {
398 name: String,
399 priority: i32,
400 suffix: String,
401 }
402
403 #[async_trait]
404 impl HookHandler for SuffixPromptHook {
405 fn name(&self) -> &str {
406 &self.name
407 }
408 fn priority(&self) -> i32 {
409 self.priority
410 }
411 async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
412 HookResult::Continue(format!("{}{}", prompt, self.suffix))
413 }
414 }
415
416 #[test]
417 fn register_and_sort_by_priority() {
418 let mut runner = HookRunner::new();
419 let (low, _) = CountingHook::new("low", 1);
420 let (high, _) = CountingHook::new("high", 10);
421 let (mid, _) = CountingHook::new("mid", 5);
422
423 runner.register(Box::new(low));
424 runner.register(Box::new(high));
425 runner.register(Box::new(mid));
426
427 let names: Vec<&str> = runner.handlers.iter().map(|h| h.name()).collect();
428 assert_eq!(names, vec!["high", "mid", "low"]);
429 }
430
431 #[tokio::test]
432 async fn void_hooks_fire_all_handlers() {
433 let mut runner = HookRunner::new();
434 let (h1, c1) = CountingHook::new("hook_a", 0);
435 let (h2, c2) = CountingHook::new("hook_b", 0);
436
437 runner.register(Box::new(h1));
438 runner.register(Box::new(h2));
439
440 runner.fire_heartbeat_tick().await;
441
442 assert_eq!(c1.load(Ordering::SeqCst), 1);
443 assert_eq!(c2.load(Ordering::SeqCst), 1);
444 }
445
446 #[tokio::test]
447 async fn modifying_hook_can_cancel() {
448 let mut runner = HookRunner::new();
449 runner.register(Box::new(CancelPromptHook {
450 name: "blocker".into(),
451 priority: 10,
452 }));
453 runner.register(Box::new(UppercasePromptHook {
454 name: "upper".into(),
455 priority: 0,
456 }));
457
458 let result = runner.run_before_prompt_build("hello".into()).await;
459 assert!(result.is_cancel());
460 }
461
462 #[tokio::test]
463 async fn modifying_hook_pipelines_data() {
464 let mut runner = HookRunner::new();
465
466 runner.register(Box::new(UppercasePromptHook {
468 name: "upper".into(),
469 priority: 10,
470 }));
471 runner.register(Box::new(SuffixPromptHook {
473 name: "suffix".into(),
474 priority: 0,
475 suffix: "_done".into(),
476 }));
477
478 match runner.run_before_prompt_build("hello".into()).await {
479 HookResult::Continue(result) => assert_eq!(result, "HELLO_done"),
480 HookResult::Cancel(_) => panic!("should not cancel"),
481 }
482 }
483}