1use std::path::Path;
2
3use anyhow::{Context, Result};
4use async_trait::async_trait;
5use glob::Pattern;
6use serde::Deserialize;
7use tokio::process::Command;
8
9use super::types::*;
10
11#[derive(Debug, Deserialize)]
13pub struct HooksConfig {
14 pub hooks: Vec<ShellHookDef>,
15}
16
17#[derive(Debug, Deserialize)]
19#[serde(rename_all = "camelCase")]
20pub struct ShellHookDef {
21 pub name: String,
23 pub event: String,
25 pub pattern: Option<String>,
27 pub command: String,
29 #[serde(default)]
31 pub feedback: bool,
32}
33
34impl ShellHookDef {
35 fn parse_event(&self) -> Option<(HookTiming, HookTarget)> {
36 match self.event.as_str() {
37 "beforeRead" => Some((HookTiming::Before, HookTarget::FsRead)),
38 "afterRead" => Some((HookTiming::After, HookTarget::FsRead)),
39 "beforeWrite" => Some((HookTiming::Before, HookTarget::FsWrite)),
40 "afterWrite" => Some((HookTiming::After, HookTarget::FsWrite)),
41 "beforeTerminal" => Some((HookTiming::Before, HookTarget::Terminal)),
42 "afterTerminal" => Some((HookTiming::After, HookTarget::Terminal)),
43 "turnEnd" => Some((HookTiming::After, HookTarget::TurnEnd)),
44 _ => None,
45 }
46 }
47}
48
49#[derive(Debug)]
51enum GlobFilter {
52 MatchAll,
54 Pattern(Pattern),
56 Invalid,
58}
59
60#[derive(Debug)]
62pub struct ShellHook {
63 def: ShellHookDef,
64 timing: HookTiming,
65 target: HookTarget,
66 glob: GlobFilter,
67}
68
69impl ShellHook {
70 pub fn from_def(def: ShellHookDef) -> Option<Self> {
71 let (timing, target) = def.parse_event()?;
72 let glob = match &def.pattern {
73 None => GlobFilter::MatchAll,
74 Some(p) => match Pattern::new(p) {
75 Ok(pattern) => GlobFilter::Pattern(pattern),
76 Err(e) => {
77 tracing::warn!(
78 "Hook '{}': invalid glob pattern '{}': {e} — hook will not match any files",
79 def.name,
80 p,
81 );
82 GlobFilter::Invalid
83 }
84 },
85 };
86 Some(Self {
87 def,
88 timing,
89 target,
90 glob,
91 })
92 }
93
94 fn matches_path(&self, path: &Path) -> bool {
96 match &self.glob {
97 GlobFilter::MatchAll => true,
98 GlobFilter::Invalid => false,
99 GlobFilter::Pattern(pattern) => {
100 let path_str = path.to_string_lossy();
101 pattern.matches(&path_str)
103 || path
104 .file_name()
105 .map(|f| pattern.matches(&f.to_string_lossy()))
106 .unwrap_or(false)
107 }
108 }
109 }
110
111 fn expand_command(&self, ctx: &HookContext) -> String {
113 let mut cmd = self.def.command.clone();
114 if let Some(path) = &ctx.path {
115 cmd = cmd.replace("${file}", &path.to_string_lossy());
116 }
117 cmd
118 }
119}
120
121#[async_trait(?Send)]
122impl Hook for ShellHook {
123 fn name(&self) -> &str {
124 &self.def.name
125 }
126
127 fn timing(&self) -> HookTiming {
128 self.timing
129 }
130
131 fn target(&self) -> HookTarget {
132 self.target
133 }
134
135 async fn run(&self, ctx: &HookContext) -> HookResult {
136 if let Some(path) = &ctx.path {
138 if !self.matches_path(path) {
139 return HookResult::Continue;
140 }
141 }
142
143 let cmd = self.expand_command(ctx);
144 tracing::info!("Running hook '{}': {}", self.def.name, cmd);
145
146 let output = shell_command(&cmd).output().await;
147
148 match output {
149 Ok(output) => {
150 let stdout = String::from_utf8_lossy(&output.stdout);
151 let stderr = String::from_utf8_lossy(&output.stderr);
152
153 if !output.status.success() {
154 let exit_info = match output.status.code() {
155 Some(code) => format!("exit {code}"),
156 None => {
157 #[cfg(unix)]
158 {
159 use std::os::unix::process::ExitStatusExt;
160 match output.status.signal() {
161 Some(sig) => format!("killed by signal {sig}"),
162 None => "terminated abnormally".to_string(),
163 }
164 }
165 #[cfg(not(unix))]
166 {
167 "terminated abnormally".to_string()
168 }
169 }
170 };
171 let combined = format!(
172 "Hook '{}' failed ({exit_info}):\n{stdout}{stderr}",
173 self.def.name,
174 );
175 tracing::warn!("{combined}");
176
177 if self.def.feedback {
178 return HookResult::FeedbackPrompt { text: combined };
179 }
180 if self.timing == HookTiming::Before {
183 return HookResult::Blocked { reason: combined };
184 }
185 return HookResult::Continue;
186 }
187
188 if self.def.feedback {
189 let combined = format!("{stdout}{stderr}");
190 if !combined.trim().is_empty() {
191 return HookResult::FeedbackPrompt {
192 text: format!(
193 "Hook '{}' output:\n{combined}",
194 self.def.name
195 ),
196 };
197 }
198 }
199
200 HookResult::Continue
201 }
202 Err(e) => {
203 tracing::error!("Failed to run hook '{}': {e}", self.def.name);
204 if self.timing == HookTiming::Before {
207 HookResult::Blocked {
208 reason: format!("Hook '{}' failed to execute: {e}", self.def.name),
209 }
210 } else {
211 HookResult::Continue
212 }
213 }
214 }
215 }
216}
217
218fn shell_command(cmd: &str) -> Command {
221 #[cfg(target_os = "windows")]
222 {
223 let mut c = Command::new("cmd");
224 c.args(["/C", cmd]);
225 c
226 }
227 #[cfg(not(target_os = "windows"))]
228 {
229 let mut c = Command::new("sh");
230 c.args(["-c", cmd]);
231 c
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use std::path::PathBuf;
238
239 use super::*;
240
241 fn make_hook(event: &str, command: &str, feedback: bool) -> ShellHook {
242 make_hook_with_pattern(event, command, feedback, None)
243 }
244
245 fn make_hook_with_pattern(
246 event: &str,
247 command: &str,
248 feedback: bool,
249 pattern: Option<&str>,
250 ) -> ShellHook {
251 let def = ShellHookDef {
252 name: "test-hook".to_string(),
253 event: event.to_string(),
254 pattern: pattern.map(String::from),
255 command: command.to_string(),
256 feedback,
257 };
258 ShellHook::from_def(def).expect("valid event string")
259 }
260
261 fn write_context(path: Option<PathBuf>) -> HookContext {
262 HookContext {
263 target: HookTarget::FsWrite,
264 timing: HookTiming::Before,
265 path,
266 content: None,
267 command: None,
268 }
269 }
270
271 #[tokio::test]
272 async fn hook_successful_command_returns_continue() {
273 let hook = make_hook("beforeWrite", "echo hello", false);
274 let result = hook.run(&write_context(None)).await;
275 assert!(matches!(result, HookResult::Continue));
276 }
277
278 #[tokio::test]
279 async fn before_hook_failure_blocks() {
280 let hook = make_hook("beforeWrite", "exit 1", false);
281 let result = hook.run(&write_context(None)).await;
282 assert!(
283 matches!(result, HookResult::Blocked { .. }),
284 "before-hook failure should block, got: {result:?}"
285 );
286 }
287
288 #[tokio::test]
289 async fn after_hook_failure_continues() {
290 let hook = make_hook("afterWrite", "exit 1", false);
291 let ctx = HookContext {
292 timing: HookTiming::After,
293 ..write_context(None)
294 };
295 let result = hook.run(&ctx).await;
296 assert!(
297 matches!(result, HookResult::Continue),
298 "after-hook failure should continue, got: {result:?}"
299 );
300 }
301
302 #[tokio::test]
303 async fn feedback_hook_returns_output() {
304 let hook = make_hook("afterWrite", "echo 'lint passed'", true);
305 let ctx = HookContext {
306 timing: HookTiming::After,
307 ..write_context(None)
308 };
309 let result = hook.run(&ctx).await;
310 match result {
311 HookResult::FeedbackPrompt { text } => {
312 assert!(text.contains("lint passed"), "expected output in feedback: {text}");
313 }
314 other => panic!("expected FeedbackPrompt, got: {other:?}"),
315 }
316 }
317
318 #[tokio::test]
319 async fn feedback_hook_failure_returns_feedback() {
320 let hook = make_hook("beforeWrite", "echo 'bad format' >&2; exit 1", true);
321 let result = hook.run(&write_context(None)).await;
322 match result {
323 HookResult::FeedbackPrompt { text } => {
324 assert!(text.contains("bad format"), "expected stderr in feedback: {text}");
325 }
326 other => panic!("expected FeedbackPrompt, got: {other:?}"),
327 }
328 }
329
330 #[tokio::test]
331 async fn hook_glob_filters_non_matching_path() {
332 let hook = make_hook_with_pattern("beforeWrite", "exit 1", false, Some("*.rs"));
333 let result = hook.run(&write_context(Some(PathBuf::from("src/main.py")))).await;
334 assert!(
335 matches!(result, HookResult::Continue),
336 "non-matching glob should skip hook, got: {result:?}"
337 );
338 }
339
340 #[tokio::test]
341 async fn hook_glob_matches_file() {
342 let hook = make_hook_with_pattern("beforeWrite", "exit 1", false, Some("*.rs"));
343 let result = hook.run(&write_context(Some(PathBuf::from("main.rs")))).await;
344 assert!(
345 matches!(result, HookResult::Blocked { .. }),
346 "matching glob + failure should block, got: {result:?}"
347 );
348 }
349
350 #[tokio::test]
351 async fn hook_invalid_glob_matches_nothing() {
352 let def = ShellHookDef {
353 name: "bad-glob".to_string(),
354 event: "beforeWrite".to_string(),
355 pattern: Some("[invalid".to_string()),
356 command: "exit 1".to_string(),
357 feedback: false,
358 };
359 let hook = ShellHook::from_def(def).unwrap();
360 let result = hook.run(&write_context(Some(PathBuf::from("anything.rs")))).await;
361 assert!(
362 matches!(result, HookResult::Continue),
363 "invalid glob should match nothing (fail closed), got: {result:?}"
364 );
365 }
366
367 #[test]
368 fn unknown_event_returns_none() {
369 let def = ShellHookDef {
370 name: "bad".to_string(),
371 event: "onSomething".to_string(),
372 pattern: None,
373 command: "echo hi".to_string(),
374 feedback: false,
375 };
376 assert!(ShellHook::from_def(def).is_none());
377 }
378
379 #[test]
380 fn file_placeholder_expanded() {
381 let hook = make_hook("afterWrite", "cat ${file}", false);
382 let ctx = HookContext {
383 timing: HookTiming::After,
384 ..write_context(Some(PathBuf::from("/tmp/test.txt")))
385 };
386 let expanded = hook.expand_command(&ctx);
387 assert_eq!(expanded, "cat /tmp/test.txt");
388 }
389
390 #[test]
391 fn file_placeholder_left_when_path_is_none() {
392 let hook = make_hook("afterWrite", "cat ${file}", false);
393 let ctx = HookContext {
394 timing: HookTiming::After,
395 ..write_context(None)
396 };
397 let expanded = hook.expand_command(&ctx);
398 assert_eq!(expanded, "cat ${file}");
399 }
400
401 #[tokio::test]
402 async fn glob_hook_runs_when_path_is_none() {
403 let hook = make_hook_with_pattern("beforeWrite", "exit 1", false, Some("*.rs"));
406 let result = hook.run(&write_context(None)).await;
407 assert!(
408 matches!(result, HookResult::Blocked { .. }),
409 "glob-filtered hook with no path should still run, got: {result:?}"
410 );
411 }
412
413 #[tokio::test]
414 async fn feedback_hook_empty_output_returns_continue() {
415 let hook = make_hook("afterWrite", "true", true);
416 let ctx = HookContext {
417 timing: HookTiming::After,
418 ..write_context(None)
419 };
420 let result = hook.run(&ctx).await;
421 assert!(
422 matches!(result, HookResult::Continue),
423 "feedback hook with empty output should return Continue, got: {result:?}"
424 );
425 }
426}
427
428pub fn load_hooks_config(path: &Path) -> Result<Vec<Box<dyn Hook>>> {
430 let content = std::fs::read_to_string(path)
431 .with_context(|| format!("Failed to read hooks config: {}", path.display()))?;
432
433 let config: HooksConfig = serde_json::from_str(&content)
434 .with_context(|| format!("Failed to parse hooks config: {}", path.display()))?;
435
436 let mut hooks: Vec<Box<dyn Hook>> = Vec::new();
437 for def in config.hooks {
438 let name = def.name.clone();
439 let event = def.event.clone();
440 match ShellHook::from_def(def) {
441 Some(hook) => {
442 tracing::info!("Loaded hook: {} ({})", name, event);
443 hooks.push(Box::new(hook));
444 }
445 None => {
446 tracing::warn!("Skipping hook '{}': unknown event '{}'", name, event);
447 }
448 }
449 }
450
451 Ok(hooks)
452}