1use super::events::{HookEvent, HookEventType};
6use super::matcher::HookMatcher;
7use super::{HookAction, HookResponse};
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use tokio::sync::mpsc;
13
14use crate::error::{read_or_recover, write_or_recover};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HookConfig {
19 #[serde(default = "default_priority")]
21 pub priority: i32,
22
23 #[serde(default = "default_timeout")]
25 pub timeout_ms: u64,
26
27 #[serde(default)]
29 pub async_execution: bool,
30
31 #[serde(default)]
33 pub max_retries: u32,
34}
35
36fn default_priority() -> i32 {
37 100
38}
39
40fn default_timeout() -> u64 {
41 30000
42}
43
44impl Default for HookConfig {
45 fn default() -> Self {
46 Self {
47 priority: default_priority(),
48 timeout_ms: default_timeout(),
49 async_execution: false,
50 max_retries: 0,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct Hook {
58 pub id: String,
60
61 pub event_type: HookEventType,
63
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub matcher: Option<HookMatcher>,
67
68 #[serde(default)]
70 pub config: HookConfig,
71}
72
73impl Hook {
74 pub fn new(id: impl Into<String>, event_type: HookEventType) -> Self {
76 Self {
77 id: id.into(),
78 event_type,
79 matcher: None,
80 config: HookConfig::default(),
81 }
82 }
83
84 pub fn with_matcher(mut self, matcher: HookMatcher) -> Self {
86 self.matcher = Some(matcher);
87 self
88 }
89
90 pub fn with_config(mut self, config: HookConfig) -> Self {
92 self.config = config;
93 self
94 }
95
96 pub fn matches(&self, event: &HookEvent) -> bool {
98 if event.event_type() != self.event_type {
100 return false;
101 }
102
103 if let Some(ref matcher) = self.matcher {
105 matcher.matches(event)
106 } else {
107 true
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
114pub enum HookResult {
115 Continue(Option<serde_json::Value>),
117 Block(String),
119 Retry(u64),
121 Skip,
123}
124
125impl HookResult {
126 pub fn continue_() -> Self {
128 Self::Continue(None)
129 }
130
131 pub fn continue_with(modified: serde_json::Value) -> Self {
133 Self::Continue(Some(modified))
134 }
135
136 pub fn block(reason: impl Into<String>) -> Self {
138 Self::Block(reason.into())
139 }
140
141 pub fn retry(delay_ms: u64) -> Self {
143 Self::Retry(delay_ms)
144 }
145
146 pub fn skip() -> Self {
148 Self::Skip
149 }
150
151 pub fn is_continue(&self) -> bool {
153 matches!(self, Self::Continue(_))
154 }
155
156 pub fn is_block(&self) -> bool {
158 matches!(self, Self::Block(_))
159 }
160}
161
162pub trait HookHandler: Send + Sync {
164 fn handle(&self, event: &HookEvent) -> HookResponse;
166}
167
168#[async_trait::async_trait]
173pub trait HookExecutor: Send + Sync + std::fmt::Debug {
174 async fn fire(&self, event: &HookEvent) -> HookResult;
176}
177
178pub struct HookEngine {
180 hooks: Arc<RwLock<HashMap<String, Hook>>>,
182
183 handlers: Arc<RwLock<HashMap<String, Arc<dyn HookHandler>>>>,
185
186 event_tx: Option<mpsc::Sender<HookEvent>>,
188}
189
190impl std::fmt::Debug for HookEngine {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 f.debug_struct("HookEngine")
193 .field("hooks_count", &read_or_recover(&self.hooks).len())
194 .field("handlers_count", &read_or_recover(&self.handlers).len())
195 .field("has_event_channel", &self.event_tx.is_some())
196 .finish()
197 }
198}
199
200impl Default for HookEngine {
201 fn default() -> Self {
202 Self::new()
203 }
204}
205
206impl HookEngine {
207 pub fn new() -> Self {
209 Self {
210 hooks: Arc::new(RwLock::new(HashMap::new())),
211 handlers: Arc::new(RwLock::new(HashMap::new())),
212 event_tx: None,
213 }
214 }
215
216 pub fn with_event_channel(mut self, tx: mpsc::Sender<HookEvent>) -> Self {
218 self.event_tx = Some(tx);
219 self
220 }
221
222 pub fn register(&self, hook: Hook) {
224 let mut hooks = write_or_recover(&self.hooks);
225 hooks.insert(hook.id.clone(), hook);
226 }
227
228 pub fn unregister(&self, hook_id: &str) -> Option<Hook> {
230 let mut hooks = write_or_recover(&self.hooks);
231 hooks.remove(hook_id)
232 }
233
234 pub fn register_handler(&self, hook_id: &str, handler: Arc<dyn HookHandler>) {
236 let mut handlers = write_or_recover(&self.handlers);
237 handlers.insert(hook_id.to_string(), handler);
238 }
239
240 pub fn unregister_handler(&self, hook_id: &str) {
242 let mut handlers = write_or_recover(&self.handlers);
243 handlers.remove(hook_id);
244 }
245
246 pub fn matching_hooks(&self, event: &HookEvent) -> Vec<Hook> {
248 let hooks = read_or_recover(&self.hooks);
249 let mut matching: Vec<Hook> = hooks
250 .values()
251 .filter(|h| h.matches(event))
252 .cloned()
253 .collect();
254
255 matching.sort_by_key(|h| h.config.priority);
257 matching
258 }
259
260 pub async fn fire(&self, event: &HookEvent) -> HookResult {
262 if let Some(ref tx) = self.event_tx {
264 let _ = tx.send(event.clone()).await;
265 }
266
267 let matching_hooks = self.matching_hooks(event);
269
270 if matching_hooks.is_empty() {
271 return HookResult::continue_();
272 }
273
274 for hook in matching_hooks {
276 let result = self.execute_hook(&hook, event).await;
277
278 match result {
279 HookResult::Continue(modified) => {
280 if modified.is_some() {
283 return HookResult::Continue(modified);
284 }
285 }
286 HookResult::Block(reason) => {
287 return HookResult::Block(reason);
288 }
289 HookResult::Retry(delay) => {
290 return HookResult::Retry(delay);
291 }
292 HookResult::Skip => {
293 return HookResult::Continue(None);
294 }
295 }
296 }
297
298 HookResult::continue_()
299 }
300
301 async fn execute_hook(&self, hook: &Hook, event: &HookEvent) -> HookResult {
303 let handler = {
305 let handlers = read_or_recover(&self.handlers);
306 handlers.get(&hook.id).cloned()
307 };
308
309 match handler {
310 Some(h) => {
311 let response = if hook.config.async_execution {
313 let h = h.clone();
315 let event = event.clone();
316 tokio::spawn(async move {
317 h.handle(&event);
318 });
319 HookResponse::continue_()
320 } else {
321 let timeout = std::time::Duration::from_millis(hook.config.timeout_ms);
323 let h = h.clone();
324 let event = event.clone();
325
326 match tokio::time::timeout(timeout, async move { h.handle(&event) }).await {
327 Ok(response) => response,
328 Err(_) => {
329 HookResponse::continue_()
331 }
332 }
333 };
334
335 self.response_to_result(response)
336 }
337 None => {
338 HookResult::continue_()
340 }
341 }
342 }
343
344 fn response_to_result(&self, response: HookResponse) -> HookResult {
346 match response.action {
347 HookAction::Continue => HookResult::Continue(response.modified),
348 HookAction::Block => {
349 HookResult::Block(response.reason.unwrap_or_else(|| "Blocked".to_string()))
350 }
351 HookAction::Retry => HookResult::Retry(response.retry_delay_ms.unwrap_or(1000)),
352 HookAction::Skip => HookResult::Skip,
353 }
354 }
355
356 pub fn hook_count(&self) -> usize {
358 read_or_recover(&self.hooks).len()
359 }
360
361 pub fn get_hook(&self, id: &str) -> Option<Hook> {
363 read_or_recover(&self.hooks).get(id).cloned()
364 }
365
366 pub fn all_hooks(&self) -> Vec<Hook> {
368 read_or_recover(&self.hooks).values().cloned().collect()
369 }
370}
371
372#[async_trait]
374impl HookExecutor for HookEngine {
375 async fn fire(&self, event: &HookEvent) -> HookResult {
376 self.fire(event).await
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383 use crate::hooks::events::PreToolUseEvent;
384
385 fn make_pre_tool_event(session_id: &str, tool: &str) -> HookEvent {
386 HookEvent::PreToolUse(PreToolUseEvent {
387 session_id: session_id.to_string(),
388 tool: tool.to_string(),
389 args: serde_json::json!({}),
390 working_directory: "/workspace".to_string(),
391 recent_tools: vec![],
392 })
393 }
394
395 #[test]
396 fn test_hook_config_default() {
397 let config = HookConfig::default();
398 assert_eq!(config.priority, 100);
399 assert_eq!(config.timeout_ms, 30000);
400 assert!(!config.async_execution);
401 assert_eq!(config.max_retries, 0);
402 }
403
404 #[test]
405 fn test_hook_new() {
406 let hook = Hook::new("test-hook", HookEventType::PreToolUse);
407 assert_eq!(hook.id, "test-hook");
408 assert_eq!(hook.event_type, HookEventType::PreToolUse);
409 assert!(hook.matcher.is_none());
410 }
411
412 #[test]
413 fn test_hook_with_matcher() {
414 let hook = Hook::new("test-hook", HookEventType::PreToolUse)
415 .with_matcher(HookMatcher::tool("Bash"));
416
417 assert!(hook.matcher.is_some());
418 assert_eq!(hook.matcher.unwrap().tool, Some("Bash".to_string()));
419 }
420
421 #[test]
422 fn test_hook_matches_event_type() {
423 let hook = Hook::new("test-hook", HookEventType::PreToolUse);
424
425 let pre_event = make_pre_tool_event("s1", "Bash");
426 assert!(hook.matches(&pre_event));
427
428 let post_event = HookEvent::PostToolUse(crate::hooks::events::PostToolUseEvent {
430 session_id: "s1".to_string(),
431 tool: "Bash".to_string(),
432 args: serde_json::json!({}),
433 result: crate::hooks::events::ToolResultData {
434 success: true,
435 output: "".to_string(),
436 exit_code: Some(0),
437 duration_ms: 100,
438 },
439 });
440 assert!(!hook.matches(&post_event));
441 }
442
443 #[test]
444 fn test_hook_matches_with_matcher() {
445 let hook = Hook::new("test-hook", HookEventType::PreToolUse)
446 .with_matcher(HookMatcher::tool("Bash"));
447
448 let bash_event = make_pre_tool_event("s1", "Bash");
449 let read_event = make_pre_tool_event("s1", "Read");
450
451 assert!(hook.matches(&bash_event));
452 assert!(!hook.matches(&read_event));
453 }
454
455 #[test]
456 fn test_hook_result_constructors() {
457 let cont = HookResult::continue_();
458 assert!(cont.is_continue());
459 assert!(!cont.is_block());
460
461 let cont_with = HookResult::continue_with(serde_json::json!({"key": "value"}));
462 assert!(cont_with.is_continue());
463
464 let block = HookResult::block("Blocked");
465 assert!(block.is_block());
466 assert!(!block.is_continue());
467
468 let retry = HookResult::retry(1000);
469 assert!(!retry.is_continue());
470 assert!(!retry.is_block());
471
472 let skip = HookResult::skip();
473 assert!(!skip.is_continue());
474 assert!(!skip.is_block());
475 }
476
477 #[test]
478 fn test_engine_register_unregister() {
479 let engine = HookEngine::new();
480
481 let hook = Hook::new("test-hook", HookEventType::PreToolUse);
482 engine.register(hook);
483
484 assert_eq!(engine.hook_count(), 1);
485 assert!(engine.get_hook("test-hook").is_some());
486
487 let removed = engine.unregister("test-hook");
488 assert!(removed.is_some());
489 assert_eq!(engine.hook_count(), 0);
490 }
491
492 #[test]
493 fn test_engine_matching_hooks() {
494 let engine = HookEngine::new();
495
496 engine.register(
498 Hook::new("hook-1", HookEventType::PreToolUse).with_config(HookConfig {
499 priority: 10,
500 ..Default::default()
501 }),
502 );
503 engine.register(
504 Hook::new("hook-2", HookEventType::PreToolUse)
505 .with_matcher(HookMatcher::tool("Bash"))
506 .with_config(HookConfig {
507 priority: 5,
508 ..Default::default()
509 }),
510 );
511 engine.register(Hook::new("hook-3", HookEventType::PostToolUse));
512
513 let event = make_pre_tool_event("s1", "Bash");
514 let matching = engine.matching_hooks(&event);
515
516 assert_eq!(matching.len(), 2);
518
519 assert_eq!(matching[0].id, "hook-2");
521 assert_eq!(matching[1].id, "hook-1");
522 }
523
524 #[tokio::test]
525 async fn test_engine_fire_no_hooks() {
526 let engine = HookEngine::new();
527 let event = make_pre_tool_event("s1", "Bash");
528
529 let result = engine.fire(&event).await;
530 assert!(result.is_continue());
531 }
532
533 #[tokio::test]
534 async fn test_engine_fire_no_handler() {
535 let engine = HookEngine::new();
536 engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
537
538 let event = make_pre_tool_event("s1", "Bash");
539 let result = engine.fire(&event).await;
540
541 assert!(result.is_continue());
543 }
544
545 struct ContinueHandler;
547 impl HookHandler for ContinueHandler {
548 fn handle(&self, _event: &HookEvent) -> HookResponse {
549 HookResponse::continue_()
550 }
551 }
552
553 struct BlockHandler {
555 reason: String,
556 }
557 impl HookHandler for BlockHandler {
558 fn handle(&self, _event: &HookEvent) -> HookResponse {
559 HookResponse::block(&self.reason)
560 }
561 }
562
563 #[tokio::test]
564 async fn test_engine_fire_with_continue_handler() {
565 let engine = HookEngine::new();
566 engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
567 engine.register_handler("test-hook", Arc::new(ContinueHandler));
568
569 let event = make_pre_tool_event("s1", "Bash");
570 let result = engine.fire(&event).await;
571
572 assert!(result.is_continue());
573 }
574
575 #[tokio::test]
576 async fn test_engine_fire_with_block_handler() {
577 let engine = HookEngine::new();
578 engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
579 engine.register_handler(
580 "test-hook",
581 Arc::new(BlockHandler {
582 reason: "Dangerous command".to_string(),
583 }),
584 );
585
586 let event = make_pre_tool_event("s1", "Bash");
587 let result = engine.fire(&event).await;
588
589 assert!(result.is_block());
590 if let HookResult::Block(reason) = result {
591 assert_eq!(reason, "Dangerous command");
592 }
593 }
594
595 #[tokio::test]
596 async fn test_engine_fire_priority_order() {
597 let engine = HookEngine::new();
598
599 engine.register(
601 Hook::new("block-hook", HookEventType::PreToolUse).with_config(HookConfig {
602 priority: 5, ..Default::default()
604 }),
605 );
606 engine.register(
607 Hook::new("continue-hook", HookEventType::PreToolUse).with_config(HookConfig {
608 priority: 10,
609 ..Default::default()
610 }),
611 );
612
613 engine.register_handler(
614 "block-hook",
615 Arc::new(BlockHandler {
616 reason: "Blocked first".to_string(),
617 }),
618 );
619 engine.register_handler("continue-hook", Arc::new(ContinueHandler));
620
621 let event = make_pre_tool_event("s1", "Bash");
622 let result = engine.fire(&event).await;
623
624 assert!(result.is_block());
626 }
627
628 #[test]
629 fn test_hook_serialization() {
630 let hook = Hook::new("test-hook", HookEventType::PreToolUse)
631 .with_matcher(HookMatcher::tool("Bash"))
632 .with_config(HookConfig {
633 priority: 50,
634 timeout_ms: 5000,
635 async_execution: true,
636 max_retries: 3,
637 });
638
639 let json = serde_json::to_string(&hook).unwrap();
640 assert!(json.contains("test-hook"));
641 assert!(json.contains("pre_tool_use"));
642 assert!(json.contains("Bash"));
643
644 let parsed: Hook = serde_json::from_str(&json).unwrap();
645 assert_eq!(parsed.id, "test-hook");
646 assert_eq!(parsed.event_type, HookEventType::PreToolUse);
647 assert_eq!(parsed.config.priority, 50);
648 }
649
650 #[test]
651 fn test_all_hooks() {
652 let engine = HookEngine::new();
653 engine.register(Hook::new("hook-1", HookEventType::PreToolUse));
654 engine.register(Hook::new("hook-2", HookEventType::PostToolUse));
655
656 let all = engine.all_hooks();
657 assert_eq!(all.len(), 2);
658 }
659
660 fn make_skill_load_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
661 HookEvent::SkillLoad(crate::hooks::events::SkillLoadEvent {
662 skill_name: skill_name.to_string(),
663 tool_names: tools.iter().map(|s| s.to_string()).collect(),
664 version: Some("1.0.0".to_string()),
665 description: Some("Test skill".to_string()),
666 loaded_at: 1234567890,
667 })
668 }
669
670 fn make_skill_unload_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
671 HookEvent::SkillUnload(crate::hooks::events::SkillUnloadEvent {
672 skill_name: skill_name.to_string(),
673 tool_names: tools.iter().map(|s| s.to_string()).collect(),
674 duration_ms: 60000,
675 })
676 }
677
678 #[tokio::test]
679 async fn test_engine_fire_skill_load() {
680 let engine = HookEngine::new();
681
682 engine.register(Hook::new("skill-load-hook", HookEventType::SkillLoad));
684 engine.register_handler("skill-load-hook", Arc::new(ContinueHandler));
685
686 let event = make_skill_load_event("my-skill", vec!["tool1", "tool2"]);
687 let result = engine.fire(&event).await;
688
689 assert!(result.is_continue());
690 }
691
692 #[tokio::test]
693 async fn test_engine_fire_skill_unload() {
694 let engine = HookEngine::new();
695
696 engine.register(Hook::new("skill-unload-hook", HookEventType::SkillUnload));
698 engine.register_handler("skill-unload-hook", Arc::new(ContinueHandler));
699
700 let event = make_skill_unload_event("my-skill", vec!["tool1", "tool2"]);
701 let result = engine.fire(&event).await;
702
703 assert!(result.is_continue());
704 }
705
706 #[tokio::test]
707 async fn test_engine_skill_hook_with_matcher() {
708 let engine = HookEngine::new();
709
710 engine.register(
712 Hook::new("specific-skill-hook", HookEventType::SkillLoad)
713 .with_matcher(HookMatcher::skill("my-skill")),
714 );
715 engine.register_handler(
716 "specific-skill-hook",
717 Arc::new(BlockHandler {
718 reason: "Skill blocked".to_string(),
719 }),
720 );
721
722 let matching_event = make_skill_load_event("my-skill", vec!["tool1"]);
724 let result = engine.fire(&matching_event).await;
725 assert!(result.is_block());
726
727 let non_matching_event = make_skill_load_event("other-skill", vec!["tool1"]);
729 let result = engine.fire(&non_matching_event).await;
730 assert!(result.is_continue());
731 }
732
733 #[tokio::test]
734 async fn test_engine_skill_hook_pattern_matcher() {
735 let engine = HookEngine::new();
736
737 engine.register(
739 Hook::new("test-skill-hook", HookEventType::SkillLoad)
740 .with_matcher(HookMatcher::skill("test-*")),
741 );
742 engine.register_handler(
743 "test-skill-hook",
744 Arc::new(BlockHandler {
745 reason: "Test skill blocked".to_string(),
746 }),
747 );
748
749 let test_skill = make_skill_load_event("test-alpha", vec!["tool1"]);
751 let result = engine.fire(&test_skill).await;
752 assert!(result.is_block());
753
754 let test_skill2 = make_skill_load_event("test-beta", vec!["tool1"]);
755 let result = engine.fire(&test_skill2).await;
756 assert!(result.is_block());
757
758 let prod_skill = make_skill_load_event("prod-skill", vec!["tool1"]);
760 let result = engine.fire(&prod_skill).await;
761 assert!(result.is_continue());
762 }
763}