1use super::events::{HookEvent, HookEventType};
6use super::matcher::HookMatcher;
7use super::{HookAction, HookResponse};
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use tokio::sync::mpsc;
13
14use crate::error::{read_or_recover, write_or_recover};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HookConfig {
19 #[serde(default = "default_priority")]
21 pub priority: i32,
22
23 #[serde(default = "default_timeout")]
25 pub timeout_ms: u64,
26
27 #[serde(default)]
29 pub async_execution: bool,
30
31 #[serde(default)]
33 pub max_retries: u32,
34}
35
36fn default_priority() -> i32 {
37 100
38}
39
40fn default_timeout() -> u64 {
41 30000
42}
43
44impl Default for HookConfig {
45 fn default() -> Self {
46 Self {
47 priority: default_priority(),
48 timeout_ms: default_timeout(),
49 async_execution: false,
50 max_retries: 0,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct Hook {
58 pub id: String,
60
61 pub event_type: HookEventType,
63
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub matcher: Option<HookMatcher>,
67
68 #[serde(default)]
70 pub config: HookConfig,
71}
72
73impl Hook {
74 pub fn new(id: impl Into<String>, event_type: HookEventType) -> Self {
76 Self {
77 id: id.into(),
78 event_type,
79 matcher: None,
80 config: HookConfig::default(),
81 }
82 }
83
84 pub fn with_matcher(mut self, matcher: HookMatcher) -> Self {
86 self.matcher = Some(matcher);
87 self
88 }
89
90 pub fn with_config(mut self, config: HookConfig) -> Self {
92 self.config = config;
93 self
94 }
95
96 pub fn matches(&self, event: &HookEvent) -> bool {
98 if event.event_type() != self.event_type {
100 return false;
101 }
102
103 if let Some(ref matcher) = self.matcher {
105 matcher.matches(event)
106 } else {
107 true
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
114pub enum HookResult {
115 Continue(Option<serde_json::Value>),
117 Block(String),
119 Retry(u64),
121 Skip,
123 Escalate {
125 reason: String,
126 target: Option<String>,
127 },
128}
129
130impl HookResult {
131 pub fn continue_() -> Self {
133 Self::Continue(None)
134 }
135
136 pub fn continue_with(modified: serde_json::Value) -> Self {
138 Self::Continue(Some(modified))
139 }
140
141 pub fn block(reason: impl Into<String>) -> Self {
143 Self::Block(reason.into())
144 }
145
146 pub fn retry(delay_ms: u64) -> Self {
148 Self::Retry(delay_ms)
149 }
150
151 pub fn skip() -> Self {
153 Self::Skip
154 }
155
156 pub fn escalate(reason: impl Into<String>, target: Option<String>) -> Self {
158 Self::Escalate {
159 reason: reason.into(),
160 target,
161 }
162 }
163
164 pub fn is_continue(&self) -> bool {
166 matches!(self, Self::Continue(_))
167 }
168
169 pub fn is_block(&self) -> bool {
171 matches!(self, Self::Block(_))
172 }
173}
174
175pub trait HookHandler: Send + Sync {
177 fn handle(&self, event: &HookEvent) -> HookResponse;
179}
180
181#[async_trait::async_trait]
186pub trait HookExecutor: Send + Sync + std::fmt::Debug {
187 async fn fire(&self, event: &HookEvent) -> HookResult;
189
190 async fn record_agent_event(
195 &self,
196 _event: &crate::agent::AgentEvent,
197 _run_id: &str,
198 _session_id: &str,
199 ) {
200 }
201
202 async fn record_run_cancelled(&self, _run_id: &str, _session_id: &str, _reason: Option<&str>) {}
205}
206
207pub struct HookEngine {
209 hooks: Arc<RwLock<HashMap<String, Hook>>>,
211
212 handlers: Arc<RwLock<HashMap<String, Arc<dyn HookHandler>>>>,
214
215 event_tx: Option<mpsc::Sender<HookEvent>>,
217}
218
219impl std::fmt::Debug for HookEngine {
220 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221 f.debug_struct("HookEngine")
222 .field("hooks_count", &read_or_recover(&self.hooks).len())
223 .field("handlers_count", &read_or_recover(&self.handlers).len())
224 .field("has_event_channel", &self.event_tx.is_some())
225 .finish()
226 }
227}
228
229impl Default for HookEngine {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235impl HookEngine {
236 pub fn new() -> Self {
238 Self {
239 hooks: Arc::new(RwLock::new(HashMap::new())),
240 handlers: Arc::new(RwLock::new(HashMap::new())),
241 event_tx: None,
242 }
243 }
244
245 pub fn with_event_channel(mut self, tx: mpsc::Sender<HookEvent>) -> Self {
247 self.event_tx = Some(tx);
248 self
249 }
250
251 pub fn register(&self, hook: Hook) {
253 let mut hooks = write_or_recover(&self.hooks);
254 hooks.insert(hook.id.clone(), hook);
255 }
256
257 pub fn unregister(&self, hook_id: &str) -> Option<Hook> {
259 let mut hooks = write_or_recover(&self.hooks);
260 hooks.remove(hook_id)
261 }
262
263 pub fn register_handler(&self, hook_id: &str, handler: Arc<dyn HookHandler>) {
265 let mut handlers = write_or_recover(&self.handlers);
266 handlers.insert(hook_id.to_string(), handler);
267 }
268
269 pub fn unregister_handler(&self, hook_id: &str) {
271 let mut handlers = write_or_recover(&self.handlers);
272 handlers.remove(hook_id);
273 }
274
275 pub fn matching_hooks(&self, event: &HookEvent) -> Vec<Hook> {
277 let hooks = read_or_recover(&self.hooks);
278 let mut matching: Vec<Hook> = hooks
279 .values()
280 .filter(|h| h.matches(event))
281 .cloned()
282 .collect();
283
284 matching.sort_by_key(|h| h.config.priority);
286 matching
287 }
288
289 pub async fn fire(&self, event: &HookEvent) -> HookResult {
291 if let Some(ref tx) = self.event_tx {
293 let _ = tx.send(event.clone()).await;
294 }
295
296 let matching_hooks = self.matching_hooks(event);
298
299 if matching_hooks.is_empty() {
300 return HookResult::continue_();
301 }
302
303 let mut last_modified: Option<serde_json::Value> = None;
305 for hook in matching_hooks {
306 let result = self.execute_hook(&hook, event).await;
307
308 match result {
309 HookResult::Continue(modified) => {
310 if modified.is_some() {
312 last_modified = modified;
313 }
314 }
315 HookResult::Block(reason) => {
316 return HookResult::Block(reason);
317 }
318 HookResult::Retry(delay) => {
319 return HookResult::Retry(delay);
320 }
321 HookResult::Skip => {
322 return HookResult::Continue(None);
323 }
324 HookResult::Escalate { reason, target } => {
325 return HookResult::Escalate { reason, target };
326 }
327 }
328 }
329
330 HookResult::Continue(last_modified)
331 }
332
333 async fn execute_hook(&self, hook: &Hook, event: &HookEvent) -> HookResult {
335 let handler = {
337 let handlers = read_or_recover(&self.handlers);
338 handlers.get(&hook.id).cloned()
339 };
340
341 match handler {
342 Some(h) => {
343 let response = if hook.config.async_execution {
345 let h = h.clone();
347 let event = event.clone();
348 tokio::spawn(async move {
349 h.handle(&event);
350 });
351 HookResponse::continue_()
352 } else {
353 let timeout = std::time::Duration::from_millis(hook.config.timeout_ms);
355 let h = h.clone();
356 let event = event.clone();
357
358 match tokio::time::timeout(timeout, async move { h.handle(&event) }).await {
359 Ok(response) => response,
360 Err(_) => {
361 HookResponse::continue_()
363 }
364 }
365 };
366
367 self.response_to_result(response)
368 }
369 None => {
370 HookResult::continue_()
372 }
373 }
374 }
375
376 fn response_to_result(&self, response: HookResponse) -> HookResult {
378 match response.action {
379 HookAction::Continue => HookResult::Continue(response.modified),
380 HookAction::Block => {
381 HookResult::Block(response.reason.unwrap_or_else(|| "Blocked".to_string()))
382 }
383 HookAction::Retry => HookResult::Retry(response.retry_delay_ms.unwrap_or(1000)),
384 HookAction::Skip => HookResult::Skip,
385 }
386 }
387
388 pub fn hook_count(&self) -> usize {
390 read_or_recover(&self.hooks).len()
391 }
392
393 pub fn get_hook(&self, id: &str) -> Option<Hook> {
395 read_or_recover(&self.hooks).get(id).cloned()
396 }
397
398 pub fn all_hooks(&self) -> Vec<Hook> {
400 read_or_recover(&self.hooks).values().cloned().collect()
401 }
402}
403
404#[async_trait]
406impl HookExecutor for HookEngine {
407 async fn fire(&self, event: &HookEvent) -> HookResult {
408 self.fire(event).await
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use crate::hooks::events::PreToolUseEvent;
416
417 fn make_pre_tool_event(session_id: &str, tool: &str) -> HookEvent {
418 HookEvent::PreToolUse(PreToolUseEvent {
419 session_id: session_id.to_string(),
420 tool: tool.to_string(),
421 args: serde_json::json!({}),
422 working_directory: "/workspace".to_string(),
423 recent_tools: vec![],
424 })
425 }
426
427 #[test]
428 fn test_hook_config_default() {
429 let config = HookConfig::default();
430 assert_eq!(config.priority, 100);
431 assert_eq!(config.timeout_ms, 30000);
432 assert!(!config.async_execution);
433 assert_eq!(config.max_retries, 0);
434 }
435
436 #[test]
437 fn test_hook_new() {
438 let hook = Hook::new("test-hook", HookEventType::PreToolUse);
439 assert_eq!(hook.id, "test-hook");
440 assert_eq!(hook.event_type, HookEventType::PreToolUse);
441 assert!(hook.matcher.is_none());
442 }
443
444 #[test]
445 fn test_hook_with_matcher() {
446 let hook = Hook::new("test-hook", HookEventType::PreToolUse)
447 .with_matcher(HookMatcher::tool("Bash"));
448
449 assert!(hook.matcher.is_some());
450 assert_eq!(hook.matcher.unwrap().tool, Some("Bash".to_string()));
451 }
452
453 #[test]
454 fn test_hook_matches_event_type() {
455 let hook = Hook::new("test-hook", HookEventType::PreToolUse);
456
457 let pre_event = make_pre_tool_event("s1", "Bash");
458 assert!(hook.matches(&pre_event));
459
460 let post_event = HookEvent::PostToolUse(crate::hooks::events::PostToolUseEvent {
462 session_id: "s1".to_string(),
463 tool: "Bash".to_string(),
464 args: serde_json::json!({}),
465 result: crate::hooks::events::ToolResultData {
466 success: true,
467 output: "".to_string(),
468 exit_code: Some(0),
469 duration_ms: 100,
470 },
471 });
472 assert!(!hook.matches(&post_event));
473 }
474
475 #[test]
476 fn test_hook_matches_with_matcher() {
477 let hook = Hook::new("test-hook", HookEventType::PreToolUse)
478 .with_matcher(HookMatcher::tool("Bash"));
479
480 let bash_event = make_pre_tool_event("s1", "Bash");
481 let read_event = make_pre_tool_event("s1", "Read");
482
483 assert!(hook.matches(&bash_event));
484 assert!(!hook.matches(&read_event));
485 }
486
487 #[test]
488 fn test_hook_result_constructors() {
489 let cont = HookResult::continue_();
490 assert!(cont.is_continue());
491 assert!(!cont.is_block());
492
493 let cont_with = HookResult::continue_with(serde_json::json!({"key": "value"}));
494 assert!(cont_with.is_continue());
495
496 let block = HookResult::block("Blocked");
497 assert!(block.is_block());
498 assert!(!block.is_continue());
499
500 let retry = HookResult::retry(1000);
501 assert!(!retry.is_continue());
502 assert!(!retry.is_block());
503
504 let skip = HookResult::skip();
505 assert!(!skip.is_continue());
506 assert!(!skip.is_block());
507 }
508
509 #[test]
510 fn test_engine_register_unregister() {
511 let engine = HookEngine::new();
512
513 let hook = Hook::new("test-hook", HookEventType::PreToolUse);
514 engine.register(hook);
515
516 assert_eq!(engine.hook_count(), 1);
517 assert!(engine.get_hook("test-hook").is_some());
518
519 let removed = engine.unregister("test-hook");
520 assert!(removed.is_some());
521 assert_eq!(engine.hook_count(), 0);
522 }
523
524 #[test]
525 fn test_engine_matching_hooks() {
526 let engine = HookEngine::new();
527
528 engine.register(
530 Hook::new("hook-1", HookEventType::PreToolUse).with_config(HookConfig {
531 priority: 10,
532 ..Default::default()
533 }),
534 );
535 engine.register(
536 Hook::new("hook-2", HookEventType::PreToolUse)
537 .with_matcher(HookMatcher::tool("Bash"))
538 .with_config(HookConfig {
539 priority: 5,
540 ..Default::default()
541 }),
542 );
543 engine.register(Hook::new("hook-3", HookEventType::PostToolUse));
544
545 let event = make_pre_tool_event("s1", "Bash");
546 let matching = engine.matching_hooks(&event);
547
548 assert_eq!(matching.len(), 2);
550
551 assert_eq!(matching[0].id, "hook-2");
553 assert_eq!(matching[1].id, "hook-1");
554 }
555
556 #[tokio::test]
557 async fn test_engine_fire_no_hooks() {
558 let engine = HookEngine::new();
559 let event = make_pre_tool_event("s1", "Bash");
560
561 let result = engine.fire(&event).await;
562 assert!(result.is_continue());
563 }
564
565 #[tokio::test]
566 async fn test_engine_fire_no_handler() {
567 let engine = HookEngine::new();
568 engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
569
570 let event = make_pre_tool_event("s1", "Bash");
571 let result = engine.fire(&event).await;
572
573 assert!(result.is_continue());
575 }
576
577 struct ContinueHandler;
579 impl HookHandler for ContinueHandler {
580 fn handle(&self, _event: &HookEvent) -> HookResponse {
581 HookResponse::continue_()
582 }
583 }
584
585 struct BlockHandler {
587 reason: String,
588 }
589 impl HookHandler for BlockHandler {
590 fn handle(&self, _event: &HookEvent) -> HookResponse {
591 HookResponse::block(&self.reason)
592 }
593 }
594
595 #[tokio::test]
596 async fn test_engine_fire_with_continue_handler() {
597 let engine = HookEngine::new();
598 engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
599 engine.register_handler("test-hook", Arc::new(ContinueHandler));
600
601 let event = make_pre_tool_event("s1", "Bash");
602 let result = engine.fire(&event).await;
603
604 assert!(result.is_continue());
605 }
606
607 #[tokio::test]
608 async fn test_engine_fire_with_block_handler() {
609 let engine = HookEngine::new();
610 engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
611 engine.register_handler(
612 "test-hook",
613 Arc::new(BlockHandler {
614 reason: "Dangerous command".to_string(),
615 }),
616 );
617
618 let event = make_pre_tool_event("s1", "Bash");
619 let result = engine.fire(&event).await;
620
621 assert!(result.is_block());
622 if let HookResult::Block(reason) = result {
623 assert_eq!(reason, "Dangerous command");
624 }
625 }
626
627 #[tokio::test]
628 async fn test_engine_fire_priority_order() {
629 let engine = HookEngine::new();
630
631 engine.register(
633 Hook::new("block-hook", HookEventType::PreToolUse).with_config(HookConfig {
634 priority: 5, ..Default::default()
636 }),
637 );
638 engine.register(
639 Hook::new("continue-hook", HookEventType::PreToolUse).with_config(HookConfig {
640 priority: 10,
641 ..Default::default()
642 }),
643 );
644
645 engine.register_handler(
646 "block-hook",
647 Arc::new(BlockHandler {
648 reason: "Blocked first".to_string(),
649 }),
650 );
651 engine.register_handler("continue-hook", Arc::new(ContinueHandler));
652
653 let event = make_pre_tool_event("s1", "Bash");
654 let result = engine.fire(&event).await;
655
656 assert!(result.is_block());
658 }
659
660 #[test]
661 fn test_hook_serialization() {
662 let hook = Hook::new("test-hook", HookEventType::PreToolUse)
663 .with_matcher(HookMatcher::tool("Bash"))
664 .with_config(HookConfig {
665 priority: 50,
666 timeout_ms: 5000,
667 async_execution: true,
668 max_retries: 3,
669 });
670
671 let json = serde_json::to_string(&hook).unwrap();
672 assert!(json.contains("test-hook"));
673 assert!(json.contains("pre_tool_use"));
674 assert!(json.contains("Bash"));
675
676 let parsed: Hook = serde_json::from_str(&json).unwrap();
677 assert_eq!(parsed.id, "test-hook");
678 assert_eq!(parsed.event_type, HookEventType::PreToolUse);
679 assert_eq!(parsed.config.priority, 50);
680 }
681
682 #[test]
683 fn test_all_hooks() {
684 let engine = HookEngine::new();
685 engine.register(Hook::new("hook-1", HookEventType::PreToolUse));
686 engine.register(Hook::new("hook-2", HookEventType::PostToolUse));
687
688 let all = engine.all_hooks();
689 assert_eq!(all.len(), 2);
690 }
691
692 fn make_skill_load_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
693 HookEvent::SkillLoad(crate::hooks::events::SkillLoadEvent {
694 skill_name: skill_name.to_string(),
695 tool_names: tools.iter().map(|s| s.to_string()).collect(),
696 version: Some("1.0.0".to_string()),
697 description: Some("Test skill".to_string()),
698 loaded_at: 1234567890,
699 })
700 }
701
702 fn make_skill_unload_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
703 HookEvent::SkillUnload(crate::hooks::events::SkillUnloadEvent {
704 skill_name: skill_name.to_string(),
705 tool_names: tools.iter().map(|s| s.to_string()).collect(),
706 duration_ms: 60000,
707 })
708 }
709
710 #[tokio::test]
711 async fn test_engine_fire_skill_load() {
712 let engine = HookEngine::new();
713
714 engine.register(Hook::new("skill-load-hook", HookEventType::SkillLoad));
716 engine.register_handler("skill-load-hook", Arc::new(ContinueHandler));
717
718 let event = make_skill_load_event("my-skill", vec!["tool1", "tool2"]);
719 let result = engine.fire(&event).await;
720
721 assert!(result.is_continue());
722 }
723
724 #[tokio::test]
725 async fn test_engine_fire_skill_unload() {
726 let engine = HookEngine::new();
727
728 engine.register(Hook::new("skill-unload-hook", HookEventType::SkillUnload));
730 engine.register_handler("skill-unload-hook", Arc::new(ContinueHandler));
731
732 let event = make_skill_unload_event("my-skill", vec!["tool1", "tool2"]);
733 let result = engine.fire(&event).await;
734
735 assert!(result.is_continue());
736 }
737
738 #[tokio::test]
739 async fn test_engine_skill_hook_with_matcher() {
740 let engine = HookEngine::new();
741
742 engine.register(
744 Hook::new("specific-skill-hook", HookEventType::SkillLoad)
745 .with_matcher(HookMatcher::skill("my-skill")),
746 );
747 engine.register_handler(
748 "specific-skill-hook",
749 Arc::new(BlockHandler {
750 reason: "Skill blocked".to_string(),
751 }),
752 );
753
754 let matching_event = make_skill_load_event("my-skill", vec!["tool1"]);
756 let result = engine.fire(&matching_event).await;
757 assert!(result.is_block());
758
759 let non_matching_event = make_skill_load_event("other-skill", vec!["tool1"]);
761 let result = engine.fire(&non_matching_event).await;
762 assert!(result.is_continue());
763 }
764
765 #[tokio::test]
766 async fn test_engine_skill_hook_pattern_matcher() {
767 let engine = HookEngine::new();
768
769 engine.register(
771 Hook::new("test-skill-hook", HookEventType::SkillLoad)
772 .with_matcher(HookMatcher::skill("test-*")),
773 );
774 engine.register_handler(
775 "test-skill-hook",
776 Arc::new(BlockHandler {
777 reason: "Test skill blocked".to_string(),
778 }),
779 );
780
781 let test_skill = make_skill_load_event("test-alpha", vec!["tool1"]);
783 let result = engine.fire(&test_skill).await;
784 assert!(result.is_block());
785
786 let test_skill2 = make_skill_load_event("test-beta", vec!["tool1"]);
787 let result = engine.fire(&test_skill2).await;
788 assert!(result.is_block());
789
790 let prod_skill = make_skill_load_event("prod-skill", vec!["tool1"]);
792 let result = engine.fire(&prod_skill).await;
793 assert!(result.is_continue());
794 }
795}