1use std::ffi::OsStr;
2use std::process::Command;
3
4use serde_json::json;
5
6use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum HookEvent {
10 PreToolUse,
11 PostToolUse,
12}
13
14impl HookEvent {
15 pub fn as_str(self) -> &'static str {
16 match self {
17 Self::PreToolUse => "PreToolUse",
18 Self::PostToolUse => "PostToolUse",
19 }
20 }
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct HookRunResult {
25 denied: bool,
26 messages: Vec<String>,
27}
28
29impl HookRunResult {
30 #[must_use]
31 pub fn allow(messages: Vec<String>) -> Self {
32 Self {
33 denied: false,
34 messages,
35 }
36 }
37
38 #[must_use]
39 pub fn is_denied(&self) -> bool {
40 self.denied
41 }
42
43 #[must_use]
44 pub fn messages(&self) -> &[String] {
45 &self.messages
46 }
47}
48
49pub trait HookCommandSource {
50 fn pre_tool_use_commands(&self) -> &[String];
51 fn post_tool_use_commands(&self) -> &[String];
52}
53
54impl HookCommandSource for RuntimeHookConfig {
55 fn pre_tool_use_commands(&self) -> &[String] {
56 self.pre_tool_use()
57 }
58
59 fn post_tool_use_commands(&self) -> &[String] {
60 self.post_tool_use()
61 }
62}
63
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct HookRunner<S: HookCommandSource> {
66 source: S,
67}
68
69impl<S: HookCommandSource> HookRunner<S> {
70 #[must_use]
71 pub fn new(source: S) -> Self {
72 Self { source }
73 }
74
75 #[must_use]
76 pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
77 run_hook_commands(
78 HookEvent::PreToolUse,
79 self.source.pre_tool_use_commands(),
80 tool_name,
81 tool_input,
82 None,
83 false,
84 )
85 }
86
87 #[must_use]
88 pub fn run_post_tool_use(
89 &self,
90 tool_name: &str,
91 tool_input: &str,
92 tool_output: &str,
93 is_error: bool,
94 ) -> HookRunResult {
95 run_hook_commands(
96 HookEvent::PostToolUse,
97 self.source.post_tool_use_commands(),
98 tool_name,
99 tool_input,
100 Some(tool_output),
101 is_error,
102 )
103 }
104}
105
106impl HookRunner<RuntimeHookConfig> {
107 #[must_use]
108 pub fn from_feature_config(feature_config: &RuntimeFeatureConfig) -> Self {
109 Self::new(feature_config.hooks().clone())
110 }
111}
112
113impl<S: HookCommandSource + Default> Default for HookRunner<S> {
114 fn default() -> Self {
115 Self::new(S::default())
116 }
117}
118
119#[derive(Debug, Clone, Copy)]
120struct HookContext<'a> {
121 event: HookEvent,
122 tool_name: &'a str,
123 tool_input: &'a str,
124 tool_output: Option<&'a str>,
125 is_error: bool,
126 payload: &'a str,
127}
128
129pub fn run_hook_commands(
130 event: HookEvent,
131 commands: &[String],
132 tool_name: &str,
133 tool_input: &str,
134 tool_output: Option<&str>,
135 is_error: bool,
136) -> HookRunResult {
137 if commands.is_empty() {
138 return HookRunResult::allow(Vec::new());
139 }
140
141 let payload = json!({
142 "hook_event_name": event.as_str(),
143 "tool_name": tool_name,
144 "tool_input": parse_tool_input(tool_input),
145 "tool_input_json": tool_input,
146 "tool_output": tool_output,
147 "tool_result_is_error": is_error,
148 })
149 .to_string();
150
151 let ctx = HookContext {
152 event,
153 tool_name,
154 tool_input,
155 tool_output,
156 is_error,
157 payload: &payload,
158 };
159
160 let mut messages = Vec::new();
161
162 for command in commands {
163 match run_hook_command(command, &ctx) {
164 HookCommandOutcome::Allow { message } => {
165 if let Some(message) = message {
166 messages.push(message);
167 }
168 }
169 HookCommandOutcome::Deny { message } => {
170 messages.push(message.unwrap_or_else(|| {
171 format!("{} hook denied tool `{tool_name}`", event.as_str())
172 }));
173 return HookRunResult {
174 denied: true,
175 messages,
176 };
177 }
178 HookCommandOutcome::Warn { message } => messages.push(message),
179 }
180 }
181
182 HookRunResult::allow(messages)
183}
184
185fn run_hook_command(command: &str, ctx: &HookContext<'_>) -> HookCommandOutcome {
186 let mut child = shell_command(command);
187 child.stdin(std::process::Stdio::piped());
188 child.stdout(std::process::Stdio::piped());
189 child.stderr(std::process::Stdio::piped());
190 child.env("HOOK_EVENT", ctx.event.as_str());
191 child.env("HOOK_TOOL_NAME", ctx.tool_name);
192 child.env("HOOK_TOOL_INPUT", ctx.tool_input);
193 child.env("HOOK_TOOL_IS_ERROR", if ctx.is_error { "1" } else { "0" });
194 if let Some(tool_output) = ctx.tool_output {
195 child.env("HOOK_TOOL_OUTPUT", tool_output);
196 }
197
198 match child.output_with_stdin(ctx.payload.as_bytes()) {
199 Ok(output) => {
200 let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
201 let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
202 let message = (!stdout.is_empty()).then_some(stdout);
203 const HOOK_EXIT_ALLOW: i32 = 0;
204 const HOOK_EXIT_DENY: i32 = 2;
205 match output.status.code() {
206 Some(HOOK_EXIT_ALLOW) => HookCommandOutcome::Allow { message },
207 Some(HOOK_EXIT_DENY) => HookCommandOutcome::Deny { message },
208 Some(code) => HookCommandOutcome::Warn {
209 message: format_hook_warning(
210 command,
211 code,
212 message.as_deref(),
213 stderr.as_str(),
214 ),
215 },
216 None => HookCommandOutcome::Warn {
217 message: format!(
218 "{} hook `{command}` terminated by signal while handling `{}`",
219 ctx.event.as_str(),
220 ctx.tool_name,
221 ),
222 },
223 }
224 }
225 Err(error) => HookCommandOutcome::Warn {
226 message: format!(
227 "{} hook `{command}` failed to start for `{}`: {error}",
228 ctx.event.as_str(),
229 ctx.tool_name,
230 ),
231 },
232 }
233}
234
235enum HookCommandOutcome {
236 Allow { message: Option<String> },
237 Deny { message: Option<String> },
238 Warn { message: String },
239}
240
241fn parse_tool_input(tool_input: &str) -> serde_json::Value {
242 serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input }))
243}
244
245fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String {
246 let mut message =
247 format!("Hook `{command}` exited with status {code}; allowing tool execution to continue");
248 if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) {
249 message.push_str(": ");
250 message.push_str(stdout);
251 } else if !stderr.is_empty() {
252 message.push_str(": ");
253 message.push_str(stderr);
254 }
255 message
256}
257
258fn shell_command(command: &str) -> CommandWithStdin {
259 #[cfg(windows)]
260 let command_builder = {
261 let mut cmd = Command::new("cmd");
262 cmd.arg("/C").arg(command);
263 CommandWithStdin::new(cmd)
264 };
265
266 #[cfg(not(windows))]
267 let command_builder = if std::path::Path::new(command).exists() {
268 let mut cmd = Command::new("sh");
269 cmd.arg(command);
270 CommandWithStdin::new(cmd)
271 } else {
272 let mut cmd = Command::new("sh");
273 cmd.arg("-lc").arg(command);
274 CommandWithStdin::new(cmd)
275 };
276
277 command_builder
278}
279
280struct CommandWithStdin {
281 command: Command,
282}
283
284impl CommandWithStdin {
285 fn new(command: Command) -> Self {
286 Self { command }
287 }
288
289 fn stdin(&mut self, cfg: std::process::Stdio) -> &mut Self {
290 self.command.stdin(cfg);
291 self
292 }
293
294 fn stdout(&mut self, cfg: std::process::Stdio) -> &mut Self {
295 self.command.stdout(cfg);
296 self
297 }
298
299 fn stderr(&mut self, cfg: std::process::Stdio) -> &mut Self {
300 self.command.stderr(cfg);
301 self
302 }
303
304 fn env<K, V>(&mut self, key: K, value: V) -> &mut Self
305 where
306 K: AsRef<OsStr>,
307 V: AsRef<OsStr>,
308 {
309 self.command.env(key, value);
310 self
311 }
312
313 fn output_with_stdin(&mut self, stdin: &[u8]) -> std::io::Result<std::process::Output> {
314 let mut child = self.command.spawn()?;
315 if let Some(mut child_stdin) = child.stdin.take() {
316 use std::io::Write as _;
317 if let Err(error) = child_stdin.write_all(stdin) {
318 if error.kind() != std::io::ErrorKind::BrokenPipe {
319 return Err(error);
320 }
321 }
322 }
323 child.wait_with_output()
324 }
325}
326
327#[cfg(test)]
328#[cfg(unix)]
329mod tests {
330 use super::{HookRunResult, HookRunner};
331 use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
332
333 #[test]
334 fn allows_exit_code_zero_and_captures_stdout() {
335 let runner = HookRunner::new(RuntimeHookConfig::new(
336 vec!["printf 'pre ok'".to_string()],
337 Vec::new(),
338 ));
339
340 let result = runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#);
341
342 assert_eq!(result, HookRunResult::allow(vec!["pre ok".to_string()]));
343 }
344
345 #[test]
346 fn denies_exit_code_two() {
347 let runner = HookRunner::new(RuntimeHookConfig::new(
348 vec!["printf 'blocked by hook'; exit 2".to_string()],
349 Vec::new(),
350 ));
351
352 let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#);
353
354 assert!(result.is_denied());
355 assert_eq!(result.messages(), &["blocked by hook".to_string()]);
356 }
357
358 #[test]
359 fn warns_for_other_non_zero_statuses() {
360 let runner = HookRunner::from_feature_config(&RuntimeFeatureConfig::default().with_hooks(
361 RuntimeHookConfig::new(
362 vec!["printf 'warning hook'; exit 1".to_string()],
363 Vec::new(),
364 ),
365 ));
366
367 let result = runner.run_pre_tool_use("Edit", r#"{"file":"src/lib.rs"}"#);
368
369 assert!(!result.is_denied());
370 assert!(result
371 .messages()
372 .iter()
373 .any(|message| message.contains("allowing tool execution to continue")));
374 }
375}