Skip to main content

a3s_code_core/hooks/
engine.rs

1//! Hook Engine
2//!
3//! Core engine responsible for managing and executing hooks.
4
5use super::events::{HookEvent, HookEventType};
6use super::matcher::HookMatcher;
7use super::{HookAction, HookResponse};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use tokio::sync::mpsc;
12
13use crate::error::{read_or_recover, write_or_recover};
14
15/// Hook configuration
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct HookConfig {
18    /// Priority (lower values = higher priority)
19    #[serde(default = "default_priority")]
20    pub priority: i32,
21
22    /// Timeout in milliseconds
23    #[serde(default = "default_timeout")]
24    pub timeout_ms: u64,
25
26    /// Whether to execute asynchronously (fire-and-forget)
27    #[serde(default)]
28    pub async_execution: bool,
29
30    /// Maximum retry attempts
31    #[serde(default)]
32    pub max_retries: u32,
33}
34
35fn default_priority() -> i32 {
36    100
37}
38
39fn default_timeout() -> u64 {
40    30000
41}
42
43impl Default for HookConfig {
44    fn default() -> Self {
45        Self {
46            priority: default_priority(),
47            timeout_ms: default_timeout(),
48            async_execution: false,
49            max_retries: 0,
50        }
51    }
52}
53
54/// Hook definition
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct Hook {
57    /// Unique hook identifier
58    pub id: String,
59
60    /// Event type that triggers this hook
61    pub event_type: HookEventType,
62
63    /// Event matcher (optional, None matches all events)
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub matcher: Option<HookMatcher>,
66
67    /// Hook configuration
68    #[serde(default)]
69    pub config: HookConfig,
70}
71
72impl Hook {
73    /// Create a new hook
74    pub fn new(id: impl Into<String>, event_type: HookEventType) -> Self {
75        Self {
76            id: id.into(),
77            event_type,
78            matcher: None,
79            config: HookConfig::default(),
80        }
81    }
82
83    /// Set the matcher
84    pub fn with_matcher(mut self, matcher: HookMatcher) -> Self {
85        self.matcher = Some(matcher);
86        self
87    }
88
89    /// Set the configuration
90    pub fn with_config(mut self, config: HookConfig) -> Self {
91        self.config = config;
92        self
93    }
94
95    /// Check if an event matches this hook
96    pub fn matches(&self, event: &HookEvent) -> bool {
97        // First check event type
98        if event.event_type() != self.event_type {
99            return false;
100        }
101
102        // If there's a matcher, check it
103        if let Some(ref matcher) = self.matcher {
104            matcher.matches(event)
105        } else {
106            true
107        }
108    }
109}
110
111/// Hook execution result
112#[derive(Debug, Clone)]
113pub enum HookResult {
114    /// Continue execution (with optional modified data)
115    Continue(Option<serde_json::Value>),
116    /// Block execution
117    Block(String),
118    /// Retry after delay (milliseconds)
119    Retry(u64),
120    /// Skip remaining hooks but continue execution
121    Skip,
122}
123
124impl HookResult {
125    /// Create a continue result
126    pub fn continue_() -> Self {
127        Self::Continue(None)
128    }
129
130    /// Create a continue result with modifications
131    pub fn continue_with(modified: serde_json::Value) -> Self {
132        Self::Continue(Some(modified))
133    }
134
135    /// Create a block result
136    pub fn block(reason: impl Into<String>) -> Self {
137        Self::Block(reason.into())
138    }
139
140    /// Create a retry result
141    pub fn retry(delay_ms: u64) -> Self {
142        Self::Retry(delay_ms)
143    }
144
145    /// Create a skip result
146    pub fn skip() -> Self {
147        Self::Skip
148    }
149
150    /// Check if this is a continue result
151    pub fn is_continue(&self) -> bool {
152        matches!(self, Self::Continue(_))
153    }
154
155    /// Check if this is a block result
156    pub fn is_block(&self) -> bool {
157        matches!(self, Self::Block(_))
158    }
159}
160
161/// Hook handler trait
162pub trait HookHandler: Send + Sync {
163    /// Handle a hook event
164    fn handle(&self, event: &HookEvent) -> HookResponse;
165}
166
167/// Hook engine
168pub struct HookEngine {
169    /// Registered hooks
170    hooks: Arc<RwLock<HashMap<String, Hook>>>,
171
172    /// Hook handlers (registered by SDK)
173    handlers: Arc<RwLock<HashMap<String, Arc<dyn HookHandler>>>>,
174
175    /// Event sender channel (for SDK listeners)
176    event_tx: Option<mpsc::Sender<HookEvent>>,
177}
178
179impl std::fmt::Debug for HookEngine {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        f.debug_struct("HookEngine")
182            .field("hooks_count", &read_or_recover(&self.hooks).len())
183            .field("handlers_count", &read_or_recover(&self.handlers).len())
184            .field("has_event_channel", &self.event_tx.is_some())
185            .finish()
186    }
187}
188
189impl Default for HookEngine {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195impl HookEngine {
196    /// Create a new hook engine
197    pub fn new() -> Self {
198        Self {
199            hooks: Arc::new(RwLock::new(HashMap::new())),
200            handlers: Arc::new(RwLock::new(HashMap::new())),
201            event_tx: None,
202        }
203    }
204
205    /// Set the event sender channel
206    pub fn with_event_channel(mut self, tx: mpsc::Sender<HookEvent>) -> Self {
207        self.event_tx = Some(tx);
208        self
209    }
210
211    /// Register a hook
212    pub fn register(&self, hook: Hook) {
213        let mut hooks = write_or_recover(&self.hooks);
214        hooks.insert(hook.id.clone(), hook);
215    }
216
217    /// Unregister a hook
218    pub fn unregister(&self, hook_id: &str) -> Option<Hook> {
219        let mut hooks = write_or_recover(&self.hooks);
220        hooks.remove(hook_id)
221    }
222
223    /// Register a handler
224    pub fn register_handler(&self, hook_id: &str, handler: Arc<dyn HookHandler>) {
225        let mut handlers = write_or_recover(&self.handlers);
226        handlers.insert(hook_id.to_string(), handler);
227    }
228
229    /// Unregister a handler
230    pub fn unregister_handler(&self, hook_id: &str) {
231        let mut handlers = write_or_recover(&self.handlers);
232        handlers.remove(hook_id);
233    }
234
235    /// Get all hooks matching an event (sorted by priority)
236    pub fn matching_hooks(&self, event: &HookEvent) -> Vec<Hook> {
237        let hooks = read_or_recover(&self.hooks);
238        let mut matching: Vec<Hook> = hooks
239            .values()
240            .filter(|h| h.matches(event))
241            .cloned()
242            .collect();
243
244        // Sort by priority (lower values = higher priority)
245        matching.sort_by_key(|h| h.config.priority);
246        matching
247    }
248
249    /// Fire an event and get the result
250    pub async fn fire(&self, event: &HookEvent) -> HookResult {
251        // Send event to channel if available
252        if let Some(ref tx) = self.event_tx {
253            let _ = tx.send(event.clone()).await;
254        }
255
256        // Get matching hooks
257        let matching_hooks = self.matching_hooks(event);
258
259        if matching_hooks.is_empty() {
260            return HookResult::continue_();
261        }
262
263        // Execute each hook
264        for hook in matching_hooks {
265            let result = self.execute_hook(&hook, event).await;
266
267            match result {
268                HookResult::Continue(modified) => {
269                    // If modified, can apply to subsequent hooks
270                    // For now, simple handling: continue to next hook
271                    if modified.is_some() {
272                        return HookResult::Continue(modified);
273                    }
274                }
275                HookResult::Block(reason) => {
276                    return HookResult::Block(reason);
277                }
278                HookResult::Retry(delay) => {
279                    return HookResult::Retry(delay);
280                }
281                HookResult::Skip => {
282                    return HookResult::Continue(None);
283                }
284            }
285        }
286
287        HookResult::continue_()
288    }
289
290    /// Execute a single hook
291    async fn execute_hook(&self, hook: &Hook, event: &HookEvent) -> HookResult {
292        // Find handler
293        let handler = {
294            let handlers = read_or_recover(&self.handlers);
295            handlers.get(&hook.id).cloned()
296        };
297
298        match handler {
299            Some(h) => {
300                // Handler found, execute it
301                let response = if hook.config.async_execution {
302                    // Async execution (fire-and-forget)
303                    let h = h.clone();
304                    let event = event.clone();
305                    tokio::spawn(async move {
306                        h.handle(&event);
307                    });
308                    HookResponse::continue_()
309                } else {
310                    // Sync execution (with timeout)
311                    let timeout = std::time::Duration::from_millis(hook.config.timeout_ms);
312                    let h = h.clone();
313                    let event = event.clone();
314
315                    match tokio::time::timeout(timeout, async move { h.handle(&event) }).await {
316                        Ok(response) => response,
317                        Err(_) => {
318                            // Timeout, continue execution
319                            HookResponse::continue_()
320                        }
321                    }
322                };
323
324                self.response_to_result(response)
325            }
326            None => {
327                // No handler, continue execution
328                HookResult::continue_()
329            }
330        }
331    }
332
333    /// Convert HookResponse to HookResult
334    fn response_to_result(&self, response: HookResponse) -> HookResult {
335        match response.action {
336            HookAction::Continue => HookResult::Continue(response.modified),
337            HookAction::Block => {
338                HookResult::Block(response.reason.unwrap_or_else(|| "Blocked".to_string()))
339            }
340            HookAction::Retry => HookResult::Retry(response.retry_delay_ms.unwrap_or(1000)),
341            HookAction::Skip => HookResult::Skip,
342        }
343    }
344
345    /// Get the number of registered hooks
346    pub fn hook_count(&self) -> usize {
347        read_or_recover(&self.hooks).len()
348    }
349
350    /// Get a hook by ID
351    pub fn get_hook(&self, id: &str) -> Option<Hook> {
352        read_or_recover(&self.hooks).get(id).cloned()
353    }
354
355    /// Get all hooks
356    pub fn all_hooks(&self) -> Vec<Hook> {
357        read_or_recover(&self.hooks).values().cloned().collect()
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use crate::hooks::events::PreToolUseEvent;
365
366    fn make_pre_tool_event(session_id: &str, tool: &str) -> HookEvent {
367        HookEvent::PreToolUse(PreToolUseEvent {
368            session_id: session_id.to_string(),
369            tool: tool.to_string(),
370            args: serde_json::json!({}),
371            working_directory: "/workspace".to_string(),
372            recent_tools: vec![],
373        })
374    }
375
376    #[test]
377    fn test_hook_config_default() {
378        let config = HookConfig::default();
379        assert_eq!(config.priority, 100);
380        assert_eq!(config.timeout_ms, 30000);
381        assert!(!config.async_execution);
382        assert_eq!(config.max_retries, 0);
383    }
384
385    #[test]
386    fn test_hook_new() {
387        let hook = Hook::new("test-hook", HookEventType::PreToolUse);
388        assert_eq!(hook.id, "test-hook");
389        assert_eq!(hook.event_type, HookEventType::PreToolUse);
390        assert!(hook.matcher.is_none());
391    }
392
393    #[test]
394    fn test_hook_with_matcher() {
395        let hook = Hook::new("test-hook", HookEventType::PreToolUse)
396            .with_matcher(HookMatcher::tool("Bash"));
397
398        assert!(hook.matcher.is_some());
399        assert_eq!(hook.matcher.unwrap().tool, Some("Bash".to_string()));
400    }
401
402    #[test]
403    fn test_hook_matches_event_type() {
404        let hook = Hook::new("test-hook", HookEventType::PreToolUse);
405
406        let pre_event = make_pre_tool_event("s1", "Bash");
407        assert!(hook.matches(&pre_event));
408
409        // PostToolUse doesn't match
410        let post_event = HookEvent::PostToolUse(crate::hooks::events::PostToolUseEvent {
411            session_id: "s1".to_string(),
412            tool: "Bash".to_string(),
413            args: serde_json::json!({}),
414            result: crate::hooks::events::ToolResultData {
415                success: true,
416                output: "".to_string(),
417                exit_code: Some(0),
418                duration_ms: 100,
419            },
420        });
421        assert!(!hook.matches(&post_event));
422    }
423
424    #[test]
425    fn test_hook_matches_with_matcher() {
426        let hook = Hook::new("test-hook", HookEventType::PreToolUse)
427            .with_matcher(HookMatcher::tool("Bash"));
428
429        let bash_event = make_pre_tool_event("s1", "Bash");
430        let read_event = make_pre_tool_event("s1", "Read");
431
432        assert!(hook.matches(&bash_event));
433        assert!(!hook.matches(&read_event));
434    }
435
436    #[test]
437    fn test_hook_result_constructors() {
438        let cont = HookResult::continue_();
439        assert!(cont.is_continue());
440        assert!(!cont.is_block());
441
442        let cont_with = HookResult::continue_with(serde_json::json!({"key": "value"}));
443        assert!(cont_with.is_continue());
444
445        let block = HookResult::block("Blocked");
446        assert!(block.is_block());
447        assert!(!block.is_continue());
448
449        let retry = HookResult::retry(1000);
450        assert!(!retry.is_continue());
451        assert!(!retry.is_block());
452
453        let skip = HookResult::skip();
454        assert!(!skip.is_continue());
455        assert!(!skip.is_block());
456    }
457
458    #[test]
459    fn test_engine_register_unregister() {
460        let engine = HookEngine::new();
461
462        let hook = Hook::new("test-hook", HookEventType::PreToolUse);
463        engine.register(hook);
464
465        assert_eq!(engine.hook_count(), 1);
466        assert!(engine.get_hook("test-hook").is_some());
467
468        let removed = engine.unregister("test-hook");
469        assert!(removed.is_some());
470        assert_eq!(engine.hook_count(), 0);
471    }
472
473    #[test]
474    fn test_engine_matching_hooks() {
475        let engine = HookEngine::new();
476
477        // Register multiple hooks
478        engine.register(
479            Hook::new("hook-1", HookEventType::PreToolUse).with_config(HookConfig {
480                priority: 10,
481                ..Default::default()
482            }),
483        );
484        engine.register(
485            Hook::new("hook-2", HookEventType::PreToolUse)
486                .with_matcher(HookMatcher::tool("Bash"))
487                .with_config(HookConfig {
488                    priority: 5,
489                    ..Default::default()
490                }),
491        );
492        engine.register(Hook::new("hook-3", HookEventType::PostToolUse));
493
494        let event = make_pre_tool_event("s1", "Bash");
495        let matching = engine.matching_hooks(&event);
496
497        // Should match hook-1 and hook-2 (both are PreToolUse)
498        assert_eq!(matching.len(), 2);
499
500        // Sorted by priority, hook-2 (priority=5) should be first
501        assert_eq!(matching[0].id, "hook-2");
502        assert_eq!(matching[1].id, "hook-1");
503    }
504
505    #[tokio::test]
506    async fn test_engine_fire_no_hooks() {
507        let engine = HookEngine::new();
508        let event = make_pre_tool_event("s1", "Bash");
509
510        let result = engine.fire(&event).await;
511        assert!(result.is_continue());
512    }
513
514    #[tokio::test]
515    async fn test_engine_fire_no_handler() {
516        let engine = HookEngine::new();
517        engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
518
519        let event = make_pre_tool_event("s1", "Bash");
520        let result = engine.fire(&event).await;
521
522        // No handler, should continue
523        assert!(result.is_continue());
524    }
525
526    /// Test handler: always continue
527    struct ContinueHandler;
528    impl HookHandler for ContinueHandler {
529        fn handle(&self, _event: &HookEvent) -> HookResponse {
530            HookResponse::continue_()
531        }
532    }
533
534    /// Test handler: always block
535    struct BlockHandler {
536        reason: String,
537    }
538    impl HookHandler for BlockHandler {
539        fn handle(&self, _event: &HookEvent) -> HookResponse {
540            HookResponse::block(&self.reason)
541        }
542    }
543
544    #[tokio::test]
545    async fn test_engine_fire_with_continue_handler() {
546        let engine = HookEngine::new();
547        engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
548        engine.register_handler("test-hook", Arc::new(ContinueHandler));
549
550        let event = make_pre_tool_event("s1", "Bash");
551        let result = engine.fire(&event).await;
552
553        assert!(result.is_continue());
554    }
555
556    #[tokio::test]
557    async fn test_engine_fire_with_block_handler() {
558        let engine = HookEngine::new();
559        engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
560        engine.register_handler(
561            "test-hook",
562            Arc::new(BlockHandler {
563                reason: "Dangerous command".to_string(),
564            }),
565        );
566
567        let event = make_pre_tool_event("s1", "Bash");
568        let result = engine.fire(&event).await;
569
570        assert!(result.is_block());
571        if let HookResult::Block(reason) = result {
572            assert_eq!(reason, "Dangerous command");
573        }
574    }
575
576    #[tokio::test]
577    async fn test_engine_fire_priority_order() {
578        let engine = HookEngine::new();
579
580        // Register two hooks, lower priority one blocks
581        engine.register(
582            Hook::new("block-hook", HookEventType::PreToolUse).with_config(HookConfig {
583                priority: 5, // Higher priority (executes first)
584                ..Default::default()
585            }),
586        );
587        engine.register(
588            Hook::new("continue-hook", HookEventType::PreToolUse).with_config(HookConfig {
589                priority: 10,
590                ..Default::default()
591            }),
592        );
593
594        engine.register_handler(
595            "block-hook",
596            Arc::new(BlockHandler {
597                reason: "Blocked first".to_string(),
598            }),
599        );
600        engine.register_handler("continue-hook", Arc::new(ContinueHandler));
601
602        let event = make_pre_tool_event("s1", "Bash");
603        let result = engine.fire(&event).await;
604
605        // block-hook executes first, should block
606        assert!(result.is_block());
607    }
608
609    #[test]
610    fn test_hook_serialization() {
611        let hook = Hook::new("test-hook", HookEventType::PreToolUse)
612            .with_matcher(HookMatcher::tool("Bash"))
613            .with_config(HookConfig {
614                priority: 50,
615                timeout_ms: 5000,
616                async_execution: true,
617                max_retries: 3,
618            });
619
620        let json = serde_json::to_string(&hook).unwrap();
621        assert!(json.contains("test-hook"));
622        assert!(json.contains("pre_tool_use"));
623        assert!(json.contains("Bash"));
624
625        let parsed: Hook = serde_json::from_str(&json).unwrap();
626        assert_eq!(parsed.id, "test-hook");
627        assert_eq!(parsed.event_type, HookEventType::PreToolUse);
628        assert_eq!(parsed.config.priority, 50);
629    }
630
631    #[test]
632    fn test_all_hooks() {
633        let engine = HookEngine::new();
634        engine.register(Hook::new("hook-1", HookEventType::PreToolUse));
635        engine.register(Hook::new("hook-2", HookEventType::PostToolUse));
636
637        let all = engine.all_hooks();
638        assert_eq!(all.len(), 2);
639    }
640
641    fn make_skill_load_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
642        HookEvent::SkillLoad(crate::hooks::events::SkillLoadEvent {
643            skill_name: skill_name.to_string(),
644            tool_names: tools.iter().map(|s| s.to_string()).collect(),
645            version: Some("1.0.0".to_string()),
646            description: Some("Test skill".to_string()),
647            loaded_at: 1234567890,
648        })
649    }
650
651    fn make_skill_unload_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
652        HookEvent::SkillUnload(crate::hooks::events::SkillUnloadEvent {
653            skill_name: skill_name.to_string(),
654            tool_names: tools.iter().map(|s| s.to_string()).collect(),
655            duration_ms: 60000,
656        })
657    }
658
659    #[tokio::test]
660    async fn test_engine_fire_skill_load() {
661        let engine = HookEngine::new();
662
663        // Register a hook for skill load events
664        engine.register(Hook::new("skill-load-hook", HookEventType::SkillLoad));
665        engine.register_handler("skill-load-hook", Arc::new(ContinueHandler));
666
667        let event = make_skill_load_event("my-skill", vec!["tool1", "tool2"]);
668        let result = engine.fire(&event).await;
669
670        assert!(result.is_continue());
671    }
672
673    #[tokio::test]
674    async fn test_engine_fire_skill_unload() {
675        let engine = HookEngine::new();
676
677        // Register a hook for skill unload events
678        engine.register(Hook::new("skill-unload-hook", HookEventType::SkillUnload));
679        engine.register_handler("skill-unload-hook", Arc::new(ContinueHandler));
680
681        let event = make_skill_unload_event("my-skill", vec!["tool1", "tool2"]);
682        let result = engine.fire(&event).await;
683
684        assert!(result.is_continue());
685    }
686
687    #[tokio::test]
688    async fn test_engine_skill_hook_with_matcher() {
689        let engine = HookEngine::new();
690
691        // Register a hook that only matches specific skill
692        engine.register(
693            Hook::new("specific-skill-hook", HookEventType::SkillLoad)
694                .with_matcher(HookMatcher::skill("my-skill")),
695        );
696        engine.register_handler(
697            "specific-skill-hook",
698            Arc::new(BlockHandler {
699                reason: "Skill blocked".to_string(),
700            }),
701        );
702
703        // Should match and block
704        let matching_event = make_skill_load_event("my-skill", vec!["tool1"]);
705        let result = engine.fire(&matching_event).await;
706        assert!(result.is_block());
707
708        // Should not match (no hooks match, so continue)
709        let non_matching_event = make_skill_load_event("other-skill", vec!["tool1"]);
710        let result = engine.fire(&non_matching_event).await;
711        assert!(result.is_continue());
712    }
713
714    #[tokio::test]
715    async fn test_engine_skill_hook_pattern_matcher() {
716        let engine = HookEngine::new();
717
718        // Register a hook with glob pattern
719        engine.register(
720            Hook::new("test-skill-hook", HookEventType::SkillLoad)
721                .with_matcher(HookMatcher::skill("test-*")),
722        );
723        engine.register_handler(
724            "test-skill-hook",
725            Arc::new(BlockHandler {
726                reason: "Test skill blocked".to_string(),
727            }),
728        );
729
730        // Should match pattern
731        let test_skill = make_skill_load_event("test-alpha", vec!["tool1"]);
732        let result = engine.fire(&test_skill).await;
733        assert!(result.is_block());
734
735        let test_skill2 = make_skill_load_event("test-beta", vec!["tool1"]);
736        let result = engine.fire(&test_skill2).await;
737        assert!(result.is_block());
738
739        // Should not match pattern
740        let prod_skill = make_skill_load_event("prod-skill", vec!["tool1"]);
741        let result = engine.fire(&prod_skill).await;
742        assert!(result.is_continue());
743    }
744}