Skip to main content

construct/hooks/
runner.rs

1use std::time::Duration;
2
3use futures_util::{FutureExt, future::join_all};
4use serde_json::Value;
5use std::panic::AssertUnwindSafe;
6use tracing::info;
7
8use crate::channels::traits::ChannelMessage;
9use crate::providers::traits::{ChatMessage, ChatResponse};
10use crate::tools::traits::ToolResult;
11
12use super::traits::{HookHandler, HookResult};
13
14/// Dispatcher that manages registered hook handlers.
15///
16/// Void hooks are dispatched in parallel via `join_all`.
17/// Modifying hooks run sequentially by priority (higher first), piping output
18/// and short-circuiting on `Cancel`.
19pub struct HookRunner {
20    handlers: Vec<Box<dyn HookHandler>>,
21}
22
23impl HookRunner {
24    /// Create an empty runner with no handlers.
25    pub fn new() -> Self {
26        Self {
27            handlers: Vec::new(),
28        }
29    }
30
31    /// Register a handler and re-sort by descending priority.
32    pub fn register(&mut self, handler: Box<dyn HookHandler>) {
33        self.handlers.push(handler);
34        self.handlers
35            .sort_by_key(|h| std::cmp::Reverse(h.priority()));
36    }
37
38    // ---------------------------------------------------------------
39    // Void dispatchers (parallel, fire-and-forget)
40    // ---------------------------------------------------------------
41
42    pub async fn fire_gateway_start(&self, host: &str, port: u16) {
43        let futs: Vec<_> = self
44            .handlers
45            .iter()
46            .map(|h| h.on_gateway_start(host, port))
47            .collect();
48        join_all(futs).await;
49    }
50
51    pub async fn fire_gateway_stop(&self) {
52        let futs: Vec<_> = self.handlers.iter().map(|h| h.on_gateway_stop()).collect();
53        join_all(futs).await;
54    }
55
56    pub async fn fire_session_start(&self, session_id: &str, channel: &str) {
57        let futs: Vec<_> = self
58            .handlers
59            .iter()
60            .map(|h| h.on_session_start(session_id, channel))
61            .collect();
62        join_all(futs).await;
63    }
64
65    pub async fn fire_session_end(&self, session_id: &str, channel: &str) {
66        let futs: Vec<_> = self
67            .handlers
68            .iter()
69            .map(|h| h.on_session_end(session_id, channel))
70            .collect();
71        join_all(futs).await;
72    }
73
74    pub async fn fire_llm_input(&self, messages: &[ChatMessage], model: &str) {
75        let futs: Vec<_> = self
76            .handlers
77            .iter()
78            .map(|h| h.on_llm_input(messages, model))
79            .collect();
80        join_all(futs).await;
81    }
82
83    pub async fn fire_llm_output(&self, response: &ChatResponse) {
84        let futs: Vec<_> = self
85            .handlers
86            .iter()
87            .map(|h| h.on_llm_output(response))
88            .collect();
89        join_all(futs).await;
90    }
91
92    pub async fn fire_after_tool_call(&self, tool: &str, result: &ToolResult, duration: Duration) {
93        let futs: Vec<_> = self
94            .handlers
95            .iter()
96            .map(|h| h.on_after_tool_call(tool, result, duration))
97            .collect();
98        join_all(futs).await;
99    }
100
101    pub async fn fire_message_sent(&self, channel: &str, recipient: &str, content: &str) {
102        let futs: Vec<_> = self
103            .handlers
104            .iter()
105            .map(|h| h.on_message_sent(channel, recipient, content))
106            .collect();
107        join_all(futs).await;
108    }
109
110    pub async fn fire_heartbeat_tick(&self) {
111        let futs: Vec<_> = self
112            .handlers
113            .iter()
114            .map(|h| h.on_heartbeat_tick())
115            .collect();
116        join_all(futs).await;
117    }
118
119    // ---------------------------------------------------------------
120    // Modifying dispatchers (sequential by priority, short-circuit on Cancel)
121    // ---------------------------------------------------------------
122
123    pub async fn run_before_model_resolve(
124        &self,
125        mut provider: String,
126        mut model: String,
127    ) -> HookResult<(String, String)> {
128        for h in &self.handlers {
129            let hook_name = h.name();
130            match AssertUnwindSafe(h.before_model_resolve(provider.clone(), model.clone()))
131                .catch_unwind()
132                .await
133            {
134                Ok(HookResult::Continue((p, m))) => {
135                    provider = p;
136                    model = m;
137                }
138                Ok(HookResult::Cancel(reason)) => {
139                    info!(
140                        hook = hook_name,
141                        reason, "before_model_resolve cancelled by hook"
142                    );
143                    return HookResult::Cancel(reason);
144                }
145                Err(_) => {
146                    tracing::error!(
147                        hook = hook_name,
148                        "before_model_resolve hook panicked; continuing with previous values"
149                    );
150                }
151            }
152        }
153        HookResult::Continue((provider, model))
154    }
155
156    pub async fn run_before_prompt_build(&self, mut prompt: String) -> HookResult<String> {
157        for h in &self.handlers {
158            let hook_name = h.name();
159            match AssertUnwindSafe(h.before_prompt_build(prompt.clone()))
160                .catch_unwind()
161                .await
162            {
163                Ok(HookResult::Continue(p)) => prompt = p,
164                Ok(HookResult::Cancel(reason)) => {
165                    info!(
166                        hook = hook_name,
167                        reason, "before_prompt_build cancelled by hook"
168                    );
169                    return HookResult::Cancel(reason);
170                }
171                Err(_) => {
172                    tracing::error!(
173                        hook = hook_name,
174                        "before_prompt_build hook panicked; continuing with previous value"
175                    );
176                }
177            }
178        }
179        HookResult::Continue(prompt)
180    }
181
182    pub async fn run_before_llm_call(
183        &self,
184        mut messages: Vec<ChatMessage>,
185        mut model: String,
186    ) -> HookResult<(Vec<ChatMessage>, String)> {
187        for h in &self.handlers {
188            let hook_name = h.name();
189            match AssertUnwindSafe(h.before_llm_call(messages.clone(), model.clone()))
190                .catch_unwind()
191                .await
192            {
193                Ok(HookResult::Continue((m, mdl))) => {
194                    messages = m;
195                    model = mdl;
196                }
197                Ok(HookResult::Cancel(reason)) => {
198                    info!(
199                        hook = hook_name,
200                        reason, "before_llm_call cancelled by hook"
201                    );
202                    return HookResult::Cancel(reason);
203                }
204                Err(_) => {
205                    tracing::error!(
206                        hook = hook_name,
207                        "before_llm_call hook panicked; continuing with previous values"
208                    );
209                }
210            }
211        }
212        HookResult::Continue((messages, model))
213    }
214
215    pub async fn run_before_tool_call(
216        &self,
217        mut name: String,
218        mut args: Value,
219    ) -> HookResult<(String, Value)> {
220        for h in &self.handlers {
221            let hook_name = h.name();
222            match AssertUnwindSafe(h.before_tool_call(name.clone(), args.clone()))
223                .catch_unwind()
224                .await
225            {
226                Ok(HookResult::Continue((n, a))) => {
227                    name = n;
228                    args = a;
229                }
230                Ok(HookResult::Cancel(reason)) => {
231                    info!(
232                        hook = hook_name,
233                        reason, "before_tool_call cancelled by hook"
234                    );
235                    return HookResult::Cancel(reason);
236                }
237                Err(_) => {
238                    tracing::error!(
239                        hook = hook_name,
240                        "before_tool_call hook panicked; continuing with previous values"
241                    );
242                }
243            }
244        }
245        HookResult::Continue((name, args))
246    }
247
248    pub async fn run_on_message_received(
249        &self,
250        mut message: ChannelMessage,
251    ) -> HookResult<ChannelMessage> {
252        for h in &self.handlers {
253            let hook_name = h.name();
254            match AssertUnwindSafe(h.on_message_received(message.clone()))
255                .catch_unwind()
256                .await
257            {
258                Ok(HookResult::Continue(m)) => message = m,
259                Ok(HookResult::Cancel(reason)) => {
260                    info!(
261                        hook = hook_name,
262                        reason, "on_message_received cancelled by hook"
263                    );
264                    return HookResult::Cancel(reason);
265                }
266                Err(_) => {
267                    tracing::error!(
268                        hook = hook_name,
269                        "on_message_received hook panicked; continuing with previous message"
270                    );
271                }
272            }
273        }
274        HookResult::Continue(message)
275    }
276
277    pub async fn run_on_message_sending(
278        &self,
279        mut channel: String,
280        mut recipient: String,
281        mut content: String,
282    ) -> HookResult<(String, String, String)> {
283        for h in &self.handlers {
284            let hook_name = h.name();
285            match AssertUnwindSafe(h.on_message_sending(
286                channel.clone(),
287                recipient.clone(),
288                content.clone(),
289            ))
290            .catch_unwind()
291            .await
292            {
293                Ok(HookResult::Continue((c, r, ct))) => {
294                    channel = c;
295                    recipient = r;
296                    content = ct;
297                }
298                Ok(HookResult::Cancel(reason)) => {
299                    info!(
300                        hook = hook_name,
301                        reason, "on_message_sending cancelled by hook"
302                    );
303                    return HookResult::Cancel(reason);
304                }
305                Err(_) => {
306                    tracing::error!(
307                        hook = hook_name,
308                        "on_message_sending hook panicked; continuing with previous message"
309                    );
310                }
311            }
312        }
313        HookResult::Continue((channel, recipient, content))
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use async_trait::async_trait;
321    use std::sync::Arc;
322    use std::sync::atomic::{AtomicU32, Ordering};
323
324    /// A hook that records how many times void events fire.
325    struct CountingHook {
326        name: String,
327        priority: i32,
328        fire_count: Arc<AtomicU32>,
329    }
330
331    impl CountingHook {
332        fn new(name: &str, priority: i32) -> (Self, Arc<AtomicU32>) {
333            let count = Arc::new(AtomicU32::new(0));
334            (
335                Self {
336                    name: name.to_string(),
337                    priority,
338                    fire_count: count.clone(),
339                },
340                count,
341            )
342        }
343    }
344
345    #[async_trait]
346    impl HookHandler for CountingHook {
347        fn name(&self) -> &str {
348            &self.name
349        }
350        fn priority(&self) -> i32 {
351            self.priority
352        }
353        async fn on_heartbeat_tick(&self) {
354            self.fire_count.fetch_add(1, Ordering::SeqCst);
355        }
356    }
357
358    /// A modifying hook that uppercases the prompt.
359    struct UppercasePromptHook {
360        name: String,
361        priority: i32,
362    }
363
364    #[async_trait]
365    impl HookHandler for UppercasePromptHook {
366        fn name(&self) -> &str {
367            &self.name
368        }
369        fn priority(&self) -> i32 {
370            self.priority
371        }
372        async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
373            HookResult::Continue(prompt.to_uppercase())
374        }
375    }
376
377    /// A modifying hook that cancels before_prompt_build.
378    struct CancelPromptHook {
379        name: String,
380        priority: i32,
381    }
382
383    #[async_trait]
384    impl HookHandler for CancelPromptHook {
385        fn name(&self) -> &str {
386            &self.name
387        }
388        fn priority(&self) -> i32 {
389            self.priority
390        }
391        async fn before_prompt_build(&self, _prompt: String) -> HookResult<String> {
392            HookResult::Cancel("blocked by policy".into())
393        }
394    }
395
396    /// A modifying hook that appends a suffix to the prompt.
397    struct SuffixPromptHook {
398        name: String,
399        priority: i32,
400        suffix: String,
401    }
402
403    #[async_trait]
404    impl HookHandler for SuffixPromptHook {
405        fn name(&self) -> &str {
406            &self.name
407        }
408        fn priority(&self) -> i32 {
409            self.priority
410        }
411        async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
412            HookResult::Continue(format!("{}{}", prompt, self.suffix))
413        }
414    }
415
416    #[test]
417    fn register_and_sort_by_priority() {
418        let mut runner = HookRunner::new();
419        let (low, _) = CountingHook::new("low", 1);
420        let (high, _) = CountingHook::new("high", 10);
421        let (mid, _) = CountingHook::new("mid", 5);
422
423        runner.register(Box::new(low));
424        runner.register(Box::new(high));
425        runner.register(Box::new(mid));
426
427        let names: Vec<&str> = runner.handlers.iter().map(|h| h.name()).collect();
428        assert_eq!(names, vec!["high", "mid", "low"]);
429    }
430
431    #[tokio::test]
432    async fn void_hooks_fire_all_handlers() {
433        let mut runner = HookRunner::new();
434        let (h1, c1) = CountingHook::new("hook_a", 0);
435        let (h2, c2) = CountingHook::new("hook_b", 0);
436
437        runner.register(Box::new(h1));
438        runner.register(Box::new(h2));
439
440        runner.fire_heartbeat_tick().await;
441
442        assert_eq!(c1.load(Ordering::SeqCst), 1);
443        assert_eq!(c2.load(Ordering::SeqCst), 1);
444    }
445
446    #[tokio::test]
447    async fn modifying_hook_can_cancel() {
448        let mut runner = HookRunner::new();
449        runner.register(Box::new(CancelPromptHook {
450            name: "blocker".into(),
451            priority: 10,
452        }));
453        runner.register(Box::new(UppercasePromptHook {
454            name: "upper".into(),
455            priority: 0,
456        }));
457
458        let result = runner.run_before_prompt_build("hello".into()).await;
459        assert!(result.is_cancel());
460    }
461
462    #[tokio::test]
463    async fn modifying_hook_pipelines_data() {
464        let mut runner = HookRunner::new();
465
466        // Priority 10 runs first: uppercases
467        runner.register(Box::new(UppercasePromptHook {
468            name: "upper".into(),
469            priority: 10,
470        }));
471        // Priority 0 runs second: appends suffix
472        runner.register(Box::new(SuffixPromptHook {
473            name: "suffix".into(),
474            priority: 0,
475            suffix: "_done".into(),
476        }));
477
478        match runner.run_before_prompt_build("hello".into()).await {
479            HookResult::Continue(result) => assert_eq!(result, "HELLO_done"),
480            HookResult::Cancel(_) => panic!("should not cancel"),
481        }
482    }
483}