claude_agent/hooks/
traits.rs

1//! Hook traits and types.
2
3use crate::types::ToolOutput;
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tokio_util::sync::CancellationToken;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum HookEvent {
14    PreToolUse,
15    PostToolUse,
16    PostToolUseFailure,
17    UserPromptSubmit,
18    Stop,
19    SubagentStart,
20    SubagentStop,
21    PreCompact,
22    SessionStart,
23    SessionEnd,
24}
25
26impl HookEvent {
27    pub fn can_block(&self) -> bool {
28        matches!(self, Self::PreToolUse | Self::UserPromptSubmit)
29    }
30
31    pub fn can_modify_input(&self) -> bool {
32        matches!(self, Self::PreToolUse | Self::UserPromptSubmit)
33    }
34
35    pub fn all() -> &'static [HookEvent] {
36        &[
37            Self::PreToolUse,
38            Self::PostToolUse,
39            Self::PostToolUseFailure,
40            Self::UserPromptSubmit,
41            Self::Stop,
42            Self::SubagentStart,
43            Self::SubagentStop,
44            Self::PreCompact,
45            Self::SessionStart,
46            Self::SessionEnd,
47        ]
48    }
49}
50
51impl std::fmt::Display for HookEvent {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        let s = match self {
54            Self::PreToolUse => "pre_tool_use",
55            Self::PostToolUse => "post_tool_use",
56            Self::PostToolUseFailure => "post_tool_use_failure",
57            Self::UserPromptSubmit => "user_prompt_submit",
58            Self::Stop => "stop",
59            Self::SubagentStart => "subagent_start",
60            Self::SubagentStop => "subagent_stop",
61            Self::PreCompact => "pre_compact",
62            Self::SessionStart => "session_start",
63            Self::SessionEnd => "session_end",
64        };
65        write!(f, "{}", s)
66    }
67}
68
69#[derive(Clone, Debug)]
70pub enum HookEventData {
71    PreToolUse {
72        tool_name: String,
73        tool_input: Value,
74    },
75    PostToolUse {
76        tool_name: String,
77        tool_result: ToolOutput,
78    },
79    PostToolUseFailure {
80        tool_name: String,
81        error: String,
82    },
83    UserPromptSubmit {
84        prompt: String,
85    },
86    Stop,
87    SubagentStart {
88        subagent_id: String,
89        subagent_type: String,
90        description: String,
91    },
92    SubagentStop {
93        subagent_id: String,
94        success: bool,
95        error: Option<String>,
96    },
97    PreCompact,
98    SessionStart,
99    SessionEnd,
100}
101
102impl HookEventData {
103    pub fn event_type(&self) -> HookEvent {
104        match self {
105            Self::PreToolUse { .. } => HookEvent::PreToolUse,
106            Self::PostToolUse { .. } => HookEvent::PostToolUse,
107            Self::PostToolUseFailure { .. } => HookEvent::PostToolUseFailure,
108            Self::UserPromptSubmit { .. } => HookEvent::UserPromptSubmit,
109            Self::Stop => HookEvent::Stop,
110            Self::SubagentStart { .. } => HookEvent::SubagentStart,
111            Self::SubagentStop { .. } => HookEvent::SubagentStop,
112            Self::PreCompact => HookEvent::PreCompact,
113            Self::SessionStart => HookEvent::SessionStart,
114            Self::SessionEnd => HookEvent::SessionEnd,
115        }
116    }
117
118    pub fn tool_name(&self) -> Option<&str> {
119        match self {
120            Self::PreToolUse { tool_name, .. }
121            | Self::PostToolUse { tool_name, .. }
122            | Self::PostToolUseFailure { tool_name, .. } => Some(tool_name),
123            _ => None,
124        }
125    }
126
127    pub fn tool_input(&self) -> Option<&Value> {
128        match self {
129            Self::PreToolUse { tool_input, .. } => Some(tool_input),
130            _ => None,
131        }
132    }
133
134    pub fn subagent_id(&self) -> Option<&str> {
135        match self {
136            Self::SubagentStart { subagent_id, .. } | Self::SubagentStop { subagent_id, .. } => {
137                Some(subagent_id)
138            }
139            _ => None,
140        }
141    }
142}
143
144#[derive(Clone, Debug)]
145pub struct HookInput {
146    pub session_id: String,
147    pub timestamp: DateTime<Utc>,
148    pub data: HookEventData,
149    pub metadata: Option<Value>,
150}
151
152impl HookInput {
153    pub fn new(session_id: impl Into<String>, data: HookEventData) -> Self {
154        Self {
155            session_id: session_id.into(),
156            timestamp: Utc::now(),
157            data,
158            metadata: None,
159        }
160    }
161
162    pub fn with_metadata(mut self, metadata: Value) -> Self {
163        self.metadata = Some(metadata);
164        self
165    }
166
167    pub fn event_type(&self) -> HookEvent {
168        self.data.event_type()
169    }
170
171    pub fn tool_name(&self) -> Option<&str> {
172        self.data.tool_name()
173    }
174
175    pub fn subagent_id(&self) -> Option<&str> {
176        self.data.subagent_id()
177    }
178
179    pub fn pre_tool_use(
180        session_id: impl Into<String>,
181        tool_name: impl Into<String>,
182        tool_input: Value,
183    ) -> Self {
184        Self::new(
185            session_id,
186            HookEventData::PreToolUse {
187                tool_name: tool_name.into(),
188                tool_input,
189            },
190        )
191    }
192
193    pub fn post_tool_use(
194        session_id: impl Into<String>,
195        tool_name: impl Into<String>,
196        tool_result: ToolOutput,
197    ) -> Self {
198        Self::new(
199            session_id,
200            HookEventData::PostToolUse {
201                tool_name: tool_name.into(),
202                tool_result,
203            },
204        )
205    }
206
207    pub fn post_tool_use_failure(
208        session_id: impl Into<String>,
209        tool_name: impl Into<String>,
210        error: impl Into<String>,
211    ) -> Self {
212        Self::new(
213            session_id,
214            HookEventData::PostToolUseFailure {
215                tool_name: tool_name.into(),
216                error: error.into(),
217            },
218        )
219    }
220
221    pub fn user_prompt_submit(session_id: impl Into<String>, prompt: impl Into<String>) -> Self {
222        Self::new(
223            session_id,
224            HookEventData::UserPromptSubmit {
225                prompt: prompt.into(),
226            },
227        )
228    }
229
230    pub fn session_start(session_id: impl Into<String>) -> Self {
231        Self::new(session_id, HookEventData::SessionStart)
232    }
233
234    pub fn session_end(session_id: impl Into<String>) -> Self {
235        Self::new(session_id, HookEventData::SessionEnd)
236    }
237
238    pub fn stop(session_id: impl Into<String>) -> Self {
239        Self::new(session_id, HookEventData::Stop)
240    }
241
242    pub fn pre_compact(session_id: impl Into<String>) -> Self {
243        Self::new(session_id, HookEventData::PreCompact)
244    }
245
246    pub fn subagent_start(
247        session_id: impl Into<String>,
248        subagent_id: impl Into<String>,
249        subagent_type: impl Into<String>,
250        description: impl Into<String>,
251    ) -> Self {
252        Self::new(
253            session_id,
254            HookEventData::SubagentStart {
255                subagent_id: subagent_id.into(),
256                subagent_type: subagent_type.into(),
257                description: description.into(),
258            },
259        )
260    }
261
262    pub fn subagent_stop(
263        session_id: impl Into<String>,
264        subagent_id: impl Into<String>,
265        success: bool,
266        error: Option<String>,
267    ) -> Self {
268        Self::new(
269            session_id,
270            HookEventData::SubagentStop {
271                subagent_id: subagent_id.into(),
272                success,
273                error,
274            },
275        )
276    }
277}
278
279#[derive(Clone, Debug, Default)]
280pub struct HookOutput {
281    pub continue_execution: bool,
282    pub stop_reason: Option<String>,
283    pub suppress_logging: bool,
284    pub system_message: Option<String>,
285    pub updated_input: Option<Value>,
286    pub additional_context: Option<String>,
287}
288
289impl HookOutput {
290    pub fn allow() -> Self {
291        Self {
292            continue_execution: true,
293            ..Default::default()
294        }
295    }
296
297    pub fn block(reason: impl Into<String>) -> Self {
298        Self {
299            continue_execution: false,
300            stop_reason: Some(reason.into()),
301            ..Default::default()
302        }
303    }
304
305    pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
306        self.system_message = Some(message.into());
307        self
308    }
309
310    pub fn with_context(mut self, context: impl Into<String>) -> Self {
311        self.additional_context = Some(context.into());
312        self
313    }
314
315    pub fn with_updated_input(mut self, input: Value) -> Self {
316        self.updated_input = Some(input);
317        self
318    }
319
320    pub fn suppress_logging(mut self) -> Self {
321        self.suppress_logging = true;
322        self
323    }
324}
325
326#[derive(Clone, Debug)]
327pub struct HookContext {
328    pub session_id: String,
329    pub cancellation_token: CancellationToken,
330    pub cwd: Option<std::path::PathBuf>,
331    pub env: std::collections::HashMap<String, String>,
332}
333
334impl Default for HookContext {
335    fn default() -> Self {
336        Self {
337            session_id: String::new(),
338            cancellation_token: CancellationToken::new(),
339            cwd: None,
340            env: std::collections::HashMap::new(),
341        }
342    }
343}
344
345impl HookContext {
346    pub fn new(session_id: impl Into<String>) -> Self {
347        Self {
348            session_id: session_id.into(),
349            ..Default::default()
350        }
351    }
352
353    pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
354        self.cancellation_token = token;
355        self
356    }
357
358    pub fn with_cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
359        self.cwd = Some(cwd.into());
360        self
361    }
362
363    pub fn with_env(mut self, env: std::collections::HashMap<String, String>) -> Self {
364        self.env = env;
365        self
366    }
367}
368
369/// Hook metadata for configuration.
370#[derive(Clone, Debug)]
371pub struct HookMetadata {
372    pub name: String,
373    pub events: Vec<HookEvent>,
374    pub priority: i32,
375    pub timeout_secs: u64,
376    pub tool_matcher: Option<Regex>,
377}
378
379impl HookMetadata {
380    pub fn new(name: impl Into<String>, events: Vec<HookEvent>) -> Self {
381        Self {
382            name: name.into(),
383            events,
384            priority: 0,
385            timeout_secs: 60,
386            tool_matcher: None,
387        }
388    }
389
390    pub fn with_priority(mut self, priority: i32) -> Self {
391        self.priority = priority;
392        self
393    }
394
395    pub fn with_timeout(mut self, secs: u64) -> Self {
396        self.timeout_secs = secs;
397        self
398    }
399
400    pub fn with_tool_matcher(mut self, pattern: &str) -> Self {
401        if let Ok(regex) = Regex::new(pattern) {
402            self.tool_matcher = Some(regex);
403        }
404        self
405    }
406}
407
408#[async_trait]
409pub trait Hook: Send + Sync {
410    fn name(&self) -> &str;
411    fn events(&self) -> &[HookEvent];
412
413    #[inline]
414    fn tool_matcher(&self) -> Option<&Regex> {
415        None
416    }
417
418    #[inline]
419    fn timeout_secs(&self) -> u64 {
420        60
421    }
422
423    #[inline]
424    fn priority(&self) -> i32 {
425        0
426    }
427
428    async fn execute(
429        &self,
430        input: HookInput,
431        hook_context: &HookContext,
432    ) -> Result<HookOutput, crate::Error>;
433
434    /// Get full metadata as a struct.
435    fn metadata(&self) -> HookMetadata {
436        HookMetadata {
437            name: self.name().to_string(),
438            events: self.events().to_vec(),
439            priority: self.priority(),
440            timeout_secs: self.timeout_secs(),
441            tool_matcher: self.tool_matcher().cloned(),
442        }
443    }
444}
445
446pub struct FnHook<F> {
447    name: String,
448    events: Vec<HookEvent>,
449    handler: F,
450    priority: i32,
451    timeout_secs: u64,
452    tool_matcher: Option<Regex>,
453}
454
455impl<F> FnHook<F> {
456    pub fn builder(name: impl Into<String>, events: Vec<HookEvent>) -> FnHookBuilder {
457        FnHookBuilder {
458            name: name.into(),
459            events,
460            priority: 0,
461            timeout_secs: 60,
462            tool_matcher: None,
463        }
464    }
465}
466
467pub struct FnHookBuilder {
468    name: String,
469    events: Vec<HookEvent>,
470    priority: i32,
471    timeout_secs: u64,
472    tool_matcher: Option<Regex>,
473}
474
475impl FnHookBuilder {
476    pub fn priority(mut self, priority: i32) -> Self {
477        self.priority = priority;
478        self
479    }
480
481    pub fn timeout_secs(mut self, secs: u64) -> Self {
482        self.timeout_secs = secs;
483        self
484    }
485
486    pub fn tool_matcher(mut self, pattern: &str) -> Self {
487        if let Ok(regex) = Regex::new(pattern) {
488            self.tool_matcher = Some(regex);
489        }
490        self
491    }
492
493    pub fn handler<F, Fut>(self, handler: F) -> FnHook<F>
494    where
495        F: Fn(HookInput, HookContext) -> Fut + Send + Sync,
496        Fut: std::future::Future<Output = Result<HookOutput, crate::Error>> + Send,
497    {
498        FnHook {
499            name: self.name,
500            events: self.events,
501            handler,
502            priority: self.priority,
503            timeout_secs: self.timeout_secs,
504            tool_matcher: self.tool_matcher,
505        }
506    }
507}
508
509#[async_trait]
510impl<F, Fut> Hook for FnHook<F>
511where
512    F: Fn(HookInput, HookContext) -> Fut + Send + Sync,
513    Fut: std::future::Future<Output = Result<HookOutput, crate::Error>> + Send,
514{
515    fn name(&self) -> &str {
516        &self.name
517    }
518
519    fn events(&self) -> &[HookEvent] {
520        &self.events
521    }
522
523    fn priority(&self) -> i32 {
524        self.priority
525    }
526
527    fn timeout_secs(&self) -> u64 {
528        self.timeout_secs
529    }
530
531    fn tool_matcher(&self) -> Option<&Regex> {
532        self.tool_matcher.as_ref()
533    }
534
535    async fn execute(
536        &self,
537        input: HookInput,
538        hook_context: &HookContext,
539    ) -> Result<HookOutput, crate::Error> {
540        (self.handler)(input, hook_context.clone()).await
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    #[test]
549    fn test_hook_event_display() {
550        assert_eq!(HookEvent::PreToolUse.to_string(), "pre_tool_use");
551        assert_eq!(HookEvent::PostToolUse.to_string(), "post_tool_use");
552        assert_eq!(HookEvent::SessionStart.to_string(), "session_start");
553    }
554
555    #[test]
556    fn test_hook_event_can_block() {
557        assert!(HookEvent::PreToolUse.can_block());
558        assert!(HookEvent::UserPromptSubmit.can_block());
559        assert!(!HookEvent::PostToolUse.can_block());
560        assert!(!HookEvent::SessionEnd.can_block());
561        assert!(!HookEvent::SubagentStart.can_block());
562        assert!(!HookEvent::SubagentStop.can_block());
563    }
564
565    #[test]
566    fn test_hook_input_builders() {
567        let input =
568            HookInput::pre_tool_use("session-1", "Read", serde_json::json!({"path": "/tmp"}));
569        assert_eq!(input.event_type(), HookEvent::PreToolUse);
570        assert_eq!(input.tool_name(), Some("Read"));
571        assert_eq!(input.session_id, "session-1");
572
573        let input = HookInput::session_start("session-2");
574        assert_eq!(input.event_type(), HookEvent::SessionStart);
575        assert_eq!(input.session_id, "session-2");
576    }
577
578    #[test]
579    fn test_hook_output_builders() {
580        let output = HookOutput::allow();
581        assert!(output.continue_execution);
582        assert!(output.stop_reason.is_none());
583
584        let output = HookOutput::block("Dangerous operation");
585        assert!(!output.continue_execution);
586        assert_eq!(output.stop_reason, Some("Dangerous operation".to_string()));
587
588        let output = HookOutput::allow()
589            .with_system_message("Added context")
590            .with_context("More info")
591            .suppress_logging();
592        assert!(output.continue_execution);
593        assert!(output.suppress_logging);
594        assert_eq!(output.system_message, Some("Added context".to_string()));
595        assert_eq!(output.additional_context, Some("More info".to_string()));
596    }
597
598    #[test]
599    fn test_hook_event_data_accessors() {
600        let data = HookEventData::PreToolUse {
601            tool_name: "Bash".to_string(),
602            tool_input: serde_json::json!({"command": "ls"}),
603        };
604        assert_eq!(data.event_type(), HookEvent::PreToolUse);
605        assert_eq!(data.tool_name(), Some("Bash"));
606        assert!(data.tool_input().is_some());
607
608        let data = HookEventData::SessionStart;
609        assert_eq!(data.event_type(), HookEvent::SessionStart);
610        assert_eq!(data.tool_name(), None);
611        assert!(data.tool_input().is_none());
612    }
613}