claude_agent/hooks/
command.rs

1//! Command-based hooks that execute shell commands.
2
3use std::process::Stdio;
4use std::time::Duration;
5
6use async_trait::async_trait;
7use regex::Regex;
8use tokio::io::AsyncWriteExt;
9use tokio::process::Command;
10
11use super::{Hook, HookContext, HookEvent, HookInput, HookOutput};
12use crate::config::{HookConfig, HooksSettings};
13
14pub struct CommandHook {
15    name: String,
16    command: String,
17    events: Vec<HookEvent>,
18    tool_pattern: Option<Regex>,
19    timeout_secs: u64,
20}
21
22impl CommandHook {
23    pub fn new(
24        name: impl Into<String>,
25        command: impl Into<String>,
26        events: Vec<HookEvent>,
27    ) -> Self {
28        Self {
29            name: name.into(),
30            command: command.into(),
31            events,
32            tool_pattern: None,
33            timeout_secs: 60,
34        }
35    }
36
37    pub fn with_matcher(mut self, pattern: &str) -> Self {
38        self.tool_pattern = Regex::new(pattern).ok();
39        self
40    }
41
42    pub fn with_timeout(mut self, secs: u64) -> Self {
43        self.timeout_secs = secs;
44        self
45    }
46
47    pub fn from_settings(settings: &HooksSettings) -> Vec<Self> {
48        let mut hooks = Vec::new();
49
50        for (name, config) in &settings.pre_tool_use {
51            let (command, matcher, timeout) = Self::parse_config(config);
52            let mut hook = Self::new(name, command, vec![HookEvent::PreToolUse]);
53            if let Some(m) = matcher {
54                hook = hook.with_matcher(&m);
55            }
56            if let Some(t) = timeout {
57                hook = hook.with_timeout(t);
58            }
59            hooks.push(hook);
60        }
61
62        for (name, config) in &settings.post_tool_use {
63            let (command, matcher, timeout) = Self::parse_config(config);
64            let mut hook = Self::new(name, command, vec![HookEvent::PostToolUse]);
65            if let Some(m) = matcher {
66                hook = hook.with_matcher(&m);
67            }
68            if let Some(t) = timeout {
69                hook = hook.with_timeout(t);
70            }
71            hooks.push(hook);
72        }
73
74        for (i, config) in settings.session_start.iter().enumerate() {
75            let (command, _, timeout) = Self::parse_config(config);
76            let mut hook = Self::new(
77                format!("session-start-{}", i),
78                command,
79                vec![HookEvent::SessionStart],
80            );
81            if let Some(t) = timeout {
82                hook = hook.with_timeout(t);
83            }
84            hooks.push(hook);
85        }
86
87        for (i, config) in settings.session_end.iter().enumerate() {
88            let (command, _, timeout) = Self::parse_config(config);
89            let mut hook = Self::new(
90                format!("session-end-{}", i),
91                command,
92                vec![HookEvent::SessionEnd],
93            );
94            if let Some(t) = timeout {
95                hook = hook.with_timeout(t);
96            }
97            hooks.push(hook);
98        }
99
100        hooks
101    }
102
103    fn parse_config(config: &HookConfig) -> (String, Option<String>, Option<u64>) {
104        match config {
105            HookConfig::Command(cmd) => (cmd.clone(), None, None),
106            HookConfig::Full {
107                command,
108                timeout_secs,
109                matcher,
110            } => (command.clone(), matcher.clone(), *timeout_secs),
111        }
112    }
113}
114
115#[async_trait]
116impl Hook for CommandHook {
117    fn name(&self) -> &str {
118        &self.name
119    }
120
121    fn events(&self) -> &[HookEvent] {
122        &self.events
123    }
124
125    fn tool_matcher(&self) -> Option<&Regex> {
126        self.tool_pattern.as_ref()
127    }
128
129    fn timeout_secs(&self) -> u64 {
130        self.timeout_secs
131    }
132
133    async fn execute(
134        &self,
135        input: HookInput,
136        hook_context: &HookContext,
137    ) -> Result<HookOutput, crate::Error> {
138        let input_json = serde_json::to_string(&InputPayload::from_input(&input))
139            .map_err(|e| crate::Error::Config(format!("Failed to serialize hook input: {}", e)))?;
140
141        let mut child = Command::new("sh")
142            .arg("-c")
143            .arg(&self.command)
144            .stdin(Stdio::piped())
145            .stdout(Stdio::piped())
146            .stderr(Stdio::inherit())
147            .current_dir(
148                hook_context
149                    .cwd
150                    .as_deref()
151                    .unwrap_or(std::path::Path::new(".")),
152            )
153            .envs(&hook_context.env)
154            .spawn()
155            .map_err(|e| crate::Error::Config(format!("Failed to spawn hook command: {}", e)))?;
156
157        if let Some(mut stdin) = child.stdin.take() {
158            stdin
159                .write_all(input_json.as_bytes())
160                .await
161                .map_err(|e| crate::Error::Config(format!("Failed to write to stdin: {}", e)))?;
162        }
163
164        let timeout = Duration::from_secs(self.timeout_secs);
165        let output = tokio::time::timeout(timeout, child.wait_with_output())
166            .await
167            .map_err(|_| crate::Error::Timeout(timeout))?
168            .map_err(|e| crate::Error::Config(format!("Hook command failed: {}", e)))?;
169
170        if !output.status.success() {
171            return Ok(HookOutput::block(format!(
172                "Hook '{}' failed with exit code: {:?}",
173                self.name,
174                output.status.code()
175            )));
176        }
177
178        let stdout = String::from_utf8_lossy(&output.stdout);
179        if stdout.trim().is_empty() {
180            return Ok(HookOutput::allow());
181        }
182
183        match serde_json::from_str::<OutputPayload>(stdout.trim()) {
184            Ok(payload) => Ok(payload.into_output()),
185            Err(_) => Ok(HookOutput::allow()),
186        }
187    }
188}
189
190#[derive(serde::Serialize)]
191struct InputPayload {
192    event: String,
193    session_id: String,
194    tool_name: Option<String>,
195    tool_input: Option<serde_json::Value>,
196}
197
198impl InputPayload {
199    fn from_input(input: &HookInput) -> Self {
200        Self {
201            event: input.event_type().to_string(),
202            session_id: input.session_id.clone(),
203            tool_name: input.tool_name().map(String::from),
204            tool_input: input.data.tool_input().cloned(),
205        }
206    }
207}
208
209#[derive(serde::Deserialize)]
210struct OutputPayload {
211    #[serde(default = "default_true")]
212    continue_execution: bool,
213    stop_reason: Option<String>,
214    updated_input: Option<serde_json::Value>,
215}
216
217fn default_true() -> bool {
218    true
219}
220
221impl OutputPayload {
222    fn into_output(self) -> HookOutput {
223        HookOutput {
224            continue_execution: self.continue_execution,
225            stop_reason: self.stop_reason,
226            updated_input: self.updated_input,
227            ..Default::default()
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_command_hook_creation() {
238        let hook = CommandHook::new("test", "echo hello", vec![HookEvent::PreToolUse])
239            .with_matcher("Bash")
240            .with_timeout(30);
241
242        assert_eq!(hook.name(), "test");
243        assert!(hook.tool_matcher().is_some());
244        assert_eq!(hook.timeout_secs(), 30);
245    }
246
247    #[test]
248    fn test_from_settings() {
249        let mut settings = HooksSettings::default();
250        settings.pre_tool_use.insert(
251            "security".to_string(),
252            HookConfig::Full {
253                command: "check-security.sh".to_string(),
254                timeout_secs: Some(10),
255                matcher: Some("Bash".to_string()),
256            },
257        );
258
259        let hooks = CommandHook::from_settings(&settings);
260        assert_eq!(hooks.len(), 1);
261        assert_eq!(hooks[0].name(), "security");
262        assert_eq!(hooks[0].timeout_secs(), 10);
263    }
264
265    #[tokio::test]
266    async fn test_command_hook_execution() {
267        let hook = CommandHook::new("echo-test", "echo '{}'", vec![HookEvent::PreToolUse]);
268
269        let input = HookInput::pre_tool_use("test-session", "Read", serde_json::json!({}));
270        let hook_context = HookContext::new("test-session");
271
272        let output = hook.execute(input, &hook_context).await.unwrap();
273        assert!(output.continue_execution);
274    }
275}