Skip to main content

claude_agent/hooks/
command.rs

1//! Command-based hooks that execute shell commands.
2
3use std::collections::HashMap;
4use std::process::Stdio;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use regex::Regex;
9use tokio::io::AsyncWriteExt;
10use tokio::process::Command;
11
12use super::{Hook, HookContext, HookEvent, HookInput, HookOutput};
13use crate::config::{HookConfig, HooksSettings};
14
15pub struct CommandHook {
16    name: String,
17    command: String,
18    events: Vec<HookEvent>,
19    tool_pattern: Option<Regex>,
20    timeout_secs: u64,
21    extra_env: HashMap<String, String>,
22}
23
24impl CommandHook {
25    pub fn new(
26        name: impl Into<String>,
27        command: impl Into<String>,
28        events: Vec<HookEvent>,
29    ) -> Self {
30        Self {
31            name: name.into(),
32            command: command.into(),
33            events,
34            tool_pattern: None,
35            timeout_secs: 60,
36            extra_env: HashMap::new(),
37        }
38    }
39
40    pub fn matcher(mut self, pattern: &str) -> Self {
41        self.tool_pattern = Regex::new(pattern).ok();
42        self
43    }
44
45    pub fn timeout(mut self, secs: u64) -> Self {
46        self.timeout_secs = secs;
47        self
48    }
49
50    pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
51        self.extra_env.insert(key.into(), value.into());
52        self
53    }
54
55    pub fn from_settings(settings: &HooksSettings) -> Vec<Self> {
56        let mut hooks = Vec::new();
57
58        for (name, config) in &settings.pre_tool_use {
59            hooks.push(Self::from_event_config(name, HookEvent::PreToolUse, config));
60        }
61
62        for (name, config) in &settings.post_tool_use {
63            hooks.push(Self::from_event_config(
64                name,
65                HookEvent::PostToolUse,
66                config,
67            ));
68        }
69
70        for (i, config) in settings.session_start.iter().enumerate() {
71            hooks.push(Self::from_event_config(
72                format!("session-start-{}", i),
73                HookEvent::SessionStart,
74                config,
75            ));
76        }
77
78        for (i, config) in settings.session_end.iter().enumerate() {
79            hooks.push(Self::from_event_config(
80                format!("session-end-{}", i),
81                HookEvent::SessionEnd,
82                config,
83            ));
84        }
85
86        hooks
87    }
88
89    pub fn from_event_config(
90        name: impl Into<String>,
91        event: HookEvent,
92        config: &HookConfig,
93    ) -> Self {
94        let (command, matcher, timeout) = Self::parse_config(config);
95        let mut hook = Self::new(name, command, vec![event]);
96        if let Some(m) = matcher {
97            hook = hook.matcher(&m);
98        }
99        if let Some(t) = timeout {
100            hook = hook.timeout(t);
101        }
102        hook
103    }
104
105    fn parse_config(config: &HookConfig) -> (String, Option<String>, Option<u64>) {
106        match config {
107            HookConfig::Command(cmd) => (cmd.clone(), None, None),
108            HookConfig::Full {
109                command,
110                timeout_secs,
111                matcher,
112            } => (command.clone(), matcher.clone(), *timeout_secs),
113        }
114    }
115}
116
117#[async_trait]
118impl Hook for CommandHook {
119    fn name(&self) -> &str {
120        &self.name
121    }
122
123    fn events(&self) -> &[HookEvent] {
124        &self.events
125    }
126
127    fn tool_matcher(&self) -> Option<&Regex> {
128        self.tool_pattern.as_ref()
129    }
130
131    fn timeout_secs(&self) -> u64 {
132        self.timeout_secs
133    }
134
135    async fn execute(
136        &self,
137        input: HookInput,
138        hook_context: &HookContext,
139    ) -> Result<HookOutput, crate::Error> {
140        let input_json = serde_json::to_string(&InputPayload::from_input(&input))
141            .map_err(|e| crate::Error::Config(format!("Failed to serialize hook input: {}", e)))?;
142
143        let mut child = Command::new("sh")
144            .arg("-c")
145            .arg(&self.command)
146            .stdin(Stdio::piped())
147            .stdout(Stdio::piped())
148            .stderr(Stdio::inherit())
149            .current_dir(
150                hook_context
151                    .cwd
152                    .as_deref()
153                    .unwrap_or(std::path::Path::new(".")),
154            )
155            .envs(&hook_context.env)
156            .envs(&self.extra_env)
157            .spawn()
158            .map_err(|e| crate::Error::Config(format!("Failed to spawn hook command: {}", e)))?;
159
160        if let Some(mut stdin) = child.stdin.take() {
161            stdin
162                .write_all(input_json.as_bytes())
163                .await
164                .map_err(|e| crate::Error::Config(format!("Failed to write to stdin: {}", e)))?;
165        }
166
167        let timeout = Duration::from_secs(self.timeout_secs);
168        let output = tokio::time::timeout(timeout, child.wait_with_output())
169            .await
170            .map_err(|_| crate::Error::Timeout(timeout))?
171            .map_err(|e| crate::Error::Config(format!("Hook command failed: {}", e)))?;
172
173        if !output.status.success() {
174            return Ok(HookOutput::block(format!(
175                "Hook '{}' failed with exit code: {:?}",
176                self.name,
177                output.status.code()
178            )));
179        }
180
181        let stdout = String::from_utf8_lossy(&output.stdout);
182        if stdout.trim().is_empty() {
183            return Ok(HookOutput::allow());
184        }
185
186        match serde_json::from_str::<OutputPayload>(stdout.trim()) {
187            Ok(payload) => Ok(payload.into_output()),
188            Err(_) => Ok(HookOutput::allow()),
189        }
190    }
191}
192
193#[derive(serde::Serialize)]
194struct InputPayload {
195    event: String,
196    session_id: String,
197    tool_name: Option<String>,
198    tool_input: Option<serde_json::Value>,
199}
200
201impl InputPayload {
202    fn from_input(input: &HookInput) -> Self {
203        Self {
204            event: input.event_type().to_string(),
205            session_id: input.session_id.clone(),
206            tool_name: input.tool_name().map(String::from),
207            tool_input: input.data.tool_input().cloned(),
208        }
209    }
210}
211
212#[derive(serde::Deserialize)]
213struct OutputPayload {
214    #[serde(default = "default_true")]
215    continue_execution: bool,
216    stop_reason: Option<String>,
217    updated_input: Option<serde_json::Value>,
218}
219
220use crate::common::serde_defaults::default_true;
221
222impl OutputPayload {
223    fn into_output(self) -> HookOutput {
224        HookOutput {
225            continue_execution: self.continue_execution,
226            stop_reason: self.stop_reason,
227            updated_input: self.updated_input,
228            ..Default::default()
229        }
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_command_hook_creation() {
239        let hook = CommandHook::new("test", "echo hello", vec![HookEvent::PreToolUse])
240            .matcher("Bash")
241            .timeout(30);
242
243        assert_eq!(hook.name(), "test");
244        assert!(hook.tool_matcher().is_some());
245        assert_eq!(hook.timeout_secs(), 30);
246    }
247
248    #[test]
249    fn test_from_settings() {
250        let mut settings = HooksSettings::default();
251        settings.pre_tool_use.insert(
252            "security".to_string(),
253            HookConfig::Full {
254                command: "check-security.sh".to_string(),
255                timeout_secs: Some(10),
256                matcher: Some("Bash".to_string()),
257            },
258        );
259
260        let hooks = CommandHook::from_settings(&settings);
261        assert_eq!(hooks.len(), 1);
262        assert_eq!(hooks[0].name(), "security");
263        assert_eq!(hooks[0].timeout_secs(), 10);
264    }
265
266    #[tokio::test]
267    async fn test_command_hook_execution() {
268        let hook = CommandHook::new("echo-test", "echo '{}'", vec![HookEvent::PreToolUse]);
269
270        let input = HookInput::pre_tool_use("test-session", "Read", serde_json::json!({}));
271        let hook_context = HookContext::new("test-session");
272
273        let output = hook.execute(input, &hook_context).await.unwrap();
274        assert!(output.continue_execution);
275    }
276}