1use super::events::{HookEvent, HookEventType};
6use super::matcher::HookMatcher;
7use super::{HookAction, HookResponse};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use tokio::sync::mpsc;
12
13use crate::error::{read_or_recover, write_or_recover};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct HookConfig {
18 #[serde(default = "default_priority")]
20 pub priority: i32,
21
22 #[serde(default = "default_timeout")]
24 pub timeout_ms: u64,
25
26 #[serde(default)]
28 pub async_execution: bool,
29
30 #[serde(default)]
32 pub max_retries: u32,
33}
34
35fn default_priority() -> i32 {
36 100
37}
38
39fn default_timeout() -> u64 {
40 30000
41}
42
43impl Default for HookConfig {
44 fn default() -> Self {
45 Self {
46 priority: default_priority(),
47 timeout_ms: default_timeout(),
48 async_execution: false,
49 max_retries: 0,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct Hook {
57 pub id: String,
59
60 pub event_type: HookEventType,
62
63 #[serde(skip_serializing_if = "Option::is_none")]
65 pub matcher: Option<HookMatcher>,
66
67 #[serde(default)]
69 pub config: HookConfig,
70}
71
72impl Hook {
73 pub fn new(id: impl Into<String>, event_type: HookEventType) -> Self {
75 Self {
76 id: id.into(),
77 event_type,
78 matcher: None,
79 config: HookConfig::default(),
80 }
81 }
82
83 pub fn with_matcher(mut self, matcher: HookMatcher) -> Self {
85 self.matcher = Some(matcher);
86 self
87 }
88
89 pub fn with_config(mut self, config: HookConfig) -> Self {
91 self.config = config;
92 self
93 }
94
95 pub fn matches(&self, event: &HookEvent) -> bool {
97 if event.event_type() != self.event_type {
99 return false;
100 }
101
102 if let Some(ref matcher) = self.matcher {
104 matcher.matches(event)
105 } else {
106 true
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
113pub enum HookResult {
114 Continue(Option<serde_json::Value>),
116 Block(String),
118 Retry(u64),
120 Skip,
122}
123
124impl HookResult {
125 pub fn continue_() -> Self {
127 Self::Continue(None)
128 }
129
130 pub fn continue_with(modified: serde_json::Value) -> Self {
132 Self::Continue(Some(modified))
133 }
134
135 pub fn block(reason: impl Into<String>) -> Self {
137 Self::Block(reason.into())
138 }
139
140 pub fn retry(delay_ms: u64) -> Self {
142 Self::Retry(delay_ms)
143 }
144
145 pub fn skip() -> Self {
147 Self::Skip
148 }
149
150 pub fn is_continue(&self) -> bool {
152 matches!(self, Self::Continue(_))
153 }
154
155 pub fn is_block(&self) -> bool {
157 matches!(self, Self::Block(_))
158 }
159}
160
161pub trait HookHandler: Send + Sync {
163 fn handle(&self, event: &HookEvent) -> HookResponse;
165}
166
167pub struct HookEngine {
169 hooks: Arc<RwLock<HashMap<String, Hook>>>,
171
172 handlers: Arc<RwLock<HashMap<String, Arc<dyn HookHandler>>>>,
174
175 event_tx: Option<mpsc::Sender<HookEvent>>,
177}
178
179impl std::fmt::Debug for HookEngine {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 f.debug_struct("HookEngine")
182 .field("hooks_count", &read_or_recover(&self.hooks).len())
183 .field("handlers_count", &read_or_recover(&self.handlers).len())
184 .field("has_event_channel", &self.event_tx.is_some())
185 .finish()
186 }
187}
188
189impl Default for HookEngine {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195impl HookEngine {
196 pub fn new() -> Self {
198 Self {
199 hooks: Arc::new(RwLock::new(HashMap::new())),
200 handlers: Arc::new(RwLock::new(HashMap::new())),
201 event_tx: None,
202 }
203 }
204
205 pub fn with_event_channel(mut self, tx: mpsc::Sender<HookEvent>) -> Self {
207 self.event_tx = Some(tx);
208 self
209 }
210
211 pub fn register(&self, hook: Hook) {
213 let mut hooks = write_or_recover(&self.hooks);
214 hooks.insert(hook.id.clone(), hook);
215 }
216
217 pub fn unregister(&self, hook_id: &str) -> Option<Hook> {
219 let mut hooks = write_or_recover(&self.hooks);
220 hooks.remove(hook_id)
221 }
222
223 pub fn register_handler(&self, hook_id: &str, handler: Arc<dyn HookHandler>) {
225 let mut handlers = write_or_recover(&self.handlers);
226 handlers.insert(hook_id.to_string(), handler);
227 }
228
229 pub fn unregister_handler(&self, hook_id: &str) {
231 let mut handlers = write_or_recover(&self.handlers);
232 handlers.remove(hook_id);
233 }
234
235 pub fn matching_hooks(&self, event: &HookEvent) -> Vec<Hook> {
237 let hooks = read_or_recover(&self.hooks);
238 let mut matching: Vec<Hook> = hooks
239 .values()
240 .filter(|h| h.matches(event))
241 .cloned()
242 .collect();
243
244 matching.sort_by_key(|h| h.config.priority);
246 matching
247 }
248
249 pub async fn fire(&self, event: &HookEvent) -> HookResult {
251 if let Some(ref tx) = self.event_tx {
253 let _ = tx.send(event.clone()).await;
254 }
255
256 let matching_hooks = self.matching_hooks(event);
258
259 if matching_hooks.is_empty() {
260 return HookResult::continue_();
261 }
262
263 for hook in matching_hooks {
265 let result = self.execute_hook(&hook, event).await;
266
267 match result {
268 HookResult::Continue(modified) => {
269 if modified.is_some() {
272 return HookResult::Continue(modified);
273 }
274 }
275 HookResult::Block(reason) => {
276 return HookResult::Block(reason);
277 }
278 HookResult::Retry(delay) => {
279 return HookResult::Retry(delay);
280 }
281 HookResult::Skip => {
282 return HookResult::Continue(None);
283 }
284 }
285 }
286
287 HookResult::continue_()
288 }
289
290 async fn execute_hook(&self, hook: &Hook, event: &HookEvent) -> HookResult {
292 let handler = {
294 let handlers = read_or_recover(&self.handlers);
295 handlers.get(&hook.id).cloned()
296 };
297
298 match handler {
299 Some(h) => {
300 let response = if hook.config.async_execution {
302 let h = h.clone();
304 let event = event.clone();
305 tokio::spawn(async move {
306 h.handle(&event);
307 });
308 HookResponse::continue_()
309 } else {
310 let timeout = std::time::Duration::from_millis(hook.config.timeout_ms);
312 let h = h.clone();
313 let event = event.clone();
314
315 match tokio::time::timeout(timeout, async move { h.handle(&event) }).await {
316 Ok(response) => response,
317 Err(_) => {
318 HookResponse::continue_()
320 }
321 }
322 };
323
324 self.response_to_result(response)
325 }
326 None => {
327 HookResult::continue_()
329 }
330 }
331 }
332
333 fn response_to_result(&self, response: HookResponse) -> HookResult {
335 match response.action {
336 HookAction::Continue => HookResult::Continue(response.modified),
337 HookAction::Block => {
338 HookResult::Block(response.reason.unwrap_or_else(|| "Blocked".to_string()))
339 }
340 HookAction::Retry => HookResult::Retry(response.retry_delay_ms.unwrap_or(1000)),
341 HookAction::Skip => HookResult::Skip,
342 }
343 }
344
345 pub fn hook_count(&self) -> usize {
347 read_or_recover(&self.hooks).len()
348 }
349
350 pub fn get_hook(&self, id: &str) -> Option<Hook> {
352 read_or_recover(&self.hooks).get(id).cloned()
353 }
354
355 pub fn all_hooks(&self) -> Vec<Hook> {
357 read_or_recover(&self.hooks).values().cloned().collect()
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use crate::hooks::events::PreToolUseEvent;
365
366 fn make_pre_tool_event(session_id: &str, tool: &str) -> HookEvent {
367 HookEvent::PreToolUse(PreToolUseEvent {
368 session_id: session_id.to_string(),
369 tool: tool.to_string(),
370 args: serde_json::json!({}),
371 working_directory: "/workspace".to_string(),
372 recent_tools: vec![],
373 })
374 }
375
376 #[test]
377 fn test_hook_config_default() {
378 let config = HookConfig::default();
379 assert_eq!(config.priority, 100);
380 assert_eq!(config.timeout_ms, 30000);
381 assert!(!config.async_execution);
382 assert_eq!(config.max_retries, 0);
383 }
384
385 #[test]
386 fn test_hook_new() {
387 let hook = Hook::new("test-hook", HookEventType::PreToolUse);
388 assert_eq!(hook.id, "test-hook");
389 assert_eq!(hook.event_type, HookEventType::PreToolUse);
390 assert!(hook.matcher.is_none());
391 }
392
393 #[test]
394 fn test_hook_with_matcher() {
395 let hook = Hook::new("test-hook", HookEventType::PreToolUse)
396 .with_matcher(HookMatcher::tool("Bash"));
397
398 assert!(hook.matcher.is_some());
399 assert_eq!(hook.matcher.unwrap().tool, Some("Bash".to_string()));
400 }
401
402 #[test]
403 fn test_hook_matches_event_type() {
404 let hook = Hook::new("test-hook", HookEventType::PreToolUse);
405
406 let pre_event = make_pre_tool_event("s1", "Bash");
407 assert!(hook.matches(&pre_event));
408
409 let post_event = HookEvent::PostToolUse(crate::hooks::events::PostToolUseEvent {
411 session_id: "s1".to_string(),
412 tool: "Bash".to_string(),
413 args: serde_json::json!({}),
414 result: crate::hooks::events::ToolResultData {
415 success: true,
416 output: "".to_string(),
417 exit_code: Some(0),
418 duration_ms: 100,
419 },
420 });
421 assert!(!hook.matches(&post_event));
422 }
423
424 #[test]
425 fn test_hook_matches_with_matcher() {
426 let hook = Hook::new("test-hook", HookEventType::PreToolUse)
427 .with_matcher(HookMatcher::tool("Bash"));
428
429 let bash_event = make_pre_tool_event("s1", "Bash");
430 let read_event = make_pre_tool_event("s1", "Read");
431
432 assert!(hook.matches(&bash_event));
433 assert!(!hook.matches(&read_event));
434 }
435
436 #[test]
437 fn test_hook_result_constructors() {
438 let cont = HookResult::continue_();
439 assert!(cont.is_continue());
440 assert!(!cont.is_block());
441
442 let cont_with = HookResult::continue_with(serde_json::json!({"key": "value"}));
443 assert!(cont_with.is_continue());
444
445 let block = HookResult::block("Blocked");
446 assert!(block.is_block());
447 assert!(!block.is_continue());
448
449 let retry = HookResult::retry(1000);
450 assert!(!retry.is_continue());
451 assert!(!retry.is_block());
452
453 let skip = HookResult::skip();
454 assert!(!skip.is_continue());
455 assert!(!skip.is_block());
456 }
457
458 #[test]
459 fn test_engine_register_unregister() {
460 let engine = HookEngine::new();
461
462 let hook = Hook::new("test-hook", HookEventType::PreToolUse);
463 engine.register(hook);
464
465 assert_eq!(engine.hook_count(), 1);
466 assert!(engine.get_hook("test-hook").is_some());
467
468 let removed = engine.unregister("test-hook");
469 assert!(removed.is_some());
470 assert_eq!(engine.hook_count(), 0);
471 }
472
473 #[test]
474 fn test_engine_matching_hooks() {
475 let engine = HookEngine::new();
476
477 engine.register(
479 Hook::new("hook-1", HookEventType::PreToolUse).with_config(HookConfig {
480 priority: 10,
481 ..Default::default()
482 }),
483 );
484 engine.register(
485 Hook::new("hook-2", HookEventType::PreToolUse)
486 .with_matcher(HookMatcher::tool("Bash"))
487 .with_config(HookConfig {
488 priority: 5,
489 ..Default::default()
490 }),
491 );
492 engine.register(Hook::new("hook-3", HookEventType::PostToolUse));
493
494 let event = make_pre_tool_event("s1", "Bash");
495 let matching = engine.matching_hooks(&event);
496
497 assert_eq!(matching.len(), 2);
499
500 assert_eq!(matching[0].id, "hook-2");
502 assert_eq!(matching[1].id, "hook-1");
503 }
504
505 #[tokio::test]
506 async fn test_engine_fire_no_hooks() {
507 let engine = HookEngine::new();
508 let event = make_pre_tool_event("s1", "Bash");
509
510 let result = engine.fire(&event).await;
511 assert!(result.is_continue());
512 }
513
514 #[tokio::test]
515 async fn test_engine_fire_no_handler() {
516 let engine = HookEngine::new();
517 engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
518
519 let event = make_pre_tool_event("s1", "Bash");
520 let result = engine.fire(&event).await;
521
522 assert!(result.is_continue());
524 }
525
526 struct ContinueHandler;
528 impl HookHandler for ContinueHandler {
529 fn handle(&self, _event: &HookEvent) -> HookResponse {
530 HookResponse::continue_()
531 }
532 }
533
534 struct BlockHandler {
536 reason: String,
537 }
538 impl HookHandler for BlockHandler {
539 fn handle(&self, _event: &HookEvent) -> HookResponse {
540 HookResponse::block(&self.reason)
541 }
542 }
543
544 #[tokio::test]
545 async fn test_engine_fire_with_continue_handler() {
546 let engine = HookEngine::new();
547 engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
548 engine.register_handler("test-hook", Arc::new(ContinueHandler));
549
550 let event = make_pre_tool_event("s1", "Bash");
551 let result = engine.fire(&event).await;
552
553 assert!(result.is_continue());
554 }
555
556 #[tokio::test]
557 async fn test_engine_fire_with_block_handler() {
558 let engine = HookEngine::new();
559 engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
560 engine.register_handler(
561 "test-hook",
562 Arc::new(BlockHandler {
563 reason: "Dangerous command".to_string(),
564 }),
565 );
566
567 let event = make_pre_tool_event("s1", "Bash");
568 let result = engine.fire(&event).await;
569
570 assert!(result.is_block());
571 if let HookResult::Block(reason) = result {
572 assert_eq!(reason, "Dangerous command");
573 }
574 }
575
576 #[tokio::test]
577 async fn test_engine_fire_priority_order() {
578 let engine = HookEngine::new();
579
580 engine.register(
582 Hook::new("block-hook", HookEventType::PreToolUse).with_config(HookConfig {
583 priority: 5, ..Default::default()
585 }),
586 );
587 engine.register(
588 Hook::new("continue-hook", HookEventType::PreToolUse).with_config(HookConfig {
589 priority: 10,
590 ..Default::default()
591 }),
592 );
593
594 engine.register_handler(
595 "block-hook",
596 Arc::new(BlockHandler {
597 reason: "Blocked first".to_string(),
598 }),
599 );
600 engine.register_handler("continue-hook", Arc::new(ContinueHandler));
601
602 let event = make_pre_tool_event("s1", "Bash");
603 let result = engine.fire(&event).await;
604
605 assert!(result.is_block());
607 }
608
609 #[test]
610 fn test_hook_serialization() {
611 let hook = Hook::new("test-hook", HookEventType::PreToolUse)
612 .with_matcher(HookMatcher::tool("Bash"))
613 .with_config(HookConfig {
614 priority: 50,
615 timeout_ms: 5000,
616 async_execution: true,
617 max_retries: 3,
618 });
619
620 let json = serde_json::to_string(&hook).unwrap();
621 assert!(json.contains("test-hook"));
622 assert!(json.contains("pre_tool_use"));
623 assert!(json.contains("Bash"));
624
625 let parsed: Hook = serde_json::from_str(&json).unwrap();
626 assert_eq!(parsed.id, "test-hook");
627 assert_eq!(parsed.event_type, HookEventType::PreToolUse);
628 assert_eq!(parsed.config.priority, 50);
629 }
630
631 #[test]
632 fn test_all_hooks() {
633 let engine = HookEngine::new();
634 engine.register(Hook::new("hook-1", HookEventType::PreToolUse));
635 engine.register(Hook::new("hook-2", HookEventType::PostToolUse));
636
637 let all = engine.all_hooks();
638 assert_eq!(all.len(), 2);
639 }
640
641 fn make_skill_load_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
642 HookEvent::SkillLoad(crate::hooks::events::SkillLoadEvent {
643 skill_name: skill_name.to_string(),
644 tool_names: tools.iter().map(|s| s.to_string()).collect(),
645 version: Some("1.0.0".to_string()),
646 description: Some("Test skill".to_string()),
647 loaded_at: 1234567890,
648 })
649 }
650
651 fn make_skill_unload_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
652 HookEvent::SkillUnload(crate::hooks::events::SkillUnloadEvent {
653 skill_name: skill_name.to_string(),
654 tool_names: tools.iter().map(|s| s.to_string()).collect(),
655 duration_ms: 60000,
656 })
657 }
658
659 #[tokio::test]
660 async fn test_engine_fire_skill_load() {
661 let engine = HookEngine::new();
662
663 engine.register(Hook::new("skill-load-hook", HookEventType::SkillLoad));
665 engine.register_handler("skill-load-hook", Arc::new(ContinueHandler));
666
667 let event = make_skill_load_event("my-skill", vec!["tool1", "tool2"]);
668 let result = engine.fire(&event).await;
669
670 assert!(result.is_continue());
671 }
672
673 #[tokio::test]
674 async fn test_engine_fire_skill_unload() {
675 let engine = HookEngine::new();
676
677 engine.register(Hook::new("skill-unload-hook", HookEventType::SkillUnload));
679 engine.register_handler("skill-unload-hook", Arc::new(ContinueHandler));
680
681 let event = make_skill_unload_event("my-skill", vec!["tool1", "tool2"]);
682 let result = engine.fire(&event).await;
683
684 assert!(result.is_continue());
685 }
686
687 #[tokio::test]
688 async fn test_engine_skill_hook_with_matcher() {
689 let engine = HookEngine::new();
690
691 engine.register(
693 Hook::new("specific-skill-hook", HookEventType::SkillLoad)
694 .with_matcher(HookMatcher::skill("my-skill")),
695 );
696 engine.register_handler(
697 "specific-skill-hook",
698 Arc::new(BlockHandler {
699 reason: "Skill blocked".to_string(),
700 }),
701 );
702
703 let matching_event = make_skill_load_event("my-skill", vec!["tool1"]);
705 let result = engine.fire(&matching_event).await;
706 assert!(result.is_block());
707
708 let non_matching_event = make_skill_load_event("other-skill", vec!["tool1"]);
710 let result = engine.fire(&non_matching_event).await;
711 assert!(result.is_continue());
712 }
713
714 #[tokio::test]
715 async fn test_engine_skill_hook_pattern_matcher() {
716 let engine = HookEngine::new();
717
718 engine.register(
720 Hook::new("test-skill-hook", HookEventType::SkillLoad)
721 .with_matcher(HookMatcher::skill("test-*")),
722 );
723 engine.register_handler(
724 "test-skill-hook",
725 Arc::new(BlockHandler {
726 reason: "Test skill blocked".to_string(),
727 }),
728 );
729
730 let test_skill = make_skill_load_event("test-alpha", vec!["tool1"]);
732 let result = engine.fire(&test_skill).await;
733 assert!(result.is_block());
734
735 let test_skill2 = make_skill_load_event("test-beta", vec!["tool1"]);
736 let result = engine.fire(&test_skill2).await;
737 assert!(result.is_block());
738
739 let prod_skill = make_skill_load_event("prod-skill", vec!["tool1"]);
741 let result = engine.fire(&prod_skill).await;
742 assert!(result.is_continue());
743 }
744}