Skip to main content

claude_agent/hooks/
traits.rs

1//! Hook traits and types.
2
3use crate::types::ToolOutput;
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tokio_util::sync::CancellationToken;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13#[non_exhaustive]
14pub enum HookEvent {
15    PreToolUse,
16    PostToolUse,
17    PostToolUseFailure,
18    UserPromptSubmit,
19    Stop,
20    SubagentStart,
21    SubagentStop,
22    PreCompact,
23    SessionStart,
24    SessionEnd,
25}
26
27impl HookEvent {
28    /// Returns true if this hook event can block execution.
29    ///
30    /// Blockable events use fail-closed semantics: if the hook fails or times out,
31    /// the operation is blocked. This ensures security policies are enforced.
32    ///
33    /// Blockable events:
34    /// - `PreToolUse`: Can block tool execution
35    /// - `UserPromptSubmit`: Can block prompt processing
36    /// - `SessionStart`: Can block session initialization
37    /// - `SubagentStart`: Can block subagent spawning
38    ///
39    /// Non-blockable pre-events:
40    /// - `PreCompact`: Compaction is non-critical; failures are logged but don't stop execution
41    pub fn can_block(&self) -> bool {
42        matches!(
43            self,
44            Self::PreToolUse | Self::UserPromptSubmit | Self::SessionStart | Self::SubagentStart
45        )
46    }
47
48    /// Parse a PascalCase event name (as used in hooks.json configs).
49    pub fn from_pascal_case(s: &str) -> Option<Self> {
50        match s {
51            "PreToolUse" => Some(Self::PreToolUse),
52            "PostToolUse" => Some(Self::PostToolUse),
53            "PostToolUseFailure" => Some(Self::PostToolUseFailure),
54            "UserPromptSubmit" => Some(Self::UserPromptSubmit),
55            "Stop" => Some(Self::Stop),
56            "SubagentStart" => Some(Self::SubagentStart),
57            "SubagentStop" => Some(Self::SubagentStop),
58            "PreCompact" => Some(Self::PreCompact),
59            "SessionStart" => Some(Self::SessionStart),
60            "SessionEnd" => Some(Self::SessionEnd),
61            _ => None,
62        }
63    }
64
65    pub fn all() -> &'static [HookEvent] {
66        &[
67            Self::PreToolUse,
68            Self::PostToolUse,
69            Self::PostToolUseFailure,
70            Self::UserPromptSubmit,
71            Self::Stop,
72            Self::SubagentStart,
73            Self::SubagentStop,
74            Self::PreCompact,
75            Self::SessionStart,
76            Self::SessionEnd,
77        ]
78    }
79}
80
81impl std::fmt::Display for HookEvent {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        let s = match self {
84            Self::PreToolUse => "pre_tool_use",
85            Self::PostToolUse => "post_tool_use",
86            Self::PostToolUseFailure => "post_tool_use_failure",
87            Self::UserPromptSubmit => "user_prompt_submit",
88            Self::Stop => "stop",
89            Self::SubagentStart => "subagent_start",
90            Self::SubagentStop => "subagent_stop",
91            Self::PreCompact => "pre_compact",
92            Self::SessionStart => "session_start",
93            Self::SessionEnd => "session_end",
94        };
95        write!(f, "{}", s)
96    }
97}
98
99#[derive(Clone, Debug)]
100#[non_exhaustive]
101pub enum HookEventData {
102    PreToolUse {
103        tool_name: String,
104        tool_input: Value,
105    },
106    PostToolUse {
107        tool_name: String,
108        tool_result: ToolOutput,
109    },
110    PostToolUseFailure {
111        tool_name: String,
112        error: String,
113    },
114    UserPromptSubmit {
115        prompt: String,
116    },
117    Stop,
118    SubagentStart {
119        subagent_id: String,
120        subagent_type: String,
121        description: String,
122    },
123    SubagentStop {
124        subagent_id: String,
125        success: bool,
126        error: Option<String>,
127    },
128    PreCompact,
129    SessionStart,
130    SessionEnd,
131}
132
133impl HookEventData {
134    pub fn event_type(&self) -> HookEvent {
135        match self {
136            Self::PreToolUse { .. } => HookEvent::PreToolUse,
137            Self::PostToolUse { .. } => HookEvent::PostToolUse,
138            Self::PostToolUseFailure { .. } => HookEvent::PostToolUseFailure,
139            Self::UserPromptSubmit { .. } => HookEvent::UserPromptSubmit,
140            Self::Stop => HookEvent::Stop,
141            Self::SubagentStart { .. } => HookEvent::SubagentStart,
142            Self::SubagentStop { .. } => HookEvent::SubagentStop,
143            Self::PreCompact => HookEvent::PreCompact,
144            Self::SessionStart => HookEvent::SessionStart,
145            Self::SessionEnd => HookEvent::SessionEnd,
146        }
147    }
148
149    pub fn tool_name(&self) -> Option<&str> {
150        match self {
151            Self::PreToolUse { tool_name, .. }
152            | Self::PostToolUse { tool_name, .. }
153            | Self::PostToolUseFailure { tool_name, .. } => Some(tool_name),
154            _ => None,
155        }
156    }
157
158    pub fn tool_input(&self) -> Option<&Value> {
159        match self {
160            Self::PreToolUse { tool_input, .. } => Some(tool_input),
161            _ => None,
162        }
163    }
164
165    pub fn subagent_id(&self) -> Option<&str> {
166        match self {
167            Self::SubagentStart { subagent_id, .. } | Self::SubagentStop { subagent_id, .. } => {
168                Some(subagent_id)
169            }
170            _ => None,
171        }
172    }
173}
174
175#[derive(Clone, Debug)]
176pub struct HookInput {
177    pub session_id: String,
178    pub timestamp: DateTime<Utc>,
179    pub data: HookEventData,
180    pub metadata: Option<Value>,
181}
182
183impl HookInput {
184    pub fn new(session_id: impl Into<String>, data: HookEventData) -> Self {
185        Self {
186            session_id: session_id.into(),
187            timestamp: Utc::now(),
188            data,
189            metadata: None,
190        }
191    }
192
193    pub fn event_type(&self) -> HookEvent {
194        self.data.event_type()
195    }
196
197    pub fn tool_name(&self) -> Option<&str> {
198        self.data.tool_name()
199    }
200
201    pub fn subagent_id(&self) -> Option<&str> {
202        self.data.subagent_id()
203    }
204
205    pub fn pre_tool_use(
206        session_id: impl Into<String>,
207        tool_name: impl Into<String>,
208        tool_input: Value,
209    ) -> Self {
210        Self::new(
211            session_id,
212            HookEventData::PreToolUse {
213                tool_name: tool_name.into(),
214                tool_input,
215            },
216        )
217    }
218
219    pub fn post_tool_use(
220        session_id: impl Into<String>,
221        tool_name: impl Into<String>,
222        tool_result: ToolOutput,
223    ) -> Self {
224        Self::new(
225            session_id,
226            HookEventData::PostToolUse {
227                tool_name: tool_name.into(),
228                tool_result,
229            },
230        )
231    }
232
233    pub fn post_tool_use_failure(
234        session_id: impl Into<String>,
235        tool_name: impl Into<String>,
236        error: impl Into<String>,
237    ) -> Self {
238        Self::new(
239            session_id,
240            HookEventData::PostToolUseFailure {
241                tool_name: tool_name.into(),
242                error: error.into(),
243            },
244        )
245    }
246
247    pub fn user_prompt_submit(session_id: impl Into<String>, prompt: impl Into<String>) -> Self {
248        Self::new(
249            session_id,
250            HookEventData::UserPromptSubmit {
251                prompt: prompt.into(),
252            },
253        )
254    }
255
256    pub fn session_start(session_id: impl Into<String>) -> Self {
257        Self::new(session_id, HookEventData::SessionStart)
258    }
259
260    pub fn session_end(session_id: impl Into<String>) -> Self {
261        Self::new(session_id, HookEventData::SessionEnd)
262    }
263
264    pub fn stop(session_id: impl Into<String>) -> Self {
265        Self::new(session_id, HookEventData::Stop)
266    }
267
268    pub fn pre_compact(session_id: impl Into<String>) -> Self {
269        Self::new(session_id, HookEventData::PreCompact)
270    }
271
272    pub fn subagent_start(
273        session_id: impl Into<String>,
274        subagent_id: impl Into<String>,
275        subagent_type: impl Into<String>,
276        description: impl Into<String>,
277    ) -> Self {
278        Self::new(
279            session_id,
280            HookEventData::SubagentStart {
281                subagent_id: subagent_id.into(),
282                subagent_type: subagent_type.into(),
283                description: description.into(),
284            },
285        )
286    }
287
288    pub fn subagent_stop(
289        session_id: impl Into<String>,
290        subagent_id: impl Into<String>,
291        success: bool,
292        error: Option<String>,
293    ) -> Self {
294        Self::new(
295            session_id,
296            HookEventData::SubagentStop {
297                subagent_id: subagent_id.into(),
298                success,
299                error,
300            },
301        )
302    }
303}
304
305#[derive(Clone, Debug, Default)]
306pub struct HookOutput {
307    pub continue_execution: bool,
308    pub stop_reason: Option<String>,
309    pub suppress_logging: bool,
310    pub system_message: Option<String>,
311    pub updated_input: Option<Value>,
312    pub additional_context: Option<String>,
313}
314
315impl HookOutput {
316    pub fn allow() -> Self {
317        Self {
318            continue_execution: true,
319            ..Default::default()
320        }
321    }
322
323    pub fn block(reason: impl Into<String>) -> Self {
324        Self {
325            continue_execution: false,
326            stop_reason: Some(reason.into()),
327            ..Default::default()
328        }
329    }
330
331    pub fn system_message(mut self, message: impl Into<String>) -> Self {
332        self.system_message = Some(message.into());
333        self
334    }
335
336    pub fn context(mut self, context: impl Into<String>) -> Self {
337        self.additional_context = Some(context.into());
338        self
339    }
340
341    pub fn updated_input(mut self, input: Value) -> Self {
342        self.updated_input = Some(input);
343        self
344    }
345
346    pub fn suppress_logging(mut self) -> Self {
347        self.suppress_logging = true;
348        self
349    }
350}
351
352#[derive(Clone, Debug)]
353pub struct HookContext {
354    pub session_id: String,
355    pub cancellation_token: CancellationToken,
356    pub cwd: Option<std::path::PathBuf>,
357    pub env: std::collections::HashMap<String, String>,
358}
359
360impl Default for HookContext {
361    fn default() -> Self {
362        Self {
363            session_id: String::new(),
364            cancellation_token: CancellationToken::new(),
365            cwd: None,
366            env: std::collections::HashMap::new(),
367        }
368    }
369}
370
371impl HookContext {
372    pub fn new(session_id: impl Into<String>) -> Self {
373        Self {
374            session_id: session_id.into(),
375            ..Default::default()
376        }
377    }
378
379    pub fn cancellation_token(mut self, token: CancellationToken) -> Self {
380        self.cancellation_token = token;
381        self
382    }
383
384    pub fn cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
385        self.cwd = Some(cwd.into());
386        self
387    }
388
389    pub fn env(mut self, env: std::collections::HashMap<String, String>) -> Self {
390        self.env = env;
391        self
392    }
393}
394
395/// Origin of a hook registration.
396#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
397pub enum HookSource {
398    #[default]
399    Builtin,
400    User,
401    Project,
402}
403
404/// Hook metadata for configuration.
405#[derive(Clone, Debug)]
406pub struct HookMetadata {
407    pub name: String,
408    pub events: Vec<HookEvent>,
409    pub priority: i32,
410    pub timeout_secs: u64,
411    pub tool_matcher: Option<Regex>,
412    pub source: HookSource,
413}
414
415impl HookMetadata {
416    pub fn new(name: impl Into<String>, events: Vec<HookEvent>) -> Self {
417        Self {
418            name: name.into(),
419            events,
420            priority: 0,
421            timeout_secs: 60,
422            tool_matcher: None,
423            source: HookSource::default(),
424        }
425    }
426
427    pub fn priority(mut self, priority: i32) -> Self {
428        self.priority = priority;
429        self
430    }
431
432    pub fn timeout(mut self, secs: u64) -> Self {
433        self.timeout_secs = secs;
434        self
435    }
436
437    pub fn tool_matcher(mut self, pattern: &str) -> Self {
438        if let Ok(regex) = Regex::new(pattern) {
439            self.tool_matcher = Some(regex);
440        }
441        self
442    }
443
444    pub fn source(mut self, source: HookSource) -> Self {
445        self.source = source;
446        self
447    }
448}
449
450#[async_trait]
451pub trait Hook: Send + Sync {
452    fn name(&self) -> &str;
453    fn events(&self) -> &[HookEvent];
454
455    #[inline]
456    fn tool_matcher(&self) -> Option<&Regex> {
457        None
458    }
459
460    #[inline]
461    fn timeout_secs(&self) -> u64 {
462        60
463    }
464
465    #[inline]
466    fn priority(&self) -> i32 {
467        0
468    }
469
470    async fn execute(
471        &self,
472        input: HookInput,
473        hook_context: &HookContext,
474    ) -> Result<HookOutput, crate::Error>;
475
476    #[inline]
477    fn source(&self) -> HookSource {
478        HookSource::Builtin
479    }
480
481    fn metadata(&self) -> HookMetadata {
482        HookMetadata {
483            name: self.name().to_string(),
484            events: self.events().to_vec(),
485            priority: self.priority(),
486            timeout_secs: self.timeout_secs(),
487            tool_matcher: self.tool_matcher().cloned(),
488            source: self.source(),
489        }
490    }
491}
492
493pub struct FnHook<F> {
494    name: String,
495    events: Vec<HookEvent>,
496    handler: F,
497    priority: i32,
498    timeout_secs: u64,
499    tool_matcher: Option<Regex>,
500}
501
502impl<F> FnHook<F> {
503    pub fn builder(name: impl Into<String>, events: Vec<HookEvent>) -> FnHookBuilder {
504        FnHookBuilder {
505            name: name.into(),
506            events,
507            priority: 0,
508            timeout_secs: 60,
509            tool_matcher: None,
510        }
511    }
512}
513
514pub struct FnHookBuilder {
515    name: String,
516    events: Vec<HookEvent>,
517    priority: i32,
518    timeout_secs: u64,
519    tool_matcher: Option<Regex>,
520}
521
522impl FnHookBuilder {
523    pub fn priority(mut self, priority: i32) -> Self {
524        self.priority = priority;
525        self
526    }
527
528    pub fn timeout_secs(mut self, secs: u64) -> Self {
529        self.timeout_secs = secs;
530        self
531    }
532
533    pub fn tool_matcher(mut self, pattern: &str) -> Self {
534        if let Ok(regex) = Regex::new(pattern) {
535            self.tool_matcher = Some(regex);
536        }
537        self
538    }
539
540    pub fn handler<F, Fut>(self, handler: F) -> FnHook<F>
541    where
542        F: Fn(HookInput, HookContext) -> Fut + Send + Sync,
543        Fut: std::future::Future<Output = Result<HookOutput, crate::Error>> + Send,
544    {
545        FnHook {
546            name: self.name,
547            events: self.events,
548            handler,
549            priority: self.priority,
550            timeout_secs: self.timeout_secs,
551            tool_matcher: self.tool_matcher,
552        }
553    }
554}
555
556#[async_trait]
557impl<F, Fut> Hook for FnHook<F>
558where
559    F: Fn(HookInput, HookContext) -> Fut + Send + Sync,
560    Fut: std::future::Future<Output = Result<HookOutput, crate::Error>> + Send,
561{
562    fn name(&self) -> &str {
563        &self.name
564    }
565
566    fn events(&self) -> &[HookEvent] {
567        &self.events
568    }
569
570    fn priority(&self) -> i32 {
571        self.priority
572    }
573
574    fn timeout_secs(&self) -> u64 {
575        self.timeout_secs
576    }
577
578    fn tool_matcher(&self) -> Option<&Regex> {
579        self.tool_matcher.as_ref()
580    }
581
582    async fn execute(
583        &self,
584        input: HookInput,
585        hook_context: &HookContext,
586    ) -> Result<HookOutput, crate::Error> {
587        (self.handler)(input, hook_context.clone()).await
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594
595    #[test]
596    fn test_hook_event_display() {
597        assert_eq!(HookEvent::PreToolUse.to_string(), "pre_tool_use");
598        assert_eq!(HookEvent::PostToolUse.to_string(), "post_tool_use");
599        assert_eq!(HookEvent::SessionStart.to_string(), "session_start");
600    }
601
602    #[test]
603    fn test_hook_event_can_block() {
604        // Blockable events (fail-closed semantics)
605        assert!(HookEvent::PreToolUse.can_block());
606        assert!(HookEvent::UserPromptSubmit.can_block());
607        assert!(HookEvent::SessionStart.can_block());
608        assert!(!HookEvent::PreCompact.can_block());
609        assert!(HookEvent::SubagentStart.can_block());
610
611        // Non-blockable events (fail-open semantics)
612        assert!(!HookEvent::PostToolUse.can_block());
613        assert!(!HookEvent::PostToolUseFailure.can_block());
614        assert!(!HookEvent::SessionEnd.can_block());
615        assert!(!HookEvent::SubagentStop.can_block());
616        assert!(!HookEvent::Stop.can_block());
617    }
618
619    #[test]
620    fn test_hook_input_builders() {
621        let input =
622            HookInput::pre_tool_use("session-1", "Read", serde_json::json!({"path": "/tmp"}));
623        assert_eq!(input.event_type(), HookEvent::PreToolUse);
624        assert_eq!(input.tool_name(), Some("Read"));
625        assert_eq!(input.session_id, "session-1");
626
627        let input = HookInput::session_start("session-2");
628        assert_eq!(input.event_type(), HookEvent::SessionStart);
629        assert_eq!(input.session_id, "session-2");
630    }
631
632    #[test]
633    fn test_hook_output_builders() {
634        let output = HookOutput::allow();
635        assert!(output.continue_execution);
636        assert!(output.stop_reason.is_none());
637
638        let output = HookOutput::block("Dangerous operation");
639        assert!(!output.continue_execution);
640        assert_eq!(output.stop_reason, Some("Dangerous operation".to_string()));
641
642        let output = HookOutput::allow()
643            .system_message("Added context")
644            .context("More info")
645            .suppress_logging();
646        assert!(output.continue_execution);
647        assert!(output.suppress_logging);
648        assert_eq!(output.system_message, Some("Added context".to_string()));
649        assert_eq!(output.additional_context, Some("More info".to_string()));
650    }
651
652    #[test]
653    fn test_hook_event_data_accessors() {
654        let data = HookEventData::PreToolUse {
655            tool_name: "Bash".to_string(),
656            tool_input: serde_json::json!({"command": "ls"}),
657        };
658        assert_eq!(data.event_type(), HookEvent::PreToolUse);
659        assert_eq!(data.tool_name(), Some("Bash"));
660        assert!(data.tool_input().is_some());
661
662        let data = HookEventData::SessionStart;
663        assert_eq!(data.event_type(), HookEvent::SessionStart);
664        assert_eq!(data.tool_name(), None);
665        assert!(data.tool_input().is_none());
666    }
667}