Skip to main content

a3s_code_core/hooks/
engine.rs

1//! Hook Engine
2//!
3//! Core engine responsible for managing and executing hooks.
4
5use super::events::{HookEvent, HookEventType};
6use super::matcher::HookMatcher;
7use super::{HookAction, HookResponse};
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use tokio::sync::mpsc;
13
14use crate::error::{read_or_recover, write_or_recover};
15
16/// Hook configuration
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HookConfig {
19    /// Priority (lower values = higher priority)
20    #[serde(default = "default_priority")]
21    pub priority: i32,
22
23    /// Timeout in milliseconds
24    #[serde(default = "default_timeout")]
25    pub timeout_ms: u64,
26
27    /// Whether to execute asynchronously (fire-and-forget)
28    #[serde(default)]
29    pub async_execution: bool,
30
31    /// Maximum retry attempts
32    #[serde(default)]
33    pub max_retries: u32,
34}
35
36fn default_priority() -> i32 {
37    100
38}
39
40fn default_timeout() -> u64 {
41    30000
42}
43
44impl Default for HookConfig {
45    fn default() -> Self {
46        Self {
47            priority: default_priority(),
48            timeout_ms: default_timeout(),
49            async_execution: false,
50            max_retries: 0,
51        }
52    }
53}
54
55/// Hook definition
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct Hook {
58    /// Unique hook identifier
59    pub id: String,
60
61    /// Event type that triggers this hook
62    pub event_type: HookEventType,
63
64    /// Event matcher (optional, None matches all events)
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub matcher: Option<HookMatcher>,
67
68    /// Hook configuration
69    #[serde(default)]
70    pub config: HookConfig,
71}
72
73impl Hook {
74    /// Create a new hook
75    pub fn new(id: impl Into<String>, event_type: HookEventType) -> Self {
76        Self {
77            id: id.into(),
78            event_type,
79            matcher: None,
80            config: HookConfig::default(),
81        }
82    }
83
84    /// Set the matcher
85    pub fn with_matcher(mut self, matcher: HookMatcher) -> Self {
86        self.matcher = Some(matcher);
87        self
88    }
89
90    /// Set the configuration
91    pub fn with_config(mut self, config: HookConfig) -> Self {
92        self.config = config;
93        self
94    }
95
96    /// Check if an event matches this hook
97    pub fn matches(&self, event: &HookEvent) -> bool {
98        // First check event type
99        if event.event_type() != self.event_type {
100            return false;
101        }
102
103        // If there's a matcher, check it
104        if let Some(ref matcher) = self.matcher {
105            matcher.matches(event)
106        } else {
107            true
108        }
109    }
110}
111
112/// Hook execution result
113#[derive(Debug, Clone)]
114pub enum HookResult {
115    /// Continue execution (with optional modified data)
116    Continue(Option<serde_json::Value>),
117    /// Block execution
118    Block(String),
119    /// Retry after delay (milliseconds)
120    Retry(u64),
121    /// Skip remaining hooks but continue execution
122    Skip,
123}
124
125impl HookResult {
126    /// Create a continue result
127    pub fn continue_() -> Self {
128        Self::Continue(None)
129    }
130
131    /// Create a continue result with modifications
132    pub fn continue_with(modified: serde_json::Value) -> Self {
133        Self::Continue(Some(modified))
134    }
135
136    /// Create a block result
137    pub fn block(reason: impl Into<String>) -> Self {
138        Self::Block(reason.into())
139    }
140
141    /// Create a retry result
142    pub fn retry(delay_ms: u64) -> Self {
143        Self::Retry(delay_ms)
144    }
145
146    /// Create a skip result
147    pub fn skip() -> Self {
148        Self::Skip
149    }
150
151    /// Check if this is a continue result
152    pub fn is_continue(&self) -> bool {
153        matches!(self, Self::Continue(_))
154    }
155
156    /// Check if this is a block result
157    pub fn is_block(&self) -> bool {
158        matches!(self, Self::Block(_))
159    }
160}
161
162/// Hook handler trait
163pub trait HookHandler: Send + Sync {
164    /// Handle a hook event
165    fn handle(&self, event: &HookEvent) -> HookResponse;
166}
167
168/// Hook executor trait
169///
170/// Abstracts hook execution, allowing different implementations
171/// (e.g., full engine, no-op, test mocks) while keeping agent logic clean.
172#[async_trait::async_trait]
173pub trait HookExecutor: Send + Sync + std::fmt::Debug {
174    /// Fire a hook event and get the result
175    async fn fire(&self, event: &HookEvent) -> HookResult;
176}
177
178/// Hook engine
179pub struct HookEngine {
180    /// Registered hooks
181    hooks: Arc<RwLock<HashMap<String, Hook>>>,
182
183    /// Hook handlers (registered by SDK)
184    handlers: Arc<RwLock<HashMap<String, Arc<dyn HookHandler>>>>,
185
186    /// Event sender channel (for SDK listeners)
187    event_tx: Option<mpsc::Sender<HookEvent>>,
188}
189
190impl std::fmt::Debug for HookEngine {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        f.debug_struct("HookEngine")
193            .field("hooks_count", &read_or_recover(&self.hooks).len())
194            .field("handlers_count", &read_or_recover(&self.handlers).len())
195            .field("has_event_channel", &self.event_tx.is_some())
196            .finish()
197    }
198}
199
200impl Default for HookEngine {
201    fn default() -> Self {
202        Self::new()
203    }
204}
205
206impl HookEngine {
207    /// Create a new hook engine
208    pub fn new() -> Self {
209        Self {
210            hooks: Arc::new(RwLock::new(HashMap::new())),
211            handlers: Arc::new(RwLock::new(HashMap::new())),
212            event_tx: None,
213        }
214    }
215
216    /// Set the event sender channel
217    pub fn with_event_channel(mut self, tx: mpsc::Sender<HookEvent>) -> Self {
218        self.event_tx = Some(tx);
219        self
220    }
221
222    /// Register a hook
223    pub fn register(&self, hook: Hook) {
224        let mut hooks = write_or_recover(&self.hooks);
225        hooks.insert(hook.id.clone(), hook);
226    }
227
228    /// Unregister a hook
229    pub fn unregister(&self, hook_id: &str) -> Option<Hook> {
230        let mut hooks = write_or_recover(&self.hooks);
231        hooks.remove(hook_id)
232    }
233
234    /// Register a handler
235    pub fn register_handler(&self, hook_id: &str, handler: Arc<dyn HookHandler>) {
236        let mut handlers = write_or_recover(&self.handlers);
237        handlers.insert(hook_id.to_string(), handler);
238    }
239
240    /// Unregister a handler
241    pub fn unregister_handler(&self, hook_id: &str) {
242        let mut handlers = write_or_recover(&self.handlers);
243        handlers.remove(hook_id);
244    }
245
246    /// Get all hooks matching an event (sorted by priority)
247    pub fn matching_hooks(&self, event: &HookEvent) -> Vec<Hook> {
248        let hooks = read_or_recover(&self.hooks);
249        let mut matching: Vec<Hook> = hooks
250            .values()
251            .filter(|h| h.matches(event))
252            .cloned()
253            .collect();
254
255        // Sort by priority (lower values = higher priority)
256        matching.sort_by_key(|h| h.config.priority);
257        matching
258    }
259
260    /// Fire an event and get the result
261    pub async fn fire(&self, event: &HookEvent) -> HookResult {
262        // Send event to channel if available
263        if let Some(ref tx) = self.event_tx {
264            let _ = tx.send(event.clone()).await;
265        }
266
267        // Get matching hooks
268        let matching_hooks = self.matching_hooks(event);
269
270        if matching_hooks.is_empty() {
271            return HookResult::continue_();
272        }
273
274        // Execute each hook
275        let mut last_modified: Option<serde_json::Value> = None;
276        for hook in matching_hooks {
277            let result = self.execute_hook(&hook, event).await;
278
279            match result {
280                HookResult::Continue(modified) => {
281                    // Track the last modification — continue to subsequent hooks
282                    if modified.is_some() {
283                        last_modified = modified;
284                    }
285                }
286                HookResult::Block(reason) => {
287                    return HookResult::Block(reason);
288                }
289                HookResult::Retry(delay) => {
290                    return HookResult::Retry(delay);
291                }
292                HookResult::Skip => {
293                    return HookResult::Continue(None);
294                }
295            }
296        }
297
298        HookResult::Continue(last_modified)
299    }
300
301    /// Execute a single hook
302    async fn execute_hook(&self, hook: &Hook, event: &HookEvent) -> HookResult {
303        // Find handler
304        let handler = {
305            let handlers = read_or_recover(&self.handlers);
306            handlers.get(&hook.id).cloned()
307        };
308
309        match handler {
310            Some(h) => {
311                // Handler found, execute it
312                let response = if hook.config.async_execution {
313                    // Async execution (fire-and-forget)
314                    let h = h.clone();
315                    let event = event.clone();
316                    tokio::spawn(async move {
317                        h.handle(&event);
318                    });
319                    HookResponse::continue_()
320                } else {
321                    // Sync execution (with timeout)
322                    let timeout = std::time::Duration::from_millis(hook.config.timeout_ms);
323                    let h = h.clone();
324                    let event = event.clone();
325
326                    match tokio::time::timeout(timeout, async move { h.handle(&event) }).await {
327                        Ok(response) => response,
328                        Err(_) => {
329                            // Timeout, continue execution
330                            HookResponse::continue_()
331                        }
332                    }
333                };
334
335                self.response_to_result(response)
336            }
337            None => {
338                // No handler, continue execution
339                HookResult::continue_()
340            }
341        }
342    }
343
344    /// Convert HookResponse to HookResult
345    fn response_to_result(&self, response: HookResponse) -> HookResult {
346        match response.action {
347            HookAction::Continue => HookResult::Continue(response.modified),
348            HookAction::Block => {
349                HookResult::Block(response.reason.unwrap_or_else(|| "Blocked".to_string()))
350            }
351            HookAction::Retry => HookResult::Retry(response.retry_delay_ms.unwrap_or(1000)),
352            HookAction::Skip => HookResult::Skip,
353        }
354    }
355
356    /// Get the number of registered hooks
357    pub fn hook_count(&self) -> usize {
358        read_or_recover(&self.hooks).len()
359    }
360
361    /// Get a hook by ID
362    pub fn get_hook(&self, id: &str) -> Option<Hook> {
363        read_or_recover(&self.hooks).get(id).cloned()
364    }
365
366    /// Get all hooks
367    pub fn all_hooks(&self) -> Vec<Hook> {
368        read_or_recover(&self.hooks).values().cloned().collect()
369    }
370}
371
372// Implement HookExecutor trait for HookEngine
373#[async_trait]
374impl HookExecutor for HookEngine {
375    async fn fire(&self, event: &HookEvent) -> HookResult {
376        self.fire(event).await
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383    use crate::hooks::events::PreToolUseEvent;
384
385    fn make_pre_tool_event(session_id: &str, tool: &str) -> HookEvent {
386        HookEvent::PreToolUse(PreToolUseEvent {
387            session_id: session_id.to_string(),
388            tool: tool.to_string(),
389            args: serde_json::json!({}),
390            working_directory: "/workspace".to_string(),
391            recent_tools: vec![],
392        })
393    }
394
395    #[test]
396    fn test_hook_config_default() {
397        let config = HookConfig::default();
398        assert_eq!(config.priority, 100);
399        assert_eq!(config.timeout_ms, 30000);
400        assert!(!config.async_execution);
401        assert_eq!(config.max_retries, 0);
402    }
403
404    #[test]
405    fn test_hook_new() {
406        let hook = Hook::new("test-hook", HookEventType::PreToolUse);
407        assert_eq!(hook.id, "test-hook");
408        assert_eq!(hook.event_type, HookEventType::PreToolUse);
409        assert!(hook.matcher.is_none());
410    }
411
412    #[test]
413    fn test_hook_with_matcher() {
414        let hook = Hook::new("test-hook", HookEventType::PreToolUse)
415            .with_matcher(HookMatcher::tool("Bash"));
416
417        assert!(hook.matcher.is_some());
418        assert_eq!(hook.matcher.unwrap().tool, Some("Bash".to_string()));
419    }
420
421    #[test]
422    fn test_hook_matches_event_type() {
423        let hook = Hook::new("test-hook", HookEventType::PreToolUse);
424
425        let pre_event = make_pre_tool_event("s1", "Bash");
426        assert!(hook.matches(&pre_event));
427
428        // PostToolUse doesn't match
429        let post_event = HookEvent::PostToolUse(crate::hooks::events::PostToolUseEvent {
430            session_id: "s1".to_string(),
431            tool: "Bash".to_string(),
432            args: serde_json::json!({}),
433            result: crate::hooks::events::ToolResultData {
434                success: true,
435                output: "".to_string(),
436                exit_code: Some(0),
437                duration_ms: 100,
438            },
439        });
440        assert!(!hook.matches(&post_event));
441    }
442
443    #[test]
444    fn test_hook_matches_with_matcher() {
445        let hook = Hook::new("test-hook", HookEventType::PreToolUse)
446            .with_matcher(HookMatcher::tool("Bash"));
447
448        let bash_event = make_pre_tool_event("s1", "Bash");
449        let read_event = make_pre_tool_event("s1", "Read");
450
451        assert!(hook.matches(&bash_event));
452        assert!(!hook.matches(&read_event));
453    }
454
455    #[test]
456    fn test_hook_result_constructors() {
457        let cont = HookResult::continue_();
458        assert!(cont.is_continue());
459        assert!(!cont.is_block());
460
461        let cont_with = HookResult::continue_with(serde_json::json!({"key": "value"}));
462        assert!(cont_with.is_continue());
463
464        let block = HookResult::block("Blocked");
465        assert!(block.is_block());
466        assert!(!block.is_continue());
467
468        let retry = HookResult::retry(1000);
469        assert!(!retry.is_continue());
470        assert!(!retry.is_block());
471
472        let skip = HookResult::skip();
473        assert!(!skip.is_continue());
474        assert!(!skip.is_block());
475    }
476
477    #[test]
478    fn test_engine_register_unregister() {
479        let engine = HookEngine::new();
480
481        let hook = Hook::new("test-hook", HookEventType::PreToolUse);
482        engine.register(hook);
483
484        assert_eq!(engine.hook_count(), 1);
485        assert!(engine.get_hook("test-hook").is_some());
486
487        let removed = engine.unregister("test-hook");
488        assert!(removed.is_some());
489        assert_eq!(engine.hook_count(), 0);
490    }
491
492    #[test]
493    fn test_engine_matching_hooks() {
494        let engine = HookEngine::new();
495
496        // Register multiple hooks
497        engine.register(
498            Hook::new("hook-1", HookEventType::PreToolUse).with_config(HookConfig {
499                priority: 10,
500                ..Default::default()
501            }),
502        );
503        engine.register(
504            Hook::new("hook-2", HookEventType::PreToolUse)
505                .with_matcher(HookMatcher::tool("Bash"))
506                .with_config(HookConfig {
507                    priority: 5,
508                    ..Default::default()
509                }),
510        );
511        engine.register(Hook::new("hook-3", HookEventType::PostToolUse));
512
513        let event = make_pre_tool_event("s1", "Bash");
514        let matching = engine.matching_hooks(&event);
515
516        // Should match hook-1 and hook-2 (both are PreToolUse)
517        assert_eq!(matching.len(), 2);
518
519        // Sorted by priority, hook-2 (priority=5) should be first
520        assert_eq!(matching[0].id, "hook-2");
521        assert_eq!(matching[1].id, "hook-1");
522    }
523
524    #[tokio::test]
525    async fn test_engine_fire_no_hooks() {
526        let engine = HookEngine::new();
527        let event = make_pre_tool_event("s1", "Bash");
528
529        let result = engine.fire(&event).await;
530        assert!(result.is_continue());
531    }
532
533    #[tokio::test]
534    async fn test_engine_fire_no_handler() {
535        let engine = HookEngine::new();
536        engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
537
538        let event = make_pre_tool_event("s1", "Bash");
539        let result = engine.fire(&event).await;
540
541        // No handler, should continue
542        assert!(result.is_continue());
543    }
544
545    /// Test handler: always continue
546    struct ContinueHandler;
547    impl HookHandler for ContinueHandler {
548        fn handle(&self, _event: &HookEvent) -> HookResponse {
549            HookResponse::continue_()
550        }
551    }
552
553    /// Test handler: always block
554    struct BlockHandler {
555        reason: String,
556    }
557    impl HookHandler for BlockHandler {
558        fn handle(&self, _event: &HookEvent) -> HookResponse {
559            HookResponse::block(&self.reason)
560        }
561    }
562
563    #[tokio::test]
564    async fn test_engine_fire_with_continue_handler() {
565        let engine = HookEngine::new();
566        engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
567        engine.register_handler("test-hook", Arc::new(ContinueHandler));
568
569        let event = make_pre_tool_event("s1", "Bash");
570        let result = engine.fire(&event).await;
571
572        assert!(result.is_continue());
573    }
574
575    #[tokio::test]
576    async fn test_engine_fire_with_block_handler() {
577        let engine = HookEngine::new();
578        engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
579        engine.register_handler(
580            "test-hook",
581            Arc::new(BlockHandler {
582                reason: "Dangerous command".to_string(),
583            }),
584        );
585
586        let event = make_pre_tool_event("s1", "Bash");
587        let result = engine.fire(&event).await;
588
589        assert!(result.is_block());
590        if let HookResult::Block(reason) = result {
591            assert_eq!(reason, "Dangerous command");
592        }
593    }
594
595    #[tokio::test]
596    async fn test_engine_fire_priority_order() {
597        let engine = HookEngine::new();
598
599        // Register two hooks, lower priority one blocks
600        engine.register(
601            Hook::new("block-hook", HookEventType::PreToolUse).with_config(HookConfig {
602                priority: 5, // Higher priority (executes first)
603                ..Default::default()
604            }),
605        );
606        engine.register(
607            Hook::new("continue-hook", HookEventType::PreToolUse).with_config(HookConfig {
608                priority: 10,
609                ..Default::default()
610            }),
611        );
612
613        engine.register_handler(
614            "block-hook",
615            Arc::new(BlockHandler {
616                reason: "Blocked first".to_string(),
617            }),
618        );
619        engine.register_handler("continue-hook", Arc::new(ContinueHandler));
620
621        let event = make_pre_tool_event("s1", "Bash");
622        let result = engine.fire(&event).await;
623
624        // block-hook executes first, should block
625        assert!(result.is_block());
626    }
627
628    #[test]
629    fn test_hook_serialization() {
630        let hook = Hook::new("test-hook", HookEventType::PreToolUse)
631            .with_matcher(HookMatcher::tool("Bash"))
632            .with_config(HookConfig {
633                priority: 50,
634                timeout_ms: 5000,
635                async_execution: true,
636                max_retries: 3,
637            });
638
639        let json = serde_json::to_string(&hook).unwrap();
640        assert!(json.contains("test-hook"));
641        assert!(json.contains("pre_tool_use"));
642        assert!(json.contains("Bash"));
643
644        let parsed: Hook = serde_json::from_str(&json).unwrap();
645        assert_eq!(parsed.id, "test-hook");
646        assert_eq!(parsed.event_type, HookEventType::PreToolUse);
647        assert_eq!(parsed.config.priority, 50);
648    }
649
650    #[test]
651    fn test_all_hooks() {
652        let engine = HookEngine::new();
653        engine.register(Hook::new("hook-1", HookEventType::PreToolUse));
654        engine.register(Hook::new("hook-2", HookEventType::PostToolUse));
655
656        let all = engine.all_hooks();
657        assert_eq!(all.len(), 2);
658    }
659
660    fn make_skill_load_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
661        HookEvent::SkillLoad(crate::hooks::events::SkillLoadEvent {
662            skill_name: skill_name.to_string(),
663            tool_names: tools.iter().map(|s| s.to_string()).collect(),
664            version: Some("1.0.0".to_string()),
665            description: Some("Test skill".to_string()),
666            loaded_at: 1234567890,
667        })
668    }
669
670    fn make_skill_unload_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
671        HookEvent::SkillUnload(crate::hooks::events::SkillUnloadEvent {
672            skill_name: skill_name.to_string(),
673            tool_names: tools.iter().map(|s| s.to_string()).collect(),
674            duration_ms: 60000,
675        })
676    }
677
678    #[tokio::test]
679    async fn test_engine_fire_skill_load() {
680        let engine = HookEngine::new();
681
682        // Register a hook for skill load events
683        engine.register(Hook::new("skill-load-hook", HookEventType::SkillLoad));
684        engine.register_handler("skill-load-hook", Arc::new(ContinueHandler));
685
686        let event = make_skill_load_event("my-skill", vec!["tool1", "tool2"]);
687        let result = engine.fire(&event).await;
688
689        assert!(result.is_continue());
690    }
691
692    #[tokio::test]
693    async fn test_engine_fire_skill_unload() {
694        let engine = HookEngine::new();
695
696        // Register a hook for skill unload events
697        engine.register(Hook::new("skill-unload-hook", HookEventType::SkillUnload));
698        engine.register_handler("skill-unload-hook", Arc::new(ContinueHandler));
699
700        let event = make_skill_unload_event("my-skill", vec!["tool1", "tool2"]);
701        let result = engine.fire(&event).await;
702
703        assert!(result.is_continue());
704    }
705
706    #[tokio::test]
707    async fn test_engine_skill_hook_with_matcher() {
708        let engine = HookEngine::new();
709
710        // Register a hook that only matches specific skill
711        engine.register(
712            Hook::new("specific-skill-hook", HookEventType::SkillLoad)
713                .with_matcher(HookMatcher::skill("my-skill")),
714        );
715        engine.register_handler(
716            "specific-skill-hook",
717            Arc::new(BlockHandler {
718                reason: "Skill blocked".to_string(),
719            }),
720        );
721
722        // Should match and block
723        let matching_event = make_skill_load_event("my-skill", vec!["tool1"]);
724        let result = engine.fire(&matching_event).await;
725        assert!(result.is_block());
726
727        // Should not match (no hooks match, so continue)
728        let non_matching_event = make_skill_load_event("other-skill", vec!["tool1"]);
729        let result = engine.fire(&non_matching_event).await;
730        assert!(result.is_continue());
731    }
732
733    #[tokio::test]
734    async fn test_engine_skill_hook_pattern_matcher() {
735        let engine = HookEngine::new();
736
737        // Register a hook with glob pattern
738        engine.register(
739            Hook::new("test-skill-hook", HookEventType::SkillLoad)
740                .with_matcher(HookMatcher::skill("test-*")),
741        );
742        engine.register_handler(
743            "test-skill-hook",
744            Arc::new(BlockHandler {
745                reason: "Test skill blocked".to_string(),
746            }),
747        );
748
749        // Should match pattern
750        let test_skill = make_skill_load_event("test-alpha", vec!["tool1"]);
751        let result = engine.fire(&test_skill).await;
752        assert!(result.is_block());
753
754        let test_skill2 = make_skill_load_event("test-beta", vec!["tool1"]);
755        let result = engine.fire(&test_skill2).await;
756        assert!(result.is_block());
757
758        // Should not match pattern
759        let prod_skill = make_skill_load_event("prod-skill", vec!["tool1"]);
760        let result = engine.fire(&prod_skill).await;
761        assert!(result.is_continue());
762    }
763}