Skip to main content

a3s_code_core/hooks/
engine.rs

1//! Hook Engine
2//!
3//! Core engine responsible for managing and executing hooks.
4
5use super::events::{HookEvent, HookEventType};
6use super::matcher::HookMatcher;
7use super::{HookAction, HookResponse};
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use tokio::sync::mpsc;
13
14use crate::error::{read_or_recover, write_or_recover};
15
16/// Hook configuration
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HookConfig {
19    /// Priority (lower values = higher priority)
20    #[serde(default = "default_priority")]
21    pub priority: i32,
22
23    /// Timeout in milliseconds
24    #[serde(default = "default_timeout")]
25    pub timeout_ms: u64,
26
27    /// Whether to execute asynchronously (fire-and-forget)
28    #[serde(default)]
29    pub async_execution: bool,
30
31    /// Maximum retry attempts
32    #[serde(default)]
33    pub max_retries: u32,
34}
35
36fn default_priority() -> i32 {
37    100
38}
39
40fn default_timeout() -> u64 {
41    30000
42}
43
44impl Default for HookConfig {
45    fn default() -> Self {
46        Self {
47            priority: default_priority(),
48            timeout_ms: default_timeout(),
49            async_execution: false,
50            max_retries: 0,
51        }
52    }
53}
54
55/// Hook definition
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct Hook {
58    /// Unique hook identifier
59    pub id: String,
60
61    /// Event type that triggers this hook
62    pub event_type: HookEventType,
63
64    /// Event matcher (optional, None matches all events)
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub matcher: Option<HookMatcher>,
67
68    /// Hook configuration
69    #[serde(default)]
70    pub config: HookConfig,
71}
72
73impl Hook {
74    /// Create a new hook
75    pub fn new(id: impl Into<String>, event_type: HookEventType) -> Self {
76        Self {
77            id: id.into(),
78            event_type,
79            matcher: None,
80            config: HookConfig::default(),
81        }
82    }
83
84    /// Set the matcher
85    pub fn with_matcher(mut self, matcher: HookMatcher) -> Self {
86        self.matcher = Some(matcher);
87        self
88    }
89
90    /// Set the configuration
91    pub fn with_config(mut self, config: HookConfig) -> Self {
92        self.config = config;
93        self
94    }
95
96    /// Check if an event matches this hook
97    pub fn matches(&self, event: &HookEvent) -> bool {
98        // First check event type
99        if event.event_type() != self.event_type {
100            return false;
101        }
102
103        // If there's a matcher, check it
104        if let Some(ref matcher) = self.matcher {
105            matcher.matches(event)
106        } else {
107            true
108        }
109    }
110}
111
112/// Hook execution result
113#[derive(Debug, Clone)]
114pub enum HookResult {
115    /// Continue execution (with optional modified data)
116    Continue(Option<serde_json::Value>),
117    /// Block execution
118    Block(String),
119    /// Retry after delay (milliseconds)
120    Retry(u64),
121    /// Skip remaining hooks but continue execution
122    Skip,
123    /// Escalate to human review
124    Escalate {
125        reason: String,
126        target: Option<String>,
127    },
128}
129
130impl HookResult {
131    /// Create a continue result
132    pub fn continue_() -> Self {
133        Self::Continue(None)
134    }
135
136    /// Create a continue result with modifications
137    pub fn continue_with(modified: serde_json::Value) -> Self {
138        Self::Continue(Some(modified))
139    }
140
141    /// Create a block result
142    pub fn block(reason: impl Into<String>) -> Self {
143        Self::Block(reason.into())
144    }
145
146    /// Create a retry result
147    pub fn retry(delay_ms: u64) -> Self {
148        Self::Retry(delay_ms)
149    }
150
151    /// Create a skip result
152    pub fn skip() -> Self {
153        Self::Skip
154    }
155
156    /// Create an escalate result
157    pub fn escalate(reason: impl Into<String>, target: Option<String>) -> Self {
158        Self::Escalate {
159            reason: reason.into(),
160            target,
161        }
162    }
163
164    /// Check if this is a continue result
165    pub fn is_continue(&self) -> bool {
166        matches!(self, Self::Continue(_))
167    }
168
169    /// Check if this is a block result
170    pub fn is_block(&self) -> bool {
171        matches!(self, Self::Block(_))
172    }
173}
174
175/// Hook handler trait
176pub trait HookHandler: Send + Sync {
177    /// Handle a hook event
178    fn handle(&self, event: &HookEvent) -> HookResponse;
179}
180
181/// Hook executor trait
182///
183/// Abstracts hook execution, allowing different implementations
184/// (e.g., full engine, no-op, test mocks) while keeping agent logic clean.
185#[async_trait::async_trait]
186pub trait HookExecutor: Send + Sync + std::fmt::Debug {
187    /// Fire a hook event and get the result
188    async fn fire(&self, event: &HookEvent) -> HookResult;
189}
190
191/// Hook engine
192pub struct HookEngine {
193    /// Registered hooks
194    hooks: Arc<RwLock<HashMap<String, Hook>>>,
195
196    /// Hook handlers (registered by SDK)
197    handlers: Arc<RwLock<HashMap<String, Arc<dyn HookHandler>>>>,
198
199    /// Event sender channel (for SDK listeners)
200    event_tx: Option<mpsc::Sender<HookEvent>>,
201}
202
203impl std::fmt::Debug for HookEngine {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        f.debug_struct("HookEngine")
206            .field("hooks_count", &read_or_recover(&self.hooks).len())
207            .field("handlers_count", &read_or_recover(&self.handlers).len())
208            .field("has_event_channel", &self.event_tx.is_some())
209            .finish()
210    }
211}
212
213impl Default for HookEngine {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219impl HookEngine {
220    /// Create a new hook engine
221    pub fn new() -> Self {
222        Self {
223            hooks: Arc::new(RwLock::new(HashMap::new())),
224            handlers: Arc::new(RwLock::new(HashMap::new())),
225            event_tx: None,
226        }
227    }
228
229    /// Set the event sender channel
230    pub fn with_event_channel(mut self, tx: mpsc::Sender<HookEvent>) -> Self {
231        self.event_tx = Some(tx);
232        self
233    }
234
235    /// Register a hook
236    pub fn register(&self, hook: Hook) {
237        let mut hooks = write_or_recover(&self.hooks);
238        hooks.insert(hook.id.clone(), hook);
239    }
240
241    /// Unregister a hook
242    pub fn unregister(&self, hook_id: &str) -> Option<Hook> {
243        let mut hooks = write_or_recover(&self.hooks);
244        hooks.remove(hook_id)
245    }
246
247    /// Register a handler
248    pub fn register_handler(&self, hook_id: &str, handler: Arc<dyn HookHandler>) {
249        let mut handlers = write_or_recover(&self.handlers);
250        handlers.insert(hook_id.to_string(), handler);
251    }
252
253    /// Unregister a handler
254    pub fn unregister_handler(&self, hook_id: &str) {
255        let mut handlers = write_or_recover(&self.handlers);
256        handlers.remove(hook_id);
257    }
258
259    /// Get all hooks matching an event (sorted by priority)
260    pub fn matching_hooks(&self, event: &HookEvent) -> Vec<Hook> {
261        let hooks = read_or_recover(&self.hooks);
262        let mut matching: Vec<Hook> = hooks
263            .values()
264            .filter(|h| h.matches(event))
265            .cloned()
266            .collect();
267
268        // Sort by priority (lower values = higher priority)
269        matching.sort_by_key(|h| h.config.priority);
270        matching
271    }
272
273    /// Fire an event and get the result
274    pub async fn fire(&self, event: &HookEvent) -> HookResult {
275        // Send event to channel if available
276        if let Some(ref tx) = self.event_tx {
277            let _ = tx.send(event.clone()).await;
278        }
279
280        // Get matching hooks
281        let matching_hooks = self.matching_hooks(event);
282
283        if matching_hooks.is_empty() {
284            return HookResult::continue_();
285        }
286
287        // Execute each hook
288        let mut last_modified: Option<serde_json::Value> = None;
289        for hook in matching_hooks {
290            let result = self.execute_hook(&hook, event).await;
291
292            match result {
293                HookResult::Continue(modified) => {
294                    // Track the last modification — continue to subsequent hooks
295                    if modified.is_some() {
296                        last_modified = modified;
297                    }
298                }
299                HookResult::Block(reason) => {
300                    return HookResult::Block(reason);
301                }
302                HookResult::Retry(delay) => {
303                    return HookResult::Retry(delay);
304                }
305                HookResult::Skip => {
306                    return HookResult::Continue(None);
307                }
308                HookResult::Escalate { reason, target } => {
309                    return HookResult::Escalate { reason, target };
310                }
311            }
312        }
313
314        HookResult::Continue(last_modified)
315    }
316
317    /// Execute a single hook
318    async fn execute_hook(&self, hook: &Hook, event: &HookEvent) -> HookResult {
319        // Find handler
320        let handler = {
321            let handlers = read_or_recover(&self.handlers);
322            handlers.get(&hook.id).cloned()
323        };
324
325        match handler {
326            Some(h) => {
327                // Handler found, execute it
328                let response = if hook.config.async_execution {
329                    // Async execution (fire-and-forget)
330                    let h = h.clone();
331                    let event = event.clone();
332                    tokio::spawn(async move {
333                        h.handle(&event);
334                    });
335                    HookResponse::continue_()
336                } else {
337                    // Sync execution (with timeout)
338                    let timeout = std::time::Duration::from_millis(hook.config.timeout_ms);
339                    let h = h.clone();
340                    let event = event.clone();
341
342                    match tokio::time::timeout(timeout, async move { h.handle(&event) }).await {
343                        Ok(response) => response,
344                        Err(_) => {
345                            // Timeout, continue execution
346                            HookResponse::continue_()
347                        }
348                    }
349                };
350
351                self.response_to_result(response)
352            }
353            None => {
354                // No handler, continue execution
355                HookResult::continue_()
356            }
357        }
358    }
359
360    /// Convert HookResponse to HookResult
361    fn response_to_result(&self, response: HookResponse) -> HookResult {
362        match response.action {
363            HookAction::Continue => HookResult::Continue(response.modified),
364            HookAction::Block => {
365                HookResult::Block(response.reason.unwrap_or_else(|| "Blocked".to_string()))
366            }
367            HookAction::Retry => HookResult::Retry(response.retry_delay_ms.unwrap_or(1000)),
368            HookAction::Skip => HookResult::Skip,
369        }
370    }
371
372    /// Get the number of registered hooks
373    pub fn hook_count(&self) -> usize {
374        read_or_recover(&self.hooks).len()
375    }
376
377    /// Get a hook by ID
378    pub fn get_hook(&self, id: &str) -> Option<Hook> {
379        read_or_recover(&self.hooks).get(id).cloned()
380    }
381
382    /// Get all hooks
383    pub fn all_hooks(&self) -> Vec<Hook> {
384        read_or_recover(&self.hooks).values().cloned().collect()
385    }
386}
387
388// Implement HookExecutor trait for HookEngine
389#[async_trait]
390impl HookExecutor for HookEngine {
391    async fn fire(&self, event: &HookEvent) -> HookResult {
392        self.fire(event).await
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use crate::hooks::events::PreToolUseEvent;
400
401    fn make_pre_tool_event(session_id: &str, tool: &str) -> HookEvent {
402        HookEvent::PreToolUse(PreToolUseEvent {
403            session_id: session_id.to_string(),
404            tool: tool.to_string(),
405            args: serde_json::json!({}),
406            working_directory: "/workspace".to_string(),
407            recent_tools: vec![],
408        })
409    }
410
411    #[test]
412    fn test_hook_config_default() {
413        let config = HookConfig::default();
414        assert_eq!(config.priority, 100);
415        assert_eq!(config.timeout_ms, 30000);
416        assert!(!config.async_execution);
417        assert_eq!(config.max_retries, 0);
418    }
419
420    #[test]
421    fn test_hook_new() {
422        let hook = Hook::new("test-hook", HookEventType::PreToolUse);
423        assert_eq!(hook.id, "test-hook");
424        assert_eq!(hook.event_type, HookEventType::PreToolUse);
425        assert!(hook.matcher.is_none());
426    }
427
428    #[test]
429    fn test_hook_with_matcher() {
430        let hook = Hook::new("test-hook", HookEventType::PreToolUse)
431            .with_matcher(HookMatcher::tool("Bash"));
432
433        assert!(hook.matcher.is_some());
434        assert_eq!(hook.matcher.unwrap().tool, Some("Bash".to_string()));
435    }
436
437    #[test]
438    fn test_hook_matches_event_type() {
439        let hook = Hook::new("test-hook", HookEventType::PreToolUse);
440
441        let pre_event = make_pre_tool_event("s1", "Bash");
442        assert!(hook.matches(&pre_event));
443
444        // PostToolUse doesn't match
445        let post_event = HookEvent::PostToolUse(crate::hooks::events::PostToolUseEvent {
446            session_id: "s1".to_string(),
447            tool: "Bash".to_string(),
448            args: serde_json::json!({}),
449            result: crate::hooks::events::ToolResultData {
450                success: true,
451                output: "".to_string(),
452                exit_code: Some(0),
453                duration_ms: 100,
454            },
455        });
456        assert!(!hook.matches(&post_event));
457    }
458
459    #[test]
460    fn test_hook_matches_with_matcher() {
461        let hook = Hook::new("test-hook", HookEventType::PreToolUse)
462            .with_matcher(HookMatcher::tool("Bash"));
463
464        let bash_event = make_pre_tool_event("s1", "Bash");
465        let read_event = make_pre_tool_event("s1", "Read");
466
467        assert!(hook.matches(&bash_event));
468        assert!(!hook.matches(&read_event));
469    }
470
471    #[test]
472    fn test_hook_result_constructors() {
473        let cont = HookResult::continue_();
474        assert!(cont.is_continue());
475        assert!(!cont.is_block());
476
477        let cont_with = HookResult::continue_with(serde_json::json!({"key": "value"}));
478        assert!(cont_with.is_continue());
479
480        let block = HookResult::block("Blocked");
481        assert!(block.is_block());
482        assert!(!block.is_continue());
483
484        let retry = HookResult::retry(1000);
485        assert!(!retry.is_continue());
486        assert!(!retry.is_block());
487
488        let skip = HookResult::skip();
489        assert!(!skip.is_continue());
490        assert!(!skip.is_block());
491    }
492
493    #[test]
494    fn test_engine_register_unregister() {
495        let engine = HookEngine::new();
496
497        let hook = Hook::new("test-hook", HookEventType::PreToolUse);
498        engine.register(hook);
499
500        assert_eq!(engine.hook_count(), 1);
501        assert!(engine.get_hook("test-hook").is_some());
502
503        let removed = engine.unregister("test-hook");
504        assert!(removed.is_some());
505        assert_eq!(engine.hook_count(), 0);
506    }
507
508    #[test]
509    fn test_engine_matching_hooks() {
510        let engine = HookEngine::new();
511
512        // Register multiple hooks
513        engine.register(
514            Hook::new("hook-1", HookEventType::PreToolUse).with_config(HookConfig {
515                priority: 10,
516                ..Default::default()
517            }),
518        );
519        engine.register(
520            Hook::new("hook-2", HookEventType::PreToolUse)
521                .with_matcher(HookMatcher::tool("Bash"))
522                .with_config(HookConfig {
523                    priority: 5,
524                    ..Default::default()
525                }),
526        );
527        engine.register(Hook::new("hook-3", HookEventType::PostToolUse));
528
529        let event = make_pre_tool_event("s1", "Bash");
530        let matching = engine.matching_hooks(&event);
531
532        // Should match hook-1 and hook-2 (both are PreToolUse)
533        assert_eq!(matching.len(), 2);
534
535        // Sorted by priority, hook-2 (priority=5) should be first
536        assert_eq!(matching[0].id, "hook-2");
537        assert_eq!(matching[1].id, "hook-1");
538    }
539
540    #[tokio::test]
541    async fn test_engine_fire_no_hooks() {
542        let engine = HookEngine::new();
543        let event = make_pre_tool_event("s1", "Bash");
544
545        let result = engine.fire(&event).await;
546        assert!(result.is_continue());
547    }
548
549    #[tokio::test]
550    async fn test_engine_fire_no_handler() {
551        let engine = HookEngine::new();
552        engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
553
554        let event = make_pre_tool_event("s1", "Bash");
555        let result = engine.fire(&event).await;
556
557        // No handler, should continue
558        assert!(result.is_continue());
559    }
560
561    /// Test handler: always continue
562    struct ContinueHandler;
563    impl HookHandler for ContinueHandler {
564        fn handle(&self, _event: &HookEvent) -> HookResponse {
565            HookResponse::continue_()
566        }
567    }
568
569    /// Test handler: always block
570    struct BlockHandler {
571        reason: String,
572    }
573    impl HookHandler for BlockHandler {
574        fn handle(&self, _event: &HookEvent) -> HookResponse {
575            HookResponse::block(&self.reason)
576        }
577    }
578
579    #[tokio::test]
580    async fn test_engine_fire_with_continue_handler() {
581        let engine = HookEngine::new();
582        engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
583        engine.register_handler("test-hook", Arc::new(ContinueHandler));
584
585        let event = make_pre_tool_event("s1", "Bash");
586        let result = engine.fire(&event).await;
587
588        assert!(result.is_continue());
589    }
590
591    #[tokio::test]
592    async fn test_engine_fire_with_block_handler() {
593        let engine = HookEngine::new();
594        engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
595        engine.register_handler(
596            "test-hook",
597            Arc::new(BlockHandler {
598                reason: "Dangerous command".to_string(),
599            }),
600        );
601
602        let event = make_pre_tool_event("s1", "Bash");
603        let result = engine.fire(&event).await;
604
605        assert!(result.is_block());
606        if let HookResult::Block(reason) = result {
607            assert_eq!(reason, "Dangerous command");
608        }
609    }
610
611    #[tokio::test]
612    async fn test_engine_fire_priority_order() {
613        let engine = HookEngine::new();
614
615        // Register two hooks, lower priority one blocks
616        engine.register(
617            Hook::new("block-hook", HookEventType::PreToolUse).with_config(HookConfig {
618                priority: 5, // Higher priority (executes first)
619                ..Default::default()
620            }),
621        );
622        engine.register(
623            Hook::new("continue-hook", HookEventType::PreToolUse).with_config(HookConfig {
624                priority: 10,
625                ..Default::default()
626            }),
627        );
628
629        engine.register_handler(
630            "block-hook",
631            Arc::new(BlockHandler {
632                reason: "Blocked first".to_string(),
633            }),
634        );
635        engine.register_handler("continue-hook", Arc::new(ContinueHandler));
636
637        let event = make_pre_tool_event("s1", "Bash");
638        let result = engine.fire(&event).await;
639
640        // block-hook executes first, should block
641        assert!(result.is_block());
642    }
643
644    #[test]
645    fn test_hook_serialization() {
646        let hook = Hook::new("test-hook", HookEventType::PreToolUse)
647            .with_matcher(HookMatcher::tool("Bash"))
648            .with_config(HookConfig {
649                priority: 50,
650                timeout_ms: 5000,
651                async_execution: true,
652                max_retries: 3,
653            });
654
655        let json = serde_json::to_string(&hook).unwrap();
656        assert!(json.contains("test-hook"));
657        assert!(json.contains("pre_tool_use"));
658        assert!(json.contains("Bash"));
659
660        let parsed: Hook = serde_json::from_str(&json).unwrap();
661        assert_eq!(parsed.id, "test-hook");
662        assert_eq!(parsed.event_type, HookEventType::PreToolUse);
663        assert_eq!(parsed.config.priority, 50);
664    }
665
666    #[test]
667    fn test_all_hooks() {
668        let engine = HookEngine::new();
669        engine.register(Hook::new("hook-1", HookEventType::PreToolUse));
670        engine.register(Hook::new("hook-2", HookEventType::PostToolUse));
671
672        let all = engine.all_hooks();
673        assert_eq!(all.len(), 2);
674    }
675
676    fn make_skill_load_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
677        HookEvent::SkillLoad(crate::hooks::events::SkillLoadEvent {
678            skill_name: skill_name.to_string(),
679            tool_names: tools.iter().map(|s| s.to_string()).collect(),
680            version: Some("1.0.0".to_string()),
681            description: Some("Test skill".to_string()),
682            loaded_at: 1234567890,
683        })
684    }
685
686    fn make_skill_unload_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
687        HookEvent::SkillUnload(crate::hooks::events::SkillUnloadEvent {
688            skill_name: skill_name.to_string(),
689            tool_names: tools.iter().map(|s| s.to_string()).collect(),
690            duration_ms: 60000,
691        })
692    }
693
694    #[tokio::test]
695    async fn test_engine_fire_skill_load() {
696        let engine = HookEngine::new();
697
698        // Register a hook for skill load events
699        engine.register(Hook::new("skill-load-hook", HookEventType::SkillLoad));
700        engine.register_handler("skill-load-hook", Arc::new(ContinueHandler));
701
702        let event = make_skill_load_event("my-skill", vec!["tool1", "tool2"]);
703        let result = engine.fire(&event).await;
704
705        assert!(result.is_continue());
706    }
707
708    #[tokio::test]
709    async fn test_engine_fire_skill_unload() {
710        let engine = HookEngine::new();
711
712        // Register a hook for skill unload events
713        engine.register(Hook::new("skill-unload-hook", HookEventType::SkillUnload));
714        engine.register_handler("skill-unload-hook", Arc::new(ContinueHandler));
715
716        let event = make_skill_unload_event("my-skill", vec!["tool1", "tool2"]);
717        let result = engine.fire(&event).await;
718
719        assert!(result.is_continue());
720    }
721
722    #[tokio::test]
723    async fn test_engine_skill_hook_with_matcher() {
724        let engine = HookEngine::new();
725
726        // Register a hook that only matches specific skill
727        engine.register(
728            Hook::new("specific-skill-hook", HookEventType::SkillLoad)
729                .with_matcher(HookMatcher::skill("my-skill")),
730        );
731        engine.register_handler(
732            "specific-skill-hook",
733            Arc::new(BlockHandler {
734                reason: "Skill blocked".to_string(),
735            }),
736        );
737
738        // Should match and block
739        let matching_event = make_skill_load_event("my-skill", vec!["tool1"]);
740        let result = engine.fire(&matching_event).await;
741        assert!(result.is_block());
742
743        // Should not match (no hooks match, so continue)
744        let non_matching_event = make_skill_load_event("other-skill", vec!["tool1"]);
745        let result = engine.fire(&non_matching_event).await;
746        assert!(result.is_continue());
747    }
748
749    #[tokio::test]
750    async fn test_engine_skill_hook_pattern_matcher() {
751        let engine = HookEngine::new();
752
753        // Register a hook with glob pattern
754        engine.register(
755            Hook::new("test-skill-hook", HookEventType::SkillLoad)
756                .with_matcher(HookMatcher::skill("test-*")),
757        );
758        engine.register_handler(
759            "test-skill-hook",
760            Arc::new(BlockHandler {
761                reason: "Test skill blocked".to_string(),
762            }),
763        );
764
765        // Should match pattern
766        let test_skill = make_skill_load_event("test-alpha", vec!["tool1"]);
767        let result = engine.fire(&test_skill).await;
768        assert!(result.is_block());
769
770        let test_skill2 = make_skill_load_event("test-beta", vec!["tool1"]);
771        let result = engine.fire(&test_skill2).await;
772        assert!(result.is_block());
773
774        // Should not match pattern
775        let prod_skill = make_skill_load_event("prod-skill", vec!["tool1"]);
776        let result = engine.fire(&prod_skill).await;
777        assert!(result.is_continue());
778    }
779}