claude_agent/hooks/
manager.rs

1//! Hook manager for registering and executing hooks.
2
3use super::{Hook, HookContext, HookEvent, HookInput, HookOutput};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::time::{Duration, timeout};
7
8#[derive(Clone)]
9pub struct HookManager {
10    hooks: Vec<Arc<dyn Hook>>,
11    cache: HashMap<HookEvent, Vec<usize>>,
12    default_timeout_secs: u64,
13}
14
15impl Default for HookManager {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl HookManager {
22    pub fn new() -> Self {
23        Self {
24            hooks: Vec::new(),
25            cache: HashMap::new(),
26            default_timeout_secs: 60,
27        }
28    }
29
30    pub fn with_timeout(timeout_secs: u64) -> Self {
31        Self {
32            hooks: Vec::new(),
33            cache: HashMap::new(),
34            default_timeout_secs: timeout_secs,
35        }
36    }
37
38    fn rebuild_cache(&mut self) {
39        self.cache.clear();
40        for event in HookEvent::all() {
41            let mut indices: Vec<usize> = self
42                .hooks
43                .iter()
44                .enumerate()
45                .filter(|(_, h)| h.events().contains(event))
46                .map(|(i, _)| i)
47                .collect();
48            indices.sort_by_key(|&i| std::cmp::Reverse(self.hooks[i].priority()));
49            self.cache.insert(*event, indices);
50        }
51    }
52
53    pub fn register<H: Hook + 'static>(&mut self, hook: H) {
54        self.hooks.push(Arc::new(hook));
55        self.rebuild_cache();
56    }
57
58    pub fn register_arc(&mut self, hook: Arc<dyn Hook>) {
59        self.hooks.push(hook);
60        self.rebuild_cache();
61    }
62
63    pub fn unregister(&mut self, name: &str) {
64        self.hooks.retain(|h| h.name() != name);
65        self.rebuild_cache();
66    }
67
68    pub fn hook_names(&self) -> Vec<&str> {
69        self.hooks.iter().map(|h| h.name()).collect()
70    }
71
72    pub fn has_hook(&self, name: &str) -> bool {
73        self.hooks.iter().any(|h| h.name() == name)
74    }
75
76    #[inline]
77    pub fn hooks_for_event(&self, event: HookEvent) -> Vec<&Arc<dyn Hook>> {
78        self.cache
79            .get(&event)
80            .map(|indices| indices.iter().map(|&i| &self.hooks[i]).collect())
81            .unwrap_or_default()
82    }
83
84    pub async fn execute(
85        &self,
86        event: HookEvent,
87        input: HookInput,
88        hook_context: &HookContext,
89    ) -> Result<HookOutput, crate::Error> {
90        self.execute_hooks::<fn(&str, &HookOutput)>(event, input, hook_context, None)
91            .await
92    }
93
94    pub async fn execute_with_handler<F>(
95        &self,
96        event: HookEvent,
97        input: HookInput,
98        hook_context: &HookContext,
99        handler: F,
100    ) -> Result<HookOutput, crate::Error>
101    where
102        F: FnMut(&str, &HookOutput),
103    {
104        self.execute_hooks(event, input, hook_context, Some(handler))
105            .await
106    }
107
108    async fn execute_hooks<F>(
109        &self,
110        event: HookEvent,
111        input: HookInput,
112        hook_context: &HookContext,
113        mut handler: Option<F>,
114    ) -> Result<HookOutput, crate::Error>
115    where
116        F: FnMut(&str, &HookOutput),
117    {
118        let hooks = self.hooks_for_event(event);
119
120        if hooks.is_empty() {
121            return Ok(HookOutput::allow());
122        }
123
124        let mut merged_output = HookOutput::allow();
125
126        for hook in hooks {
127            if let (Some(matcher), Some(tool_name)) = (hook.tool_matcher(), input.tool_name())
128                && !matcher.is_match(tool_name)
129            {
130                continue;
131            }
132
133            let hook_timeout = hook.timeout_secs().min(self.default_timeout_secs);
134            let result = timeout(
135                Duration::from_secs(hook_timeout),
136                hook.execute(input.clone(), hook_context),
137            )
138            .await;
139
140            let output = match result {
141                Ok(Ok(output)) => output,
142                Ok(Err(e)) => {
143                    if event.can_block() {
144                        // Blockable hooks use fail-closed: errors propagate
145                        return Err(crate::Error::HookFailed {
146                            hook: hook.name().to_string(),
147                            reason: e.to_string(),
148                        });
149                    }
150                    // Non-blockable hooks use fail-open: log and continue
151                    tracing::warn!(hook = hook.name(), error = %e, "Hook execution failed");
152                    continue;
153                }
154                Err(_) => {
155                    if event.can_block() {
156                        // Blockable hooks use fail-closed: timeouts propagate
157                        return Err(crate::Error::HookTimeout {
158                            hook: hook.name().to_string(),
159                            duration_secs: hook_timeout,
160                        });
161                    }
162                    // Non-blockable hooks use fail-open: log and continue
163                    tracing::warn!(
164                        hook = hook.name(),
165                        timeout_secs = hook_timeout,
166                        "Hook timed out"
167                    );
168                    continue;
169                }
170            };
171
172            if let Some(ref mut h) = handler {
173                h(hook.name(), &output);
174            }
175            merged_output = Self::merge_outputs(merged_output, output);
176
177            if !merged_output.continue_execution {
178                break;
179            }
180        }
181
182        Ok(merged_output)
183    }
184
185    fn merge_outputs(base: HookOutput, new: HookOutput) -> HookOutput {
186        HookOutput {
187            continue_execution: base.continue_execution && new.continue_execution,
188            stop_reason: new.stop_reason.or(base.stop_reason),
189            suppress_logging: base.suppress_logging || new.suppress_logging,
190            system_message: new.system_message.or(base.system_message),
191            updated_input: new.updated_input.or(base.updated_input),
192            additional_context: match (base.additional_context, new.additional_context) {
193                (Some(a), Some(b)) => Some(format!("{}\n{}", a, b)),
194                (a, b) => a.or(b),
195            },
196        }
197    }
198}
199
200impl std::fmt::Debug for HookManager {
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        f.debug_struct("HookManager")
203            .field("hook_count", &self.hooks.len())
204            .field("hook_names", &self.hook_names())
205            .field("default_timeout_secs", &self.default_timeout_secs)
206            .finish()
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use async_trait::async_trait;
214
215    struct TestHook {
216        name: String,
217        events: Vec<HookEvent>,
218        priority: i32,
219        block: bool,
220    }
221
222    impl TestHook {
223        fn new(name: impl Into<String>, events: Vec<HookEvent>, priority: i32) -> Self {
224            Self {
225                name: name.into(),
226                events,
227                priority,
228                block: false,
229            }
230        }
231
232        fn blocking(name: impl Into<String>, events: Vec<HookEvent>, priority: i32) -> Self {
233            Self {
234                name: name.into(),
235                events,
236                priority,
237                block: true,
238            }
239        }
240    }
241
242    #[async_trait]
243    impl Hook for TestHook {
244        fn name(&self) -> &str {
245            &self.name
246        }
247
248        fn events(&self) -> &[HookEvent] {
249            &self.events
250        }
251
252        fn priority(&self) -> i32 {
253            self.priority
254        }
255
256        async fn execute(
257            &self,
258            _input: HookInput,
259            _hook_context: &HookContext,
260        ) -> Result<HookOutput, crate::Error> {
261            if self.block {
262                Ok(HookOutput::block(format!("Blocked by {}", self.name)))
263            } else {
264                Ok(HookOutput::allow())
265            }
266        }
267    }
268
269    #[tokio::test]
270    async fn test_hook_registration() {
271        let mut manager = HookManager::new();
272        manager.register(TestHook::new("hook1", vec![HookEvent::PreToolUse], 0));
273        manager.register(TestHook::new("hook2", vec![HookEvent::PostToolUse], 0));
274
275        assert!(manager.has_hook("hook1"));
276        assert!(manager.has_hook("hook2"));
277        assert!(!manager.has_hook("hook3"));
278        assert_eq!(manager.hook_names().len(), 2);
279    }
280
281    #[tokio::test]
282    async fn test_hook_unregistration() {
283        let mut manager = HookManager::new();
284        manager.register(TestHook::new("hook1", vec![HookEvent::PreToolUse], 0));
285        manager.register(TestHook::new("hook2", vec![HookEvent::PreToolUse], 0));
286
287        manager.unregister("hook1");
288
289        assert!(!manager.has_hook("hook1"));
290        assert!(manager.has_hook("hook2"));
291    }
292
293    #[tokio::test]
294    async fn test_hooks_for_event() {
295        let mut manager = HookManager::new();
296        manager.register(TestHook::new("hook1", vec![HookEvent::PreToolUse], 10));
297        manager.register(TestHook::new(
298            "hook2",
299            vec![HookEvent::PreToolUse, HookEvent::PostToolUse],
300            5,
301        ));
302        manager.register(TestHook::new("hook3", vec![HookEvent::SessionStart], 0));
303
304        let pre_hooks = manager.hooks_for_event(HookEvent::PreToolUse);
305        assert_eq!(pre_hooks.len(), 2);
306        // Check priority order (hook1 has higher priority)
307        assert_eq!(pre_hooks[0].name(), "hook1");
308        assert_eq!(pre_hooks[1].name(), "hook2");
309
310        let session_hooks = manager.hooks_for_event(HookEvent::SessionStart);
311        assert_eq!(session_hooks.len(), 1);
312        assert_eq!(session_hooks[0].name(), "hook3");
313    }
314
315    #[tokio::test]
316    async fn test_execute_allows() {
317        let mut manager = HookManager::new();
318        manager.register(TestHook::new("hook1", vec![HookEvent::PreToolUse], 0));
319        manager.register(TestHook::new("hook2", vec![HookEvent::PreToolUse], 0));
320
321        let input = HookInput::pre_tool_use("session-1", "Read", serde_json::json!({}));
322        let hook_context = HookContext::new("session-1");
323        let output = manager
324            .execute(HookEvent::PreToolUse, input, &hook_context)
325            .await
326            .unwrap();
327
328        assert!(output.continue_execution);
329    }
330
331    #[tokio::test]
332    async fn test_execute_blocks() {
333        let mut manager = HookManager::new();
334        manager.register(TestHook::new("hook1", vec![HookEvent::PreToolUse], 0));
335        manager.register(TestHook::blocking(
336            "hook2",
337            vec![HookEvent::PreToolUse],
338            10, // Higher priority, runs first
339        ));
340
341        let input = HookInput::pre_tool_use("session-1", "Read", serde_json::json!({}));
342        let hook_context = HookContext::new("session-1");
343        let output = manager
344            .execute(HookEvent::PreToolUse, input, &hook_context)
345            .await
346            .unwrap();
347
348        assert!(!output.continue_execution);
349        assert_eq!(output.stop_reason, Some("Blocked by hook2".to_string()));
350    }
351
352    #[tokio::test]
353    async fn test_no_hooks_allows() {
354        let manager = HookManager::new();
355
356        let input = HookInput::pre_tool_use("session-1", "Read", serde_json::json!({}));
357        let hook_context = HookContext::new("session-1");
358        let output = manager
359            .execute(HookEvent::PreToolUse, input, &hook_context)
360            .await
361            .unwrap();
362
363        assert!(output.continue_execution);
364    }
365
366    // Hook that always fails
367    struct FailingHook {
368        name: String,
369        events: Vec<HookEvent>,
370    }
371
372    impl FailingHook {
373        fn new(name: impl Into<String>, events: Vec<HookEvent>) -> Self {
374            Self {
375                name: name.into(),
376                events,
377            }
378        }
379    }
380
381    #[async_trait]
382    impl Hook for FailingHook {
383        fn name(&self) -> &str {
384            &self.name
385        }
386
387        fn events(&self) -> &[HookEvent] {
388            &self.events
389        }
390
391        async fn execute(
392            &self,
393            _input: HookInput,
394            _hook_context: &HookContext,
395        ) -> Result<HookOutput, crate::Error> {
396            Err(crate::Error::Config("Hook failed intentionally".into()))
397        }
398    }
399
400    // Hook that times out
401    struct SlowHook {
402        name: String,
403        events: Vec<HookEvent>,
404    }
405
406    impl SlowHook {
407        fn new(name: impl Into<String>, events: Vec<HookEvent>) -> Self {
408            Self {
409                name: name.into(),
410                events,
411            }
412        }
413    }
414
415    #[async_trait]
416    impl Hook for SlowHook {
417        fn name(&self) -> &str {
418            &self.name
419        }
420
421        fn events(&self) -> &[HookEvent] {
422            &self.events
423        }
424
425        fn timeout_secs(&self) -> u64 {
426            1 // Short timeout for testing
427        }
428
429        async fn execute(
430            &self,
431            _input: HookInput,
432            _hook_context: &HookContext,
433        ) -> Result<HookOutput, crate::Error> {
434            // Sleep longer than timeout
435            tokio::time::sleep(Duration::from_secs(5)).await;
436            Ok(HookOutput::allow())
437        }
438    }
439
440    #[tokio::test]
441    async fn test_blockable_hook_failure_returns_error() {
442        let mut manager = HookManager::new();
443        manager.register(FailingHook::new("failing", vec![HookEvent::PreToolUse]));
444
445        let input = HookInput::pre_tool_use("session-1", "Read", serde_json::json!({}));
446        let hook_context = HookContext::new("session-1");
447        let result = manager
448            .execute(HookEvent::PreToolUse, input, &hook_context)
449            .await;
450
451        assert!(result.is_err());
452        let err = result.unwrap_err();
453        assert!(matches!(err, crate::Error::HookFailed { .. }));
454    }
455
456    #[tokio::test]
457    async fn test_blockable_hook_timeout_returns_error() {
458        let mut manager = HookManager::with_timeout(1);
459        manager.register(SlowHook::new("slow", vec![HookEvent::UserPromptSubmit]));
460
461        let input = HookInput::user_prompt_submit("session-1", "test prompt");
462        let hook_context = HookContext::new("session-1");
463        let result = manager
464            .execute(HookEvent::UserPromptSubmit, input, &hook_context)
465            .await;
466
467        assert!(result.is_err());
468        let err = result.unwrap_err();
469        assert!(matches!(err, crate::Error::HookTimeout { .. }));
470    }
471
472    #[tokio::test]
473    async fn test_non_blockable_hook_failure_continues() {
474        let mut manager = HookManager::new();
475        // SessionEnd is non-blockable
476        manager.register(FailingHook::new("failing", vec![HookEvent::SessionEnd]));
477        manager.register(TestHook::new("success", vec![HookEvent::SessionEnd], 0));
478
479        let input = HookInput::session_end("session-1");
480        let hook_context = HookContext::new("session-1");
481        let result = manager
482            .execute(HookEvent::SessionEnd, input, &hook_context)
483            .await;
484
485        // Should succeed despite the failing hook
486        assert!(result.is_ok());
487        assert!(result.unwrap().continue_execution);
488    }
489
490    #[tokio::test]
491    async fn test_non_blockable_hook_timeout_continues() {
492        let mut manager = HookManager::with_timeout(1);
493        // PostToolUse is non-blockable
494        manager.register(SlowHook::new("slow", vec![HookEvent::PostToolUse]));
495
496        let input = HookInput::post_tool_use(
497            "session-1",
498            "Read",
499            crate::types::ToolOutput::success("result"),
500        );
501        let hook_context = HookContext::new("session-1");
502        let result = manager
503            .execute(HookEvent::PostToolUse, input, &hook_context)
504            .await;
505
506        // Should succeed despite the timeout
507        assert!(result.is_ok());
508        assert!(result.unwrap().continue_execution);
509    }
510}