1use super::events::{HookEvent, HookEventType};
6use super::matcher::HookMatcher;
7use super::{HookAction, HookResponse};
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use tokio::sync::mpsc;
13
14use crate::error::{read_or_recover, write_or_recover};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HookConfig {
19 #[serde(default = "default_priority")]
21 pub priority: i32,
22
23 #[serde(default = "default_timeout")]
25 pub timeout_ms: u64,
26
27 #[serde(default)]
29 pub async_execution: bool,
30
31 #[serde(default)]
33 pub max_retries: u32,
34}
35
36fn default_priority() -> i32 {
37 100
38}
39
40fn default_timeout() -> u64 {
41 30000
42}
43
44impl Default for HookConfig {
45 fn default() -> Self {
46 Self {
47 priority: default_priority(),
48 timeout_ms: default_timeout(),
49 async_execution: false,
50 max_retries: 0,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct Hook {
58 pub id: String,
60
61 pub event_type: HookEventType,
63
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub matcher: Option<HookMatcher>,
67
68 #[serde(default)]
70 pub config: HookConfig,
71}
72
73impl Hook {
74 pub fn new(id: impl Into<String>, event_type: HookEventType) -> Self {
76 Self {
77 id: id.into(),
78 event_type,
79 matcher: None,
80 config: HookConfig::default(),
81 }
82 }
83
84 pub fn with_matcher(mut self, matcher: HookMatcher) -> Self {
86 self.matcher = Some(matcher);
87 self
88 }
89
90 pub fn with_config(mut self, config: HookConfig) -> Self {
92 self.config = config;
93 self
94 }
95
96 pub fn matches(&self, event: &HookEvent) -> bool {
98 if event.event_type() != self.event_type {
100 return false;
101 }
102
103 if let Some(ref matcher) = self.matcher {
105 matcher.matches(event)
106 } else {
107 true
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
114pub enum HookResult {
115 Continue(Option<serde_json::Value>),
117 Block(String),
119 Retry(u64),
121 Skip,
123 Escalate {
125 reason: String,
126 target: Option<String>,
127 },
128}
129
130impl HookResult {
131 pub fn continue_() -> Self {
133 Self::Continue(None)
134 }
135
136 pub fn continue_with(modified: serde_json::Value) -> Self {
138 Self::Continue(Some(modified))
139 }
140
141 pub fn block(reason: impl Into<String>) -> Self {
143 Self::Block(reason.into())
144 }
145
146 pub fn retry(delay_ms: u64) -> Self {
148 Self::Retry(delay_ms)
149 }
150
151 pub fn skip() -> Self {
153 Self::Skip
154 }
155
156 pub fn escalate(reason: impl Into<String>, target: Option<String>) -> Self {
158 Self::Escalate {
159 reason: reason.into(),
160 target,
161 }
162 }
163
164 pub fn is_continue(&self) -> bool {
166 matches!(self, Self::Continue(_))
167 }
168
169 pub fn is_block(&self) -> bool {
171 matches!(self, Self::Block(_))
172 }
173}
174
175pub trait HookHandler: Send + Sync {
177 fn handle(&self, event: &HookEvent) -> HookResponse;
179}
180
181#[async_trait::async_trait]
186pub trait HookExecutor: Send + Sync + std::fmt::Debug {
187 async fn fire(&self, event: &HookEvent) -> HookResult;
189}
190
191pub struct HookEngine {
193 hooks: Arc<RwLock<HashMap<String, Hook>>>,
195
196 handlers: Arc<RwLock<HashMap<String, Arc<dyn HookHandler>>>>,
198
199 event_tx: Option<mpsc::Sender<HookEvent>>,
201}
202
203impl std::fmt::Debug for HookEngine {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 f.debug_struct("HookEngine")
206 .field("hooks_count", &read_or_recover(&self.hooks).len())
207 .field("handlers_count", &read_or_recover(&self.handlers).len())
208 .field("has_event_channel", &self.event_tx.is_some())
209 .finish()
210 }
211}
212
213impl Default for HookEngine {
214 fn default() -> Self {
215 Self::new()
216 }
217}
218
219impl HookEngine {
220 pub fn new() -> Self {
222 Self {
223 hooks: Arc::new(RwLock::new(HashMap::new())),
224 handlers: Arc::new(RwLock::new(HashMap::new())),
225 event_tx: None,
226 }
227 }
228
229 pub fn with_event_channel(mut self, tx: mpsc::Sender<HookEvent>) -> Self {
231 self.event_tx = Some(tx);
232 self
233 }
234
235 pub fn register(&self, hook: Hook) {
237 let mut hooks = write_or_recover(&self.hooks);
238 hooks.insert(hook.id.clone(), hook);
239 }
240
241 pub fn unregister(&self, hook_id: &str) -> Option<Hook> {
243 let mut hooks = write_or_recover(&self.hooks);
244 hooks.remove(hook_id)
245 }
246
247 pub fn register_handler(&self, hook_id: &str, handler: Arc<dyn HookHandler>) {
249 let mut handlers = write_or_recover(&self.handlers);
250 handlers.insert(hook_id.to_string(), handler);
251 }
252
253 pub fn unregister_handler(&self, hook_id: &str) {
255 let mut handlers = write_or_recover(&self.handlers);
256 handlers.remove(hook_id);
257 }
258
259 pub fn matching_hooks(&self, event: &HookEvent) -> Vec<Hook> {
261 let hooks = read_or_recover(&self.hooks);
262 let mut matching: Vec<Hook> = hooks
263 .values()
264 .filter(|h| h.matches(event))
265 .cloned()
266 .collect();
267
268 matching.sort_by_key(|h| h.config.priority);
270 matching
271 }
272
273 pub async fn fire(&self, event: &HookEvent) -> HookResult {
275 if let Some(ref tx) = self.event_tx {
277 let _ = tx.send(event.clone()).await;
278 }
279
280 let matching_hooks = self.matching_hooks(event);
282
283 if matching_hooks.is_empty() {
284 return HookResult::continue_();
285 }
286
287 let mut last_modified: Option<serde_json::Value> = None;
289 for hook in matching_hooks {
290 let result = self.execute_hook(&hook, event).await;
291
292 match result {
293 HookResult::Continue(modified) => {
294 if modified.is_some() {
296 last_modified = modified;
297 }
298 }
299 HookResult::Block(reason) => {
300 return HookResult::Block(reason);
301 }
302 HookResult::Retry(delay) => {
303 return HookResult::Retry(delay);
304 }
305 HookResult::Skip => {
306 return HookResult::Continue(None);
307 }
308 HookResult::Escalate { reason, target } => {
309 return HookResult::Escalate { reason, target };
310 }
311 }
312 }
313
314 HookResult::Continue(last_modified)
315 }
316
317 async fn execute_hook(&self, hook: &Hook, event: &HookEvent) -> HookResult {
319 let handler = {
321 let handlers = read_or_recover(&self.handlers);
322 handlers.get(&hook.id).cloned()
323 };
324
325 match handler {
326 Some(h) => {
327 let response = if hook.config.async_execution {
329 let h = h.clone();
331 let event = event.clone();
332 tokio::spawn(async move {
333 h.handle(&event);
334 });
335 HookResponse::continue_()
336 } else {
337 let timeout = std::time::Duration::from_millis(hook.config.timeout_ms);
339 let h = h.clone();
340 let event = event.clone();
341
342 match tokio::time::timeout(timeout, async move { h.handle(&event) }).await {
343 Ok(response) => response,
344 Err(_) => {
345 HookResponse::continue_()
347 }
348 }
349 };
350
351 self.response_to_result(response)
352 }
353 None => {
354 HookResult::continue_()
356 }
357 }
358 }
359
360 fn response_to_result(&self, response: HookResponse) -> HookResult {
362 match response.action {
363 HookAction::Continue => HookResult::Continue(response.modified),
364 HookAction::Block => {
365 HookResult::Block(response.reason.unwrap_or_else(|| "Blocked".to_string()))
366 }
367 HookAction::Retry => HookResult::Retry(response.retry_delay_ms.unwrap_or(1000)),
368 HookAction::Skip => HookResult::Skip,
369 }
370 }
371
372 pub fn hook_count(&self) -> usize {
374 read_or_recover(&self.hooks).len()
375 }
376
377 pub fn get_hook(&self, id: &str) -> Option<Hook> {
379 read_or_recover(&self.hooks).get(id).cloned()
380 }
381
382 pub fn all_hooks(&self) -> Vec<Hook> {
384 read_or_recover(&self.hooks).values().cloned().collect()
385 }
386}
387
388#[async_trait]
390impl HookExecutor for HookEngine {
391 async fn fire(&self, event: &HookEvent) -> HookResult {
392 self.fire(event).await
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use crate::hooks::events::PreToolUseEvent;
400
401 fn make_pre_tool_event(session_id: &str, tool: &str) -> HookEvent {
402 HookEvent::PreToolUse(PreToolUseEvent {
403 session_id: session_id.to_string(),
404 tool: tool.to_string(),
405 args: serde_json::json!({}),
406 working_directory: "/workspace".to_string(),
407 recent_tools: vec![],
408 })
409 }
410
411 #[test]
412 fn test_hook_config_default() {
413 let config = HookConfig::default();
414 assert_eq!(config.priority, 100);
415 assert_eq!(config.timeout_ms, 30000);
416 assert!(!config.async_execution);
417 assert_eq!(config.max_retries, 0);
418 }
419
420 #[test]
421 fn test_hook_new() {
422 let hook = Hook::new("test-hook", HookEventType::PreToolUse);
423 assert_eq!(hook.id, "test-hook");
424 assert_eq!(hook.event_type, HookEventType::PreToolUse);
425 assert!(hook.matcher.is_none());
426 }
427
428 #[test]
429 fn test_hook_with_matcher() {
430 let hook = Hook::new("test-hook", HookEventType::PreToolUse)
431 .with_matcher(HookMatcher::tool("Bash"));
432
433 assert!(hook.matcher.is_some());
434 assert_eq!(hook.matcher.unwrap().tool, Some("Bash".to_string()));
435 }
436
437 #[test]
438 fn test_hook_matches_event_type() {
439 let hook = Hook::new("test-hook", HookEventType::PreToolUse);
440
441 let pre_event = make_pre_tool_event("s1", "Bash");
442 assert!(hook.matches(&pre_event));
443
444 let post_event = HookEvent::PostToolUse(crate::hooks::events::PostToolUseEvent {
446 session_id: "s1".to_string(),
447 tool: "Bash".to_string(),
448 args: serde_json::json!({}),
449 result: crate::hooks::events::ToolResultData {
450 success: true,
451 output: "".to_string(),
452 exit_code: Some(0),
453 duration_ms: 100,
454 },
455 });
456 assert!(!hook.matches(&post_event));
457 }
458
459 #[test]
460 fn test_hook_matches_with_matcher() {
461 let hook = Hook::new("test-hook", HookEventType::PreToolUse)
462 .with_matcher(HookMatcher::tool("Bash"));
463
464 let bash_event = make_pre_tool_event("s1", "Bash");
465 let read_event = make_pre_tool_event("s1", "Read");
466
467 assert!(hook.matches(&bash_event));
468 assert!(!hook.matches(&read_event));
469 }
470
471 #[test]
472 fn test_hook_result_constructors() {
473 let cont = HookResult::continue_();
474 assert!(cont.is_continue());
475 assert!(!cont.is_block());
476
477 let cont_with = HookResult::continue_with(serde_json::json!({"key": "value"}));
478 assert!(cont_with.is_continue());
479
480 let block = HookResult::block("Blocked");
481 assert!(block.is_block());
482 assert!(!block.is_continue());
483
484 let retry = HookResult::retry(1000);
485 assert!(!retry.is_continue());
486 assert!(!retry.is_block());
487
488 let skip = HookResult::skip();
489 assert!(!skip.is_continue());
490 assert!(!skip.is_block());
491 }
492
493 #[test]
494 fn test_engine_register_unregister() {
495 let engine = HookEngine::new();
496
497 let hook = Hook::new("test-hook", HookEventType::PreToolUse);
498 engine.register(hook);
499
500 assert_eq!(engine.hook_count(), 1);
501 assert!(engine.get_hook("test-hook").is_some());
502
503 let removed = engine.unregister("test-hook");
504 assert!(removed.is_some());
505 assert_eq!(engine.hook_count(), 0);
506 }
507
508 #[test]
509 fn test_engine_matching_hooks() {
510 let engine = HookEngine::new();
511
512 engine.register(
514 Hook::new("hook-1", HookEventType::PreToolUse).with_config(HookConfig {
515 priority: 10,
516 ..Default::default()
517 }),
518 );
519 engine.register(
520 Hook::new("hook-2", HookEventType::PreToolUse)
521 .with_matcher(HookMatcher::tool("Bash"))
522 .with_config(HookConfig {
523 priority: 5,
524 ..Default::default()
525 }),
526 );
527 engine.register(Hook::new("hook-3", HookEventType::PostToolUse));
528
529 let event = make_pre_tool_event("s1", "Bash");
530 let matching = engine.matching_hooks(&event);
531
532 assert_eq!(matching.len(), 2);
534
535 assert_eq!(matching[0].id, "hook-2");
537 assert_eq!(matching[1].id, "hook-1");
538 }
539
540 #[tokio::test]
541 async fn test_engine_fire_no_hooks() {
542 let engine = HookEngine::new();
543 let event = make_pre_tool_event("s1", "Bash");
544
545 let result = engine.fire(&event).await;
546 assert!(result.is_continue());
547 }
548
549 #[tokio::test]
550 async fn test_engine_fire_no_handler() {
551 let engine = HookEngine::new();
552 engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
553
554 let event = make_pre_tool_event("s1", "Bash");
555 let result = engine.fire(&event).await;
556
557 assert!(result.is_continue());
559 }
560
561 struct ContinueHandler;
563 impl HookHandler for ContinueHandler {
564 fn handle(&self, _event: &HookEvent) -> HookResponse {
565 HookResponse::continue_()
566 }
567 }
568
569 struct BlockHandler {
571 reason: String,
572 }
573 impl HookHandler for BlockHandler {
574 fn handle(&self, _event: &HookEvent) -> HookResponse {
575 HookResponse::block(&self.reason)
576 }
577 }
578
579 #[tokio::test]
580 async fn test_engine_fire_with_continue_handler() {
581 let engine = HookEngine::new();
582 engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
583 engine.register_handler("test-hook", Arc::new(ContinueHandler));
584
585 let event = make_pre_tool_event("s1", "Bash");
586 let result = engine.fire(&event).await;
587
588 assert!(result.is_continue());
589 }
590
591 #[tokio::test]
592 async fn test_engine_fire_with_block_handler() {
593 let engine = HookEngine::new();
594 engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
595 engine.register_handler(
596 "test-hook",
597 Arc::new(BlockHandler {
598 reason: "Dangerous command".to_string(),
599 }),
600 );
601
602 let event = make_pre_tool_event("s1", "Bash");
603 let result = engine.fire(&event).await;
604
605 assert!(result.is_block());
606 if let HookResult::Block(reason) = result {
607 assert_eq!(reason, "Dangerous command");
608 }
609 }
610
611 #[tokio::test]
612 async fn test_engine_fire_priority_order() {
613 let engine = HookEngine::new();
614
615 engine.register(
617 Hook::new("block-hook", HookEventType::PreToolUse).with_config(HookConfig {
618 priority: 5, ..Default::default()
620 }),
621 );
622 engine.register(
623 Hook::new("continue-hook", HookEventType::PreToolUse).with_config(HookConfig {
624 priority: 10,
625 ..Default::default()
626 }),
627 );
628
629 engine.register_handler(
630 "block-hook",
631 Arc::new(BlockHandler {
632 reason: "Blocked first".to_string(),
633 }),
634 );
635 engine.register_handler("continue-hook", Arc::new(ContinueHandler));
636
637 let event = make_pre_tool_event("s1", "Bash");
638 let result = engine.fire(&event).await;
639
640 assert!(result.is_block());
642 }
643
644 #[test]
645 fn test_hook_serialization() {
646 let hook = Hook::new("test-hook", HookEventType::PreToolUse)
647 .with_matcher(HookMatcher::tool("Bash"))
648 .with_config(HookConfig {
649 priority: 50,
650 timeout_ms: 5000,
651 async_execution: true,
652 max_retries: 3,
653 });
654
655 let json = serde_json::to_string(&hook).unwrap();
656 assert!(json.contains("test-hook"));
657 assert!(json.contains("pre_tool_use"));
658 assert!(json.contains("Bash"));
659
660 let parsed: Hook = serde_json::from_str(&json).unwrap();
661 assert_eq!(parsed.id, "test-hook");
662 assert_eq!(parsed.event_type, HookEventType::PreToolUse);
663 assert_eq!(parsed.config.priority, 50);
664 }
665
666 #[test]
667 fn test_all_hooks() {
668 let engine = HookEngine::new();
669 engine.register(Hook::new("hook-1", HookEventType::PreToolUse));
670 engine.register(Hook::new("hook-2", HookEventType::PostToolUse));
671
672 let all = engine.all_hooks();
673 assert_eq!(all.len(), 2);
674 }
675
676 fn make_skill_load_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
677 HookEvent::SkillLoad(crate::hooks::events::SkillLoadEvent {
678 skill_name: skill_name.to_string(),
679 tool_names: tools.iter().map(|s| s.to_string()).collect(),
680 version: Some("1.0.0".to_string()),
681 description: Some("Test skill".to_string()),
682 loaded_at: 1234567890,
683 })
684 }
685
686 fn make_skill_unload_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
687 HookEvent::SkillUnload(crate::hooks::events::SkillUnloadEvent {
688 skill_name: skill_name.to_string(),
689 tool_names: tools.iter().map(|s| s.to_string()).collect(),
690 duration_ms: 60000,
691 })
692 }
693
694 #[tokio::test]
695 async fn test_engine_fire_skill_load() {
696 let engine = HookEngine::new();
697
698 engine.register(Hook::new("skill-load-hook", HookEventType::SkillLoad));
700 engine.register_handler("skill-load-hook", Arc::new(ContinueHandler));
701
702 let event = make_skill_load_event("my-skill", vec!["tool1", "tool2"]);
703 let result = engine.fire(&event).await;
704
705 assert!(result.is_continue());
706 }
707
708 #[tokio::test]
709 async fn test_engine_fire_skill_unload() {
710 let engine = HookEngine::new();
711
712 engine.register(Hook::new("skill-unload-hook", HookEventType::SkillUnload));
714 engine.register_handler("skill-unload-hook", Arc::new(ContinueHandler));
715
716 let event = make_skill_unload_event("my-skill", vec!["tool1", "tool2"]);
717 let result = engine.fire(&event).await;
718
719 assert!(result.is_continue());
720 }
721
722 #[tokio::test]
723 async fn test_engine_skill_hook_with_matcher() {
724 let engine = HookEngine::new();
725
726 engine.register(
728 Hook::new("specific-skill-hook", HookEventType::SkillLoad)
729 .with_matcher(HookMatcher::skill("my-skill")),
730 );
731 engine.register_handler(
732 "specific-skill-hook",
733 Arc::new(BlockHandler {
734 reason: "Skill blocked".to_string(),
735 }),
736 );
737
738 let matching_event = make_skill_load_event("my-skill", vec!["tool1"]);
740 let result = engine.fire(&matching_event).await;
741 assert!(result.is_block());
742
743 let non_matching_event = make_skill_load_event("other-skill", vec!["tool1"]);
745 let result = engine.fire(&non_matching_event).await;
746 assert!(result.is_continue());
747 }
748
749 #[tokio::test]
750 async fn test_engine_skill_hook_pattern_matcher() {
751 let engine = HookEngine::new();
752
753 engine.register(
755 Hook::new("test-skill-hook", HookEventType::SkillLoad)
756 .with_matcher(HookMatcher::skill("test-*")),
757 );
758 engine.register_handler(
759 "test-skill-hook",
760 Arc::new(BlockHandler {
761 reason: "Test skill blocked".to_string(),
762 }),
763 );
764
765 let test_skill = make_skill_load_event("test-alpha", vec!["tool1"]);
767 let result = engine.fire(&test_skill).await;
768 assert!(result.is_block());
769
770 let test_skill2 = make_skill_load_event("test-beta", vec!["tool1"]);
771 let result = engine.fire(&test_skill2).await;
772 assert!(result.is_block());
773
774 let prod_skill = make_skill_load_event("prod-skill", vec!["tool1"]);
776 let result = engine.fire(&prod_skill).await;
777 assert!(result.is_continue());
778 }
779}