Skip to main content

claude_agent/hooks/
traits.rs

1//! Hook traits and types.
2
3use crate::types::ToolOutput;
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tokio_util::sync::CancellationToken;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13#[non_exhaustive]
14pub enum HookEvent {
15    PreToolUse,
16    PostToolUse,
17    PostToolUseFailure,
18    UserPromptSubmit,
19    Stop,
20    SubagentStart,
21    SubagentStop,
22    PreCompact,
23    SessionStart,
24    SessionEnd,
25}
26
27impl HookEvent {
28    /// Returns true if this hook event can block execution.
29    ///
30    /// Blockable events use fail-closed semantics: if the hook fails or times out,
31    /// the operation is blocked. This ensures security policies are enforced.
32    ///
33    /// Blockable events:
34    /// - `PreToolUse`: Can block tool execution
35    /// - `UserPromptSubmit`: Can block prompt processing
36    /// - `SessionStart`: Can block session initialization
37    /// - `PreCompact`: Can block context compaction
38    /// - `SubagentStart`: Can block subagent spawning
39    pub fn can_block(&self) -> bool {
40        matches!(
41            self,
42            Self::PreToolUse
43                | Self::UserPromptSubmit
44                | Self::SessionStart
45                | Self::PreCompact
46                | Self::SubagentStart
47        )
48    }
49
50    /// Parse a PascalCase event name (as used in hooks.json configs).
51    pub fn from_pascal_case(s: &str) -> Option<Self> {
52        match s {
53            "PreToolUse" => Some(Self::PreToolUse),
54            "PostToolUse" => Some(Self::PostToolUse),
55            "PostToolUseFailure" => Some(Self::PostToolUseFailure),
56            "UserPromptSubmit" => Some(Self::UserPromptSubmit),
57            "Stop" => Some(Self::Stop),
58            "SubagentStart" => Some(Self::SubagentStart),
59            "SubagentStop" => Some(Self::SubagentStop),
60            "PreCompact" => Some(Self::PreCompact),
61            "SessionStart" => Some(Self::SessionStart),
62            "SessionEnd" => Some(Self::SessionEnd),
63            _ => None,
64        }
65    }
66
67    pub fn all() -> &'static [HookEvent] {
68        &[
69            Self::PreToolUse,
70            Self::PostToolUse,
71            Self::PostToolUseFailure,
72            Self::UserPromptSubmit,
73            Self::Stop,
74            Self::SubagentStart,
75            Self::SubagentStop,
76            Self::PreCompact,
77            Self::SessionStart,
78            Self::SessionEnd,
79        ]
80    }
81}
82
83impl std::fmt::Display for HookEvent {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        let s = match self {
86            Self::PreToolUse => "pre_tool_use",
87            Self::PostToolUse => "post_tool_use",
88            Self::PostToolUseFailure => "post_tool_use_failure",
89            Self::UserPromptSubmit => "user_prompt_submit",
90            Self::Stop => "stop",
91            Self::SubagentStart => "subagent_start",
92            Self::SubagentStop => "subagent_stop",
93            Self::PreCompact => "pre_compact",
94            Self::SessionStart => "session_start",
95            Self::SessionEnd => "session_end",
96        };
97        write!(f, "{}", s)
98    }
99}
100
101#[derive(Clone, Debug)]
102#[non_exhaustive]
103pub enum HookEventData {
104    PreToolUse {
105        tool_name: String,
106        tool_input: Value,
107    },
108    PostToolUse {
109        tool_name: String,
110        tool_result: ToolOutput,
111    },
112    PostToolUseFailure {
113        tool_name: String,
114        error: String,
115    },
116    UserPromptSubmit {
117        prompt: String,
118    },
119    Stop,
120    SubagentStart {
121        subagent_id: String,
122        subagent_type: String,
123        description: String,
124    },
125    SubagentStop {
126        subagent_id: String,
127        success: bool,
128        error: Option<String>,
129    },
130    PreCompact,
131    SessionStart,
132    SessionEnd,
133}
134
135impl HookEventData {
136    pub fn event_type(&self) -> HookEvent {
137        match self {
138            Self::PreToolUse { .. } => HookEvent::PreToolUse,
139            Self::PostToolUse { .. } => HookEvent::PostToolUse,
140            Self::PostToolUseFailure { .. } => HookEvent::PostToolUseFailure,
141            Self::UserPromptSubmit { .. } => HookEvent::UserPromptSubmit,
142            Self::Stop => HookEvent::Stop,
143            Self::SubagentStart { .. } => HookEvent::SubagentStart,
144            Self::SubagentStop { .. } => HookEvent::SubagentStop,
145            Self::PreCompact => HookEvent::PreCompact,
146            Self::SessionStart => HookEvent::SessionStart,
147            Self::SessionEnd => HookEvent::SessionEnd,
148        }
149    }
150
151    pub fn tool_name(&self) -> Option<&str> {
152        match self {
153            Self::PreToolUse { tool_name, .. }
154            | Self::PostToolUse { tool_name, .. }
155            | Self::PostToolUseFailure { tool_name, .. } => Some(tool_name),
156            _ => None,
157        }
158    }
159
160    pub fn tool_input(&self) -> Option<&Value> {
161        match self {
162            Self::PreToolUse { tool_input, .. } => Some(tool_input),
163            _ => None,
164        }
165    }
166
167    pub fn subagent_id(&self) -> Option<&str> {
168        match self {
169            Self::SubagentStart { subagent_id, .. } | Self::SubagentStop { subagent_id, .. } => {
170                Some(subagent_id)
171            }
172            _ => None,
173        }
174    }
175}
176
177#[derive(Clone, Debug)]
178pub struct HookInput {
179    pub session_id: String,
180    pub timestamp: DateTime<Utc>,
181    pub data: HookEventData,
182    pub metadata: Option<Value>,
183}
184
185impl HookInput {
186    pub fn new(session_id: impl Into<String>, data: HookEventData) -> Self {
187        Self {
188            session_id: session_id.into(),
189            timestamp: Utc::now(),
190            data,
191            metadata: None,
192        }
193    }
194
195    pub fn event_type(&self) -> HookEvent {
196        self.data.event_type()
197    }
198
199    pub fn tool_name(&self) -> Option<&str> {
200        self.data.tool_name()
201    }
202
203    pub fn subagent_id(&self) -> Option<&str> {
204        self.data.subagent_id()
205    }
206
207    pub fn pre_tool_use(
208        session_id: impl Into<String>,
209        tool_name: impl Into<String>,
210        tool_input: Value,
211    ) -> Self {
212        Self::new(
213            session_id,
214            HookEventData::PreToolUse {
215                tool_name: tool_name.into(),
216                tool_input,
217            },
218        )
219    }
220
221    pub fn post_tool_use(
222        session_id: impl Into<String>,
223        tool_name: impl Into<String>,
224        tool_result: ToolOutput,
225    ) -> Self {
226        Self::new(
227            session_id,
228            HookEventData::PostToolUse {
229                tool_name: tool_name.into(),
230                tool_result,
231            },
232        )
233    }
234
235    pub fn post_tool_use_failure(
236        session_id: impl Into<String>,
237        tool_name: impl Into<String>,
238        error: impl Into<String>,
239    ) -> Self {
240        Self::new(
241            session_id,
242            HookEventData::PostToolUseFailure {
243                tool_name: tool_name.into(),
244                error: error.into(),
245            },
246        )
247    }
248
249    pub fn user_prompt_submit(session_id: impl Into<String>, prompt: impl Into<String>) -> Self {
250        Self::new(
251            session_id,
252            HookEventData::UserPromptSubmit {
253                prompt: prompt.into(),
254            },
255        )
256    }
257
258    pub fn session_start(session_id: impl Into<String>) -> Self {
259        Self::new(session_id, HookEventData::SessionStart)
260    }
261
262    pub fn session_end(session_id: impl Into<String>) -> Self {
263        Self::new(session_id, HookEventData::SessionEnd)
264    }
265
266    pub fn stop(session_id: impl Into<String>) -> Self {
267        Self::new(session_id, HookEventData::Stop)
268    }
269
270    pub fn pre_compact(session_id: impl Into<String>) -> Self {
271        Self::new(session_id, HookEventData::PreCompact)
272    }
273
274    pub fn subagent_start(
275        session_id: impl Into<String>,
276        subagent_id: impl Into<String>,
277        subagent_type: impl Into<String>,
278        description: impl Into<String>,
279    ) -> Self {
280        Self::new(
281            session_id,
282            HookEventData::SubagentStart {
283                subagent_id: subagent_id.into(),
284                subagent_type: subagent_type.into(),
285                description: description.into(),
286            },
287        )
288    }
289
290    pub fn subagent_stop(
291        session_id: impl Into<String>,
292        subagent_id: impl Into<String>,
293        success: bool,
294        error: Option<String>,
295    ) -> Self {
296        Self::new(
297            session_id,
298            HookEventData::SubagentStop {
299                subagent_id: subagent_id.into(),
300                success,
301                error,
302            },
303        )
304    }
305}
306
307#[derive(Clone, Debug, Default)]
308pub struct HookOutput {
309    pub continue_execution: bool,
310    pub stop_reason: Option<String>,
311    pub suppress_logging: bool,
312    pub system_message: Option<String>,
313    pub updated_input: Option<Value>,
314    pub additional_context: Option<String>,
315}
316
317impl HookOutput {
318    pub fn allow() -> Self {
319        Self {
320            continue_execution: true,
321            ..Default::default()
322        }
323    }
324
325    pub fn block(reason: impl Into<String>) -> Self {
326        Self {
327            continue_execution: false,
328            stop_reason: Some(reason.into()),
329            ..Default::default()
330        }
331    }
332
333    pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
334        self.system_message = Some(message.into());
335        self
336    }
337
338    pub fn with_context(mut self, context: impl Into<String>) -> Self {
339        self.additional_context = Some(context.into());
340        self
341    }
342
343    pub fn with_updated_input(mut self, input: Value) -> Self {
344        self.updated_input = Some(input);
345        self
346    }
347
348    pub fn suppress_logging(mut self) -> Self {
349        self.suppress_logging = true;
350        self
351    }
352}
353
354#[derive(Clone, Debug)]
355pub struct HookContext {
356    pub session_id: String,
357    pub cancellation_token: CancellationToken,
358    pub cwd: Option<std::path::PathBuf>,
359    pub env: std::collections::HashMap<String, String>,
360}
361
362impl Default for HookContext {
363    fn default() -> Self {
364        Self {
365            session_id: String::new(),
366            cancellation_token: CancellationToken::new(),
367            cwd: None,
368            env: std::collections::HashMap::new(),
369        }
370    }
371}
372
373impl HookContext {
374    pub fn new(session_id: impl Into<String>) -> Self {
375        Self {
376            session_id: session_id.into(),
377            ..Default::default()
378        }
379    }
380
381    pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
382        self.cancellation_token = token;
383        self
384    }
385
386    pub fn with_cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
387        self.cwd = Some(cwd.into());
388        self
389    }
390
391    pub fn with_env(mut self, env: std::collections::HashMap<String, String>) -> Self {
392        self.env = env;
393        self
394    }
395}
396
397/// Hook metadata for configuration.
398#[derive(Clone, Debug)]
399pub struct HookMetadata {
400    pub name: String,
401    pub events: Vec<HookEvent>,
402    pub priority: i32,
403    pub timeout_secs: u64,
404    pub tool_matcher: Option<Regex>,
405}
406
407impl HookMetadata {
408    pub fn new(name: impl Into<String>, events: Vec<HookEvent>) -> Self {
409        Self {
410            name: name.into(),
411            events,
412            priority: 0,
413            timeout_secs: 60,
414            tool_matcher: None,
415        }
416    }
417
418    pub fn with_priority(mut self, priority: i32) -> Self {
419        self.priority = priority;
420        self
421    }
422
423    pub fn with_timeout(mut self, secs: u64) -> Self {
424        self.timeout_secs = secs;
425        self
426    }
427
428    pub fn with_tool_matcher(mut self, pattern: &str) -> Self {
429        if let Ok(regex) = Regex::new(pattern) {
430            self.tool_matcher = Some(regex);
431        }
432        self
433    }
434}
435
436#[async_trait]
437pub trait Hook: Send + Sync {
438    fn name(&self) -> &str;
439    fn events(&self) -> &[HookEvent];
440
441    #[inline]
442    fn tool_matcher(&self) -> Option<&Regex> {
443        None
444    }
445
446    #[inline]
447    fn timeout_secs(&self) -> u64 {
448        60
449    }
450
451    #[inline]
452    fn priority(&self) -> i32 {
453        0
454    }
455
456    async fn execute(
457        &self,
458        input: HookInput,
459        hook_context: &HookContext,
460    ) -> Result<HookOutput, crate::Error>;
461
462    /// Get full metadata as a struct.
463    fn metadata(&self) -> HookMetadata {
464        HookMetadata {
465            name: self.name().to_string(),
466            events: self.events().to_vec(),
467            priority: self.priority(),
468            timeout_secs: self.timeout_secs(),
469            tool_matcher: self.tool_matcher().cloned(),
470        }
471    }
472}
473
474pub struct FnHook<F> {
475    name: String,
476    events: Vec<HookEvent>,
477    handler: F,
478    priority: i32,
479    timeout_secs: u64,
480    tool_matcher: Option<Regex>,
481}
482
483impl<F> FnHook<F> {
484    pub fn builder(name: impl Into<String>, events: Vec<HookEvent>) -> FnHookBuilder {
485        FnHookBuilder {
486            name: name.into(),
487            events,
488            priority: 0,
489            timeout_secs: 60,
490            tool_matcher: None,
491        }
492    }
493}
494
495pub struct FnHookBuilder {
496    name: String,
497    events: Vec<HookEvent>,
498    priority: i32,
499    timeout_secs: u64,
500    tool_matcher: Option<Regex>,
501}
502
503impl FnHookBuilder {
504    pub fn priority(mut self, priority: i32) -> Self {
505        self.priority = priority;
506        self
507    }
508
509    pub fn timeout_secs(mut self, secs: u64) -> Self {
510        self.timeout_secs = secs;
511        self
512    }
513
514    pub fn tool_matcher(mut self, pattern: &str) -> Self {
515        if let Ok(regex) = Regex::new(pattern) {
516            self.tool_matcher = Some(regex);
517        }
518        self
519    }
520
521    pub fn handler<F, Fut>(self, handler: F) -> FnHook<F>
522    where
523        F: Fn(HookInput, HookContext) -> Fut + Send + Sync,
524        Fut: std::future::Future<Output = Result<HookOutput, crate::Error>> + Send,
525    {
526        FnHook {
527            name: self.name,
528            events: self.events,
529            handler,
530            priority: self.priority,
531            timeout_secs: self.timeout_secs,
532            tool_matcher: self.tool_matcher,
533        }
534    }
535}
536
537#[async_trait]
538impl<F, Fut> Hook for FnHook<F>
539where
540    F: Fn(HookInput, HookContext) -> Fut + Send + Sync,
541    Fut: std::future::Future<Output = Result<HookOutput, crate::Error>> + Send,
542{
543    fn name(&self) -> &str {
544        &self.name
545    }
546
547    fn events(&self) -> &[HookEvent] {
548        &self.events
549    }
550
551    fn priority(&self) -> i32 {
552        self.priority
553    }
554
555    fn timeout_secs(&self) -> u64 {
556        self.timeout_secs
557    }
558
559    fn tool_matcher(&self) -> Option<&Regex> {
560        self.tool_matcher.as_ref()
561    }
562
563    async fn execute(
564        &self,
565        input: HookInput,
566        hook_context: &HookContext,
567    ) -> Result<HookOutput, crate::Error> {
568        (self.handler)(input, hook_context.clone()).await
569    }
570}
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575
576    #[test]
577    fn test_hook_event_display() {
578        assert_eq!(HookEvent::PreToolUse.to_string(), "pre_tool_use");
579        assert_eq!(HookEvent::PostToolUse.to_string(), "post_tool_use");
580        assert_eq!(HookEvent::SessionStart.to_string(), "session_start");
581    }
582
583    #[test]
584    fn test_hook_event_can_block() {
585        // Blockable events (fail-closed semantics)
586        assert!(HookEvent::PreToolUse.can_block());
587        assert!(HookEvent::UserPromptSubmit.can_block());
588        assert!(HookEvent::SessionStart.can_block());
589        assert!(HookEvent::PreCompact.can_block());
590        assert!(HookEvent::SubagentStart.can_block());
591
592        // Non-blockable events (fail-open semantics)
593        assert!(!HookEvent::PostToolUse.can_block());
594        assert!(!HookEvent::PostToolUseFailure.can_block());
595        assert!(!HookEvent::SessionEnd.can_block());
596        assert!(!HookEvent::SubagentStop.can_block());
597        assert!(!HookEvent::Stop.can_block());
598    }
599
600    #[test]
601    fn test_hook_input_builders() {
602        let input =
603            HookInput::pre_tool_use("session-1", "Read", serde_json::json!({"path": "/tmp"}));
604        assert_eq!(input.event_type(), HookEvent::PreToolUse);
605        assert_eq!(input.tool_name(), Some("Read"));
606        assert_eq!(input.session_id, "session-1");
607
608        let input = HookInput::session_start("session-2");
609        assert_eq!(input.event_type(), HookEvent::SessionStart);
610        assert_eq!(input.session_id, "session-2");
611    }
612
613    #[test]
614    fn test_hook_output_builders() {
615        let output = HookOutput::allow();
616        assert!(output.continue_execution);
617        assert!(output.stop_reason.is_none());
618
619        let output = HookOutput::block("Dangerous operation");
620        assert!(!output.continue_execution);
621        assert_eq!(output.stop_reason, Some("Dangerous operation".to_string()));
622
623        let output = HookOutput::allow()
624            .with_system_message("Added context")
625            .with_context("More info")
626            .suppress_logging();
627        assert!(output.continue_execution);
628        assert!(output.suppress_logging);
629        assert_eq!(output.system_message, Some("Added context".to_string()));
630        assert_eq!(output.additional_context, Some("More info".to_string()));
631    }
632
633    #[test]
634    fn test_hook_event_data_accessors() {
635        let data = HookEventData::PreToolUse {
636            tool_name: "Bash".to_string(),
637            tool_input: serde_json::json!({"command": "ls"}),
638        };
639        assert_eq!(data.event_type(), HookEvent::PreToolUse);
640        assert_eq!(data.tool_name(), Some("Bash"));
641        assert!(data.tool_input().is_some());
642
643        let data = HookEventData::SessionStart;
644        assert_eq!(data.event_type(), HookEvent::SessionStart);
645        assert_eq!(data.tool_name(), None);
646        assert!(data.tool_input().is_none());
647    }
648}