1use std::collections::HashMap;
4use std::process::Stdio;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use regex::Regex;
9use tokio::io::AsyncWriteExt;
10use tokio::process::Command;
11
12use super::{Hook, HookContext, HookEvent, HookInput, HookOutput};
13use crate::config::{HookConfig, HooksSettings};
14
15pub struct CommandHook {
16 name: String,
17 command: String,
18 events: Vec<HookEvent>,
19 tool_pattern: Option<Regex>,
20 timeout_secs: u64,
21 extra_env: HashMap<String, String>,
22}
23
24impl CommandHook {
25 pub fn new(
26 name: impl Into<String>,
27 command: impl Into<String>,
28 events: Vec<HookEvent>,
29 ) -> Self {
30 Self {
31 name: name.into(),
32 command: command.into(),
33 events,
34 tool_pattern: None,
35 timeout_secs: 60,
36 extra_env: HashMap::new(),
37 }
38 }
39
40 pub fn with_matcher(mut self, pattern: &str) -> Self {
41 self.tool_pattern = Regex::new(pattern).ok();
42 self
43 }
44
45 pub fn with_timeout(mut self, secs: u64) -> Self {
46 self.timeout_secs = secs;
47 self
48 }
49
50 pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
51 self.extra_env.insert(key.into(), value.into());
52 self
53 }
54
55 pub fn from_settings(settings: &HooksSettings) -> Vec<Self> {
56 let mut hooks = Vec::new();
57
58 for (name, config) in &settings.pre_tool_use {
59 hooks.push(Self::from_event_config(name, HookEvent::PreToolUse, config));
60 }
61
62 for (name, config) in &settings.post_tool_use {
63 hooks.push(Self::from_event_config(
64 name,
65 HookEvent::PostToolUse,
66 config,
67 ));
68 }
69
70 for (i, config) in settings.session_start.iter().enumerate() {
71 hooks.push(Self::from_event_config(
72 format!("session-start-{}", i),
73 HookEvent::SessionStart,
74 config,
75 ));
76 }
77
78 for (i, config) in settings.session_end.iter().enumerate() {
79 hooks.push(Self::from_event_config(
80 format!("session-end-{}", i),
81 HookEvent::SessionEnd,
82 config,
83 ));
84 }
85
86 hooks
87 }
88
89 pub fn from_event_config(
90 name: impl Into<String>,
91 event: HookEvent,
92 config: &HookConfig,
93 ) -> Self {
94 let (command, matcher, timeout) = Self::parse_config(config);
95 let mut hook = Self::new(name, command, vec![event]);
96 if let Some(m) = matcher {
97 hook = hook.with_matcher(&m);
98 }
99 if let Some(t) = timeout {
100 hook = hook.with_timeout(t);
101 }
102 hook
103 }
104
105 fn parse_config(config: &HookConfig) -> (String, Option<String>, Option<u64>) {
106 match config {
107 HookConfig::Command(cmd) => (cmd.clone(), None, None),
108 HookConfig::Full {
109 command,
110 timeout_secs,
111 matcher,
112 } => (command.clone(), matcher.clone(), *timeout_secs),
113 }
114 }
115}
116
117#[async_trait]
118impl Hook for CommandHook {
119 fn name(&self) -> &str {
120 &self.name
121 }
122
123 fn events(&self) -> &[HookEvent] {
124 &self.events
125 }
126
127 fn tool_matcher(&self) -> Option<&Regex> {
128 self.tool_pattern.as_ref()
129 }
130
131 fn timeout_secs(&self) -> u64 {
132 self.timeout_secs
133 }
134
135 async fn execute(
136 &self,
137 input: HookInput,
138 hook_context: &HookContext,
139 ) -> Result<HookOutput, crate::Error> {
140 let input_json = serde_json::to_string(&InputPayload::from_input(&input))
141 .map_err(|e| crate::Error::Config(format!("Failed to serialize hook input: {}", e)))?;
142
143 let mut child = Command::new("sh")
144 .arg("-c")
145 .arg(&self.command)
146 .stdin(Stdio::piped())
147 .stdout(Stdio::piped())
148 .stderr(Stdio::inherit())
149 .current_dir(
150 hook_context
151 .cwd
152 .as_deref()
153 .unwrap_or(std::path::Path::new(".")),
154 )
155 .envs(&hook_context.env)
156 .envs(&self.extra_env)
157 .spawn()
158 .map_err(|e| crate::Error::Config(format!("Failed to spawn hook command: {}", e)))?;
159
160 if let Some(mut stdin) = child.stdin.take() {
161 stdin
162 .write_all(input_json.as_bytes())
163 .await
164 .map_err(|e| crate::Error::Config(format!("Failed to write to stdin: {}", e)))?;
165 }
166
167 let timeout = Duration::from_secs(self.timeout_secs);
168 let output = tokio::time::timeout(timeout, child.wait_with_output())
169 .await
170 .map_err(|_| crate::Error::Timeout(timeout))?
171 .map_err(|e| crate::Error::Config(format!("Hook command failed: {}", e)))?;
172
173 if !output.status.success() {
174 return Ok(HookOutput::block(format!(
175 "Hook '{}' failed with exit code: {:?}",
176 self.name,
177 output.status.code()
178 )));
179 }
180
181 let stdout = String::from_utf8_lossy(&output.stdout);
182 if stdout.trim().is_empty() {
183 return Ok(HookOutput::allow());
184 }
185
186 match serde_json::from_str::<OutputPayload>(stdout.trim()) {
187 Ok(payload) => Ok(payload.into_output()),
188 Err(_) => Ok(HookOutput::allow()),
189 }
190 }
191}
192
193#[derive(serde::Serialize)]
194struct InputPayload {
195 event: String,
196 session_id: String,
197 tool_name: Option<String>,
198 tool_input: Option<serde_json::Value>,
199}
200
201impl InputPayload {
202 fn from_input(input: &HookInput) -> Self {
203 Self {
204 event: input.event_type().to_string(),
205 session_id: input.session_id.clone(),
206 tool_name: input.tool_name().map(String::from),
207 tool_input: input.data.tool_input().cloned(),
208 }
209 }
210}
211
212#[derive(serde::Deserialize)]
213struct OutputPayload {
214 #[serde(default = "default_true")]
215 continue_execution: bool,
216 stop_reason: Option<String>,
217 updated_input: Option<serde_json::Value>,
218}
219
220use crate::common::serde_defaults::default_true;
221
222impl OutputPayload {
223 fn into_output(self) -> HookOutput {
224 HookOutput {
225 continue_execution: self.continue_execution,
226 stop_reason: self.stop_reason,
227 updated_input: self.updated_input,
228 ..Default::default()
229 }
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_command_hook_creation() {
239 let hook = CommandHook::new("test", "echo hello", vec![HookEvent::PreToolUse])
240 .with_matcher("Bash")
241 .with_timeout(30);
242
243 assert_eq!(hook.name(), "test");
244 assert!(hook.tool_matcher().is_some());
245 assert_eq!(hook.timeout_secs(), 30);
246 }
247
248 #[test]
249 fn test_from_settings() {
250 let mut settings = HooksSettings::default();
251 settings.pre_tool_use.insert(
252 "security".to_string(),
253 HookConfig::Full {
254 command: "check-security.sh".to_string(),
255 timeout_secs: Some(10),
256 matcher: Some("Bash".to_string()),
257 },
258 );
259
260 let hooks = CommandHook::from_settings(&settings);
261 assert_eq!(hooks.len(), 1);
262 assert_eq!(hooks[0].name(), "security");
263 assert_eq!(hooks[0].timeout_secs(), 10);
264 }
265
266 #[tokio::test]
267 async fn test_command_hook_execution() {
268 let hook = CommandHook::new("echo-test", "echo '{}'", vec![HookEvent::PreToolUse]);
269
270 let input = HookInput::pre_tool_use("test-session", "Read", serde_json::json!({}));
271 let hook_context = HookContext::new("test-session");
272
273 let output = hook.execute(input, &hook_context).await.unwrap();
274 assert!(output.continue_execution);
275 }
276}