1use std::process::Stdio;
4use std::time::Duration;
5
6use async_trait::async_trait;
7use regex::Regex;
8use tokio::io::AsyncWriteExt;
9use tokio::process::Command;
10
11use super::{Hook, HookContext, HookEvent, HookInput, HookOutput};
12use crate::config::{HookConfig, HooksSettings};
13
14pub struct CommandHook {
15 name: String,
16 command: String,
17 events: Vec<HookEvent>,
18 tool_pattern: Option<Regex>,
19 timeout_secs: u64,
20}
21
22impl CommandHook {
23 pub fn new(
24 name: impl Into<String>,
25 command: impl Into<String>,
26 events: Vec<HookEvent>,
27 ) -> Self {
28 Self {
29 name: name.into(),
30 command: command.into(),
31 events,
32 tool_pattern: None,
33 timeout_secs: 60,
34 }
35 }
36
37 pub fn with_matcher(mut self, pattern: &str) -> Self {
38 self.tool_pattern = Regex::new(pattern).ok();
39 self
40 }
41
42 pub fn with_timeout(mut self, secs: u64) -> Self {
43 self.timeout_secs = secs;
44 self
45 }
46
47 pub fn from_settings(settings: &HooksSettings) -> Vec<Self> {
48 let mut hooks = Vec::new();
49
50 for (name, config) in &settings.pre_tool_use {
51 let (command, matcher, timeout) = Self::parse_config(config);
52 let mut hook = Self::new(name, command, vec![HookEvent::PreToolUse]);
53 if let Some(m) = matcher {
54 hook = hook.with_matcher(&m);
55 }
56 if let Some(t) = timeout {
57 hook = hook.with_timeout(t);
58 }
59 hooks.push(hook);
60 }
61
62 for (name, config) in &settings.post_tool_use {
63 let (command, matcher, timeout) = Self::parse_config(config);
64 let mut hook = Self::new(name, command, vec![HookEvent::PostToolUse]);
65 if let Some(m) = matcher {
66 hook = hook.with_matcher(&m);
67 }
68 if let Some(t) = timeout {
69 hook = hook.with_timeout(t);
70 }
71 hooks.push(hook);
72 }
73
74 for (i, config) in settings.session_start.iter().enumerate() {
75 let (command, _, timeout) = Self::parse_config(config);
76 let mut hook = Self::new(
77 format!("session-start-{}", i),
78 command,
79 vec![HookEvent::SessionStart],
80 );
81 if let Some(t) = timeout {
82 hook = hook.with_timeout(t);
83 }
84 hooks.push(hook);
85 }
86
87 for (i, config) in settings.session_end.iter().enumerate() {
88 let (command, _, timeout) = Self::parse_config(config);
89 let mut hook = Self::new(
90 format!("session-end-{}", i),
91 command,
92 vec![HookEvent::SessionEnd],
93 );
94 if let Some(t) = timeout {
95 hook = hook.with_timeout(t);
96 }
97 hooks.push(hook);
98 }
99
100 hooks
101 }
102
103 fn parse_config(config: &HookConfig) -> (String, Option<String>, Option<u64>) {
104 match config {
105 HookConfig::Command(cmd) => (cmd.clone(), None, None),
106 HookConfig::Full {
107 command,
108 timeout_secs,
109 matcher,
110 } => (command.clone(), matcher.clone(), *timeout_secs),
111 }
112 }
113}
114
115#[async_trait]
116impl Hook for CommandHook {
117 fn name(&self) -> &str {
118 &self.name
119 }
120
121 fn events(&self) -> &[HookEvent] {
122 &self.events
123 }
124
125 fn tool_matcher(&self) -> Option<&Regex> {
126 self.tool_pattern.as_ref()
127 }
128
129 fn timeout_secs(&self) -> u64 {
130 self.timeout_secs
131 }
132
133 async fn execute(
134 &self,
135 input: HookInput,
136 hook_context: &HookContext,
137 ) -> Result<HookOutput, crate::Error> {
138 let input_json = serde_json::to_string(&InputPayload::from_input(&input))
139 .map_err(|e| crate::Error::Config(format!("Failed to serialize hook input: {}", e)))?;
140
141 let mut child = Command::new("sh")
142 .arg("-c")
143 .arg(&self.command)
144 .stdin(Stdio::piped())
145 .stdout(Stdio::piped())
146 .stderr(Stdio::inherit())
147 .current_dir(
148 hook_context
149 .cwd
150 .as_deref()
151 .unwrap_or(std::path::Path::new(".")),
152 )
153 .envs(&hook_context.env)
154 .spawn()
155 .map_err(|e| crate::Error::Config(format!("Failed to spawn hook command: {}", e)))?;
156
157 if let Some(mut stdin) = child.stdin.take() {
158 stdin
159 .write_all(input_json.as_bytes())
160 .await
161 .map_err(|e| crate::Error::Config(format!("Failed to write to stdin: {}", e)))?;
162 }
163
164 let timeout = Duration::from_secs(self.timeout_secs);
165 let output = tokio::time::timeout(timeout, child.wait_with_output())
166 .await
167 .map_err(|_| crate::Error::Timeout(timeout))?
168 .map_err(|e| crate::Error::Config(format!("Hook command failed: {}", e)))?;
169
170 if !output.status.success() {
171 return Ok(HookOutput::block(format!(
172 "Hook '{}' failed with exit code: {:?}",
173 self.name,
174 output.status.code()
175 )));
176 }
177
178 let stdout = String::from_utf8_lossy(&output.stdout);
179 if stdout.trim().is_empty() {
180 return Ok(HookOutput::allow());
181 }
182
183 match serde_json::from_str::<OutputPayload>(stdout.trim()) {
184 Ok(payload) => Ok(payload.into_output()),
185 Err(_) => Ok(HookOutput::allow()),
186 }
187 }
188}
189
190#[derive(serde::Serialize)]
191struct InputPayload {
192 event: String,
193 session_id: String,
194 tool_name: Option<String>,
195 tool_input: Option<serde_json::Value>,
196}
197
198impl InputPayload {
199 fn from_input(input: &HookInput) -> Self {
200 Self {
201 event: input.event_type().to_string(),
202 session_id: input.session_id.clone(),
203 tool_name: input.tool_name().map(String::from),
204 tool_input: input.data.tool_input().cloned(),
205 }
206 }
207}
208
209#[derive(serde::Deserialize)]
210struct OutputPayload {
211 #[serde(default = "default_true")]
212 continue_execution: bool,
213 stop_reason: Option<String>,
214 updated_input: Option<serde_json::Value>,
215}
216
217fn default_true() -> bool {
218 true
219}
220
221impl OutputPayload {
222 fn into_output(self) -> HookOutput {
223 HookOutput {
224 continue_execution: self.continue_execution,
225 stop_reason: self.stop_reason,
226 updated_input: self.updated_input,
227 ..Default::default()
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn test_command_hook_creation() {
238 let hook = CommandHook::new("test", "echo hello", vec![HookEvent::PreToolUse])
239 .with_matcher("Bash")
240 .with_timeout(30);
241
242 assert_eq!(hook.name(), "test");
243 assert!(hook.tool_matcher().is_some());
244 assert_eq!(hook.timeout_secs(), 30);
245 }
246
247 #[test]
248 fn test_from_settings() {
249 let mut settings = HooksSettings::default();
250 settings.pre_tool_use.insert(
251 "security".to_string(),
252 HookConfig::Full {
253 command: "check-security.sh".to_string(),
254 timeout_secs: Some(10),
255 matcher: Some("Bash".to_string()),
256 },
257 );
258
259 let hooks = CommandHook::from_settings(&settings);
260 assert_eq!(hooks.len(), 1);
261 assert_eq!(hooks[0].name(), "security");
262 assert_eq!(hooks[0].timeout_secs(), 10);
263 }
264
265 #[tokio::test]
266 async fn test_command_hook_execution() {
267 let hook = CommandHook::new("echo-test", "echo '{}'", vec![HookEvent::PreToolUse]);
268
269 let input = HookInput::pre_tool_use("test-session", "Read", serde_json::json!({}));
270 let hook_context = HookContext::new("test-session");
271
272 let output = hook.execute(input, &hook_context).await.unwrap();
273 assert!(output.continue_execution);
274 }
275}