1use super::{Hook, HookContext, HookEvent, HookInput, HookOutput};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::time::{Duration, timeout};
7
8#[derive(Clone)]
9pub struct HookManager {
10 hooks: Vec<Arc<dyn Hook>>,
11 cache: HashMap<HookEvent, Vec<usize>>,
12 default_timeout_secs: u64,
13}
14
15impl Default for HookManager {
16 fn default() -> Self {
17 Self::new()
18 }
19}
20
21impl HookManager {
22 pub fn new() -> Self {
23 Self {
24 hooks: Vec::new(),
25 cache: HashMap::new(),
26 default_timeout_secs: 60,
27 }
28 }
29
30 pub fn with_timeout(timeout_secs: u64) -> Self {
31 Self {
32 hooks: Vec::new(),
33 cache: HashMap::new(),
34 default_timeout_secs: timeout_secs,
35 }
36 }
37
38 fn rebuild_cache(&mut self) {
39 self.cache.clear();
40 for event in HookEvent::all() {
41 let mut indices: Vec<usize> = self
42 .hooks
43 .iter()
44 .enumerate()
45 .filter(|(_, h)| h.events().contains(event))
46 .map(|(i, _)| i)
47 .collect();
48 indices.sort_by_key(|&i| std::cmp::Reverse(self.hooks[i].priority()));
49 self.cache.insert(*event, indices);
50 }
51 }
52
53 pub fn register<H: Hook + 'static>(&mut self, hook: H) {
54 self.hooks.push(Arc::new(hook));
55 self.rebuild_cache();
56 }
57
58 pub fn register_arc(&mut self, hook: Arc<dyn Hook>) {
59 self.hooks.push(hook);
60 self.rebuild_cache();
61 }
62
63 pub fn unregister(&mut self, name: &str) {
64 self.hooks.retain(|h| h.name() != name);
65 self.rebuild_cache();
66 }
67
68 pub fn hook_names(&self) -> Vec<&str> {
69 self.hooks.iter().map(|h| h.name()).collect()
70 }
71
72 pub fn has_hook(&self, name: &str) -> bool {
73 self.hooks.iter().any(|h| h.name() == name)
74 }
75
76 #[inline]
77 pub fn hooks_for_event(&self, event: HookEvent) -> Vec<&Arc<dyn Hook>> {
78 self.cache
79 .get(&event)
80 .map(|indices| indices.iter().map(|&i| &self.hooks[i]).collect())
81 .unwrap_or_default()
82 }
83
84 pub async fn execute(
85 &self,
86 event: HookEvent,
87 input: HookInput,
88 hook_context: &HookContext,
89 ) -> Result<HookOutput, crate::Error> {
90 self.execute_hooks::<fn(&str, &HookOutput)>(event, input, hook_context, None)
91 .await
92 }
93
94 pub async fn execute_with_handler<F>(
95 &self,
96 event: HookEvent,
97 input: HookInput,
98 hook_context: &HookContext,
99 handler: F,
100 ) -> Result<HookOutput, crate::Error>
101 where
102 F: FnMut(&str, &HookOutput),
103 {
104 self.execute_hooks(event, input, hook_context, Some(handler))
105 .await
106 }
107
108 async fn execute_hooks<F>(
109 &self,
110 event: HookEvent,
111 input: HookInput,
112 hook_context: &HookContext,
113 mut handler: Option<F>,
114 ) -> Result<HookOutput, crate::Error>
115 where
116 F: FnMut(&str, &HookOutput),
117 {
118 let hooks = self.hooks_for_event(event);
119
120 if hooks.is_empty() {
121 return Ok(HookOutput::allow());
122 }
123
124 let mut merged_output = HookOutput::allow();
125
126 for hook in hooks {
127 if let (Some(matcher), Some(tool_name)) = (hook.tool_matcher(), input.tool_name())
128 && !matcher.is_match(tool_name)
129 {
130 continue;
131 }
132
133 let hook_timeout = hook.timeout_secs().min(self.default_timeout_secs);
134 let result = timeout(
135 Duration::from_secs(hook_timeout),
136 hook.execute(input.clone(), hook_context),
137 )
138 .await;
139
140 let output = match result {
141 Ok(Ok(output)) => output,
142 Ok(Err(e)) => {
143 if event.can_block() {
144 return Err(crate::Error::HookFailed {
146 hook: hook.name().to_string(),
147 reason: e.to_string(),
148 });
149 }
150 tracing::warn!(hook = hook.name(), error = %e, "Hook execution failed");
152 continue;
153 }
154 Err(_) => {
155 if event.can_block() {
156 return Err(crate::Error::HookTimeout {
158 hook: hook.name().to_string(),
159 duration_secs: hook_timeout,
160 });
161 }
162 tracing::warn!(
164 hook = hook.name(),
165 timeout_secs = hook_timeout,
166 "Hook timed out"
167 );
168 continue;
169 }
170 };
171
172 if let Some(ref mut h) = handler {
173 h(hook.name(), &output);
174 }
175 merged_output = Self::merge_outputs(merged_output, output);
176
177 if !merged_output.continue_execution {
178 break;
179 }
180 }
181
182 Ok(merged_output)
183 }
184
185 fn merge_outputs(base: HookOutput, new: HookOutput) -> HookOutput {
186 HookOutput {
187 continue_execution: base.continue_execution && new.continue_execution,
188 stop_reason: new.stop_reason.or(base.stop_reason),
189 suppress_logging: base.suppress_logging || new.suppress_logging,
190 system_message: new.system_message.or(base.system_message),
191 updated_input: new.updated_input.or(base.updated_input),
192 additional_context: match (base.additional_context, new.additional_context) {
193 (Some(a), Some(b)) => Some(format!("{}\n{}", a, b)),
194 (a, b) => a.or(b),
195 },
196 }
197 }
198}
199
200impl std::fmt::Debug for HookManager {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 f.debug_struct("HookManager")
203 .field("hook_count", &self.hooks.len())
204 .field("hook_names", &self.hook_names())
205 .field("default_timeout_secs", &self.default_timeout_secs)
206 .finish()
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use async_trait::async_trait;
214
215 struct TestHook {
216 name: String,
217 events: Vec<HookEvent>,
218 priority: i32,
219 block: bool,
220 }
221
222 impl TestHook {
223 fn new(name: impl Into<String>, events: Vec<HookEvent>, priority: i32) -> Self {
224 Self {
225 name: name.into(),
226 events,
227 priority,
228 block: false,
229 }
230 }
231
232 fn blocking(name: impl Into<String>, events: Vec<HookEvent>, priority: i32) -> Self {
233 Self {
234 name: name.into(),
235 events,
236 priority,
237 block: true,
238 }
239 }
240 }
241
242 #[async_trait]
243 impl Hook for TestHook {
244 fn name(&self) -> &str {
245 &self.name
246 }
247
248 fn events(&self) -> &[HookEvent] {
249 &self.events
250 }
251
252 fn priority(&self) -> i32 {
253 self.priority
254 }
255
256 async fn execute(
257 &self,
258 _input: HookInput,
259 _hook_context: &HookContext,
260 ) -> Result<HookOutput, crate::Error> {
261 if self.block {
262 Ok(HookOutput::block(format!("Blocked by {}", self.name)))
263 } else {
264 Ok(HookOutput::allow())
265 }
266 }
267 }
268
269 #[tokio::test]
270 async fn test_hook_registration() {
271 let mut manager = HookManager::new();
272 manager.register(TestHook::new("hook1", vec![HookEvent::PreToolUse], 0));
273 manager.register(TestHook::new("hook2", vec![HookEvent::PostToolUse], 0));
274
275 assert!(manager.has_hook("hook1"));
276 assert!(manager.has_hook("hook2"));
277 assert!(!manager.has_hook("hook3"));
278 assert_eq!(manager.hook_names().len(), 2);
279 }
280
281 #[tokio::test]
282 async fn test_hook_unregistration() {
283 let mut manager = HookManager::new();
284 manager.register(TestHook::new("hook1", vec![HookEvent::PreToolUse], 0));
285 manager.register(TestHook::new("hook2", vec![HookEvent::PreToolUse], 0));
286
287 manager.unregister("hook1");
288
289 assert!(!manager.has_hook("hook1"));
290 assert!(manager.has_hook("hook2"));
291 }
292
293 #[tokio::test]
294 async fn test_hooks_for_event() {
295 let mut manager = HookManager::new();
296 manager.register(TestHook::new("hook1", vec![HookEvent::PreToolUse], 10));
297 manager.register(TestHook::new(
298 "hook2",
299 vec![HookEvent::PreToolUse, HookEvent::PostToolUse],
300 5,
301 ));
302 manager.register(TestHook::new("hook3", vec![HookEvent::SessionStart], 0));
303
304 let pre_hooks = manager.hooks_for_event(HookEvent::PreToolUse);
305 assert_eq!(pre_hooks.len(), 2);
306 assert_eq!(pre_hooks[0].name(), "hook1");
308 assert_eq!(pre_hooks[1].name(), "hook2");
309
310 let session_hooks = manager.hooks_for_event(HookEvent::SessionStart);
311 assert_eq!(session_hooks.len(), 1);
312 assert_eq!(session_hooks[0].name(), "hook3");
313 }
314
315 #[tokio::test]
316 async fn test_execute_allows() {
317 let mut manager = HookManager::new();
318 manager.register(TestHook::new("hook1", vec![HookEvent::PreToolUse], 0));
319 manager.register(TestHook::new("hook2", vec![HookEvent::PreToolUse], 0));
320
321 let input = HookInput::pre_tool_use("session-1", "Read", serde_json::json!({}));
322 let hook_context = HookContext::new("session-1");
323 let output = manager
324 .execute(HookEvent::PreToolUse, input, &hook_context)
325 .await
326 .unwrap();
327
328 assert!(output.continue_execution);
329 }
330
331 #[tokio::test]
332 async fn test_execute_blocks() {
333 let mut manager = HookManager::new();
334 manager.register(TestHook::new("hook1", vec![HookEvent::PreToolUse], 0));
335 manager.register(TestHook::blocking(
336 "hook2",
337 vec![HookEvent::PreToolUse],
338 10, ));
340
341 let input = HookInput::pre_tool_use("session-1", "Read", serde_json::json!({}));
342 let hook_context = HookContext::new("session-1");
343 let output = manager
344 .execute(HookEvent::PreToolUse, input, &hook_context)
345 .await
346 .unwrap();
347
348 assert!(!output.continue_execution);
349 assert_eq!(output.stop_reason, Some("Blocked by hook2".to_string()));
350 }
351
352 #[tokio::test]
353 async fn test_no_hooks_allows() {
354 let manager = HookManager::new();
355
356 let input = HookInput::pre_tool_use("session-1", "Read", serde_json::json!({}));
357 let hook_context = HookContext::new("session-1");
358 let output = manager
359 .execute(HookEvent::PreToolUse, input, &hook_context)
360 .await
361 .unwrap();
362
363 assert!(output.continue_execution);
364 }
365
366 struct FailingHook {
368 name: String,
369 events: Vec<HookEvent>,
370 }
371
372 impl FailingHook {
373 fn new(name: impl Into<String>, events: Vec<HookEvent>) -> Self {
374 Self {
375 name: name.into(),
376 events,
377 }
378 }
379 }
380
381 #[async_trait]
382 impl Hook for FailingHook {
383 fn name(&self) -> &str {
384 &self.name
385 }
386
387 fn events(&self) -> &[HookEvent] {
388 &self.events
389 }
390
391 async fn execute(
392 &self,
393 _input: HookInput,
394 _hook_context: &HookContext,
395 ) -> Result<HookOutput, crate::Error> {
396 Err(crate::Error::Config("Hook failed intentionally".into()))
397 }
398 }
399
400 struct SlowHook {
402 name: String,
403 events: Vec<HookEvent>,
404 }
405
406 impl SlowHook {
407 fn new(name: impl Into<String>, events: Vec<HookEvent>) -> Self {
408 Self {
409 name: name.into(),
410 events,
411 }
412 }
413 }
414
415 #[async_trait]
416 impl Hook for SlowHook {
417 fn name(&self) -> &str {
418 &self.name
419 }
420
421 fn events(&self) -> &[HookEvent] {
422 &self.events
423 }
424
425 fn timeout_secs(&self) -> u64 {
426 1 }
428
429 async fn execute(
430 &self,
431 _input: HookInput,
432 _hook_context: &HookContext,
433 ) -> Result<HookOutput, crate::Error> {
434 tokio::time::sleep(Duration::from_secs(5)).await;
436 Ok(HookOutput::allow())
437 }
438 }
439
440 #[tokio::test]
441 async fn test_blockable_hook_failure_returns_error() {
442 let mut manager = HookManager::new();
443 manager.register(FailingHook::new("failing", vec![HookEvent::PreToolUse]));
444
445 let input = HookInput::pre_tool_use("session-1", "Read", serde_json::json!({}));
446 let hook_context = HookContext::new("session-1");
447 let result = manager
448 .execute(HookEvent::PreToolUse, input, &hook_context)
449 .await;
450
451 assert!(result.is_err());
452 let err = result.unwrap_err();
453 assert!(matches!(err, crate::Error::HookFailed { .. }));
454 }
455
456 #[tokio::test]
457 async fn test_blockable_hook_timeout_returns_error() {
458 let mut manager = HookManager::with_timeout(1);
459 manager.register(SlowHook::new("slow", vec![HookEvent::UserPromptSubmit]));
460
461 let input = HookInput::user_prompt_submit("session-1", "test prompt");
462 let hook_context = HookContext::new("session-1");
463 let result = manager
464 .execute(HookEvent::UserPromptSubmit, input, &hook_context)
465 .await;
466
467 assert!(result.is_err());
468 let err = result.unwrap_err();
469 assert!(matches!(err, crate::Error::HookTimeout { .. }));
470 }
471
472 #[tokio::test]
473 async fn test_non_blockable_hook_failure_continues() {
474 let mut manager = HookManager::new();
475 manager.register(FailingHook::new("failing", vec![HookEvent::SessionEnd]));
477 manager.register(TestHook::new("success", vec![HookEvent::SessionEnd], 0));
478
479 let input = HookInput::session_end("session-1");
480 let hook_context = HookContext::new("session-1");
481 let result = manager
482 .execute(HookEvent::SessionEnd, input, &hook_context)
483 .await;
484
485 assert!(result.is_ok());
487 assert!(result.unwrap().continue_execution);
488 }
489
490 #[tokio::test]
491 async fn test_non_blockable_hook_timeout_continues() {
492 let mut manager = HookManager::with_timeout(1);
493 manager.register(SlowHook::new("slow", vec![HookEvent::PostToolUse]));
495
496 let input = HookInput::post_tool_use(
497 "session-1",
498 "Read",
499 crate::types::ToolOutput::success("result"),
500 );
501 let hook_context = HookContext::new("session-1");
502 let result = manager
503 .execute(HookEvent::PostToolUse, input, &hook_context)
504 .await;
505
506 assert!(result.is_ok());
508 assert!(result.unwrap().continue_execution);
509 }
510}