1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, HashSet};
4use std::env;
5use std::ffi::OsString;
6use std::io;
7use std::io::{Read, Write};
8use std::path::{Path, PathBuf};
9use std::process::{Child, Command, ExitStatus, Stdio};
10use std::thread;
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct HookRunRequest {
16 pub phase_event: String,
18
19 pub hook_name: String,
21
22 pub command: Vec<String>,
24
25 pub workspace_root: PathBuf,
27
28 pub cwd: Option<PathBuf>,
30
31 pub env: HashMap<String, String>,
33
34 pub timeout_seconds: u64,
36
37 pub max_output_bytes: u64,
39
40 pub stdin_payload: serde_json::Value,
42}
43
44#[derive(Debug, Clone, Default, Serialize, Deserialize)]
46pub struct HookStreamOutput {
47 pub content: String,
49
50 pub truncated: bool,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct HookRunResult {
57 pub started_at: DateTime<Utc>,
59
60 pub ended_at: DateTime<Utc>,
62
63 pub duration_ms: u64,
65
66 pub exit_code: Option<i32>,
68
69 pub timed_out: bool,
71
72 pub stdout: HookStreamOutput,
74
75 pub stderr: HookStreamOutput,
77}
78
79#[derive(Debug, thiserror::Error)]
81pub enum HookExecutorError {
82 #[error("hook '{hook_name}' for phase-event '{phase_event}' has an empty command argv")]
84 EmptyCommand {
85 phase_event: String,
86 hook_name: String,
87 },
88
89 #[error(
91 "hook '{hook_name}' for phase-event '{phase_event}' command '{command}' could not be resolved: {reason}"
92 )]
93 CommandResolution {
94 phase_event: String,
95 hook_name: String,
96 command: String,
97 reason: String,
98 },
99
100 #[error(
102 "failed to spawn hook '{hook_name}' for phase-event '{phase_event}' with command '{command}' (cwd: {cwd}): {source}"
103 )]
104 Spawn {
105 phase_event: String,
106 hook_name: String,
107 command: String,
108 cwd: String,
109 #[source]
110 source: io::Error,
111 },
112
113 #[error(
115 "failed to serialize stdin payload for hook '{hook_name}' phase-event '{phase_event}' with command '{command}': {source}"
116 )]
117 StdinSerialize {
118 phase_event: String,
119 hook_name: String,
120 command: String,
121 #[source]
122 source: serde_json::Error,
123 },
124
125 #[error(
127 "failed to write stdin payload for hook '{hook_name}' phase-event '{phase_event}' with command '{command}': {source}"
128 )]
129 StdinWrite {
130 phase_event: String,
131 hook_name: String,
132 command: String,
133 #[source]
134 source: io::Error,
135 },
136
137 #[error(
139 "hook '{hook_name}' for phase-event '{phase_event}' exceeded timeout ({timeout_seconds}s) and could not be terminated (command: '{command}'): {source}"
140 )]
141 TimeoutTerminate {
142 phase_event: String,
143 hook_name: String,
144 command: String,
145 timeout_seconds: u64,
146 #[source]
147 source: io::Error,
148 },
149
150 #[error(
152 "failed to capture {stream} for hook '{hook_name}' phase-event '{phase_event}' with command '{command}': {source}"
153 )]
154 OutputRead {
155 phase_event: String,
156 hook_name: String,
157 command: String,
158 stream: &'static str,
159 #[source]
160 source: io::Error,
161 },
162
163 #[error(
165 "hook '{hook_name}' phase-event '{phase_event}' output collector for {stream} panicked (command: '{command}')"
166 )]
167 OutputCollectorJoin {
168 phase_event: String,
169 hook_name: String,
170 command: String,
171 stream: &'static str,
172 },
173
174 #[error(
176 "failed while waiting for hook '{hook_name}' for phase-event '{phase_event}' with command '{command}': {source}"
177 )]
178 Wait {
179 phase_event: String,
180 hook_name: String,
181 command: String,
182 #[source]
183 source: io::Error,
184 },
185}
186
187pub trait HookExecutorContract {
189 fn run(&self, request: HookRunRequest) -> Result<HookRunResult, HookExecutorError>;
191}
192
193#[derive(Debug, Clone, Default)]
195pub struct HookExecutor;
196
197impl HookExecutor {
198 #[must_use]
200 pub fn new() -> Self {
201 Self
202 }
203}
204
205impl HookExecutorContract for HookExecutor {
206 fn run(&self, request: HookRunRequest) -> Result<HookRunResult, HookExecutorError> {
207 let started_at = Utc::now();
208 let resolved_cwd = resolve_hook_cwd(&request.workspace_root, request.cwd.as_deref());
209
210 let executable = request
211 .command
212 .first()
213 .map(String::as_str)
214 .map(str::trim)
215 .filter(|value| !value.is_empty())
216 .ok_or_else(|| HookExecutorError::EmptyCommand {
217 phase_event: request.phase_event.clone(),
218 hook_name: request.hook_name.clone(),
219 })?;
220
221 let resolved_command =
222 resolve_hook_command(executable, &resolved_cwd, hook_path_override(&request.env))
223 .map_err(|reason| HookExecutorError::CommandResolution {
224 phase_event: request.phase_event.clone(),
225 hook_name: request.hook_name.clone(),
226 command: executable.to_string(),
227 reason,
228 })?;
229
230 let command_display = request.command.join(" ");
231
232 let mut command = Command::new(&resolved_command);
233 command.args(request.command.iter().skip(1));
234 command.current_dir(&resolved_cwd);
235 command.envs(&request.env);
236
237 command.stdin(Stdio::piped());
239
240 command.stdout(Stdio::piped());
242 command.stderr(Stdio::piped());
243
244 let mut child = None;
248 for attempt in 0..3 {
249 match command.spawn() {
250 Ok(c) => {
251 child = Some(c);
252 break;
253 }
254 Err(e) if e.raw_os_error() == Some(26 ) && attempt < 2 => {
255 std::thread::sleep(std::time::Duration::from_millis(10));
256 }
257 Err(source) => {
258 return Err(HookExecutorError::Spawn {
259 phase_event: request.phase_event.clone(),
260 hook_name: request.hook_name.clone(),
261 command: command_display,
262 cwd: resolved_cwd.display().to_string(),
263 source,
264 });
265 }
266 }
267 }
268 let mut child = child.expect("spawn loop must break or return");
269
270 write_stdin_payload(
271 &mut child,
272 &request.stdin_payload,
273 &request.phase_event,
274 &request.hook_name,
275 &command_display,
276 )?;
277
278 let stdout_collector =
279 spawn_stream_collector(child.stdout.take(), request.max_output_bytes);
280 let stderr_collector =
281 spawn_stream_collector(child.stderr.take(), request.max_output_bytes);
282
283 let (status, timed_out) = wait_for_completion(
284 &mut child,
285 request.timeout_seconds,
286 &request.phase_event,
287 &request.hook_name,
288 &command_display,
289 )?;
290
291 let stdout = collect_stream_output(
292 stdout_collector,
293 "stdout",
294 &request.phase_event,
295 &request.hook_name,
296 &command_display,
297 )?;
298 let stderr = collect_stream_output(
299 stderr_collector,
300 "stderr",
301 &request.phase_event,
302 &request.hook_name,
303 &command_display,
304 )?;
305
306 let ended_at = Utc::now();
307
308 Ok(HookRunResult {
309 started_at,
310 ended_at,
311 duration_ms: duration_ms(started_at, ended_at),
312 exit_code: status.code(),
313 timed_out,
314 stdout,
315 stderr,
316 })
317 }
318}
319
320const WAIT_POLL_INTERVAL: Duration = Duration::from_millis(10);
321const STREAM_READ_BUFFER_BYTES: usize = 4096;
322
323type StreamCollector = thread::JoinHandle<io::Result<HookStreamOutput>>;
324
325fn write_stdin_payload(
326 child: &mut Child,
327 stdin_payload: &serde_json::Value,
328 phase_event: &str,
329 hook_name: &str,
330 command: &str,
331) -> Result<(), HookExecutorError> {
332 let Some(mut stdin) = child.stdin.take() else {
333 return Ok(());
334 };
335
336 let payload =
337 serde_json::to_vec(stdin_payload).map_err(|source| HookExecutorError::StdinSerialize {
338 phase_event: phase_event.to_string(),
339 hook_name: hook_name.to_string(),
340 command: command.to_string(),
341 source,
342 })?;
343
344 if let Err(source) = stdin.write_all(&payload)
345 && source.kind() != io::ErrorKind::BrokenPipe
346 {
347 return Err(HookExecutorError::StdinWrite {
348 phase_event: phase_event.to_string(),
349 hook_name: hook_name.to_string(),
350 command: command.to_string(),
351 source,
352 });
353 }
354
355 if let Err(source) = stdin.flush()
356 && source.kind() != io::ErrorKind::BrokenPipe
357 {
358 return Err(HookExecutorError::StdinWrite {
359 phase_event: phase_event.to_string(),
360 hook_name: hook_name.to_string(),
361 command: command.to_string(),
362 source,
363 });
364 }
365
366 Ok(())
367}
368
369fn wait_for_completion(
370 child: &mut Child,
371 timeout_seconds: u64,
372 phase_event: &str,
373 hook_name: &str,
374 command: &str,
375) -> Result<(ExitStatus, bool), HookExecutorError> {
376 let timeout = Duration::from_secs(timeout_seconds);
377 let wait_started_at = Instant::now();
378
379 loop {
380 match child.try_wait() {
381 Ok(Some(status)) => return Ok((status, false)),
382 Ok(None) => {
383 if wait_started_at.elapsed() >= timeout {
384 let status = terminate_for_timeout(
385 child,
386 timeout_seconds,
387 phase_event,
388 hook_name,
389 command,
390 )?;
391 return Ok((status, true));
392 }
393
394 let elapsed = wait_started_at.elapsed();
395 let remaining = timeout.saturating_sub(elapsed);
396 thread::sleep(remaining.min(WAIT_POLL_INTERVAL));
397 }
398 Err(source) => {
399 return Err(HookExecutorError::Wait {
400 phase_event: phase_event.to_string(),
401 hook_name: hook_name.to_string(),
402 command: command.to_string(),
403 source,
404 });
405 }
406 }
407 }
408}
409
410fn terminate_for_timeout(
411 child: &mut Child,
412 timeout_seconds: u64,
413 phase_event: &str,
414 hook_name: &str,
415 command: &str,
416) -> Result<ExitStatus, HookExecutorError> {
417 if let Err(source) = child.kill() {
418 if let Ok(Some(status)) = child.try_wait() {
419 return Ok(status);
420 }
421
422 return Err(HookExecutorError::TimeoutTerminate {
423 phase_event: phase_event.to_string(),
424 hook_name: hook_name.to_string(),
425 command: command.to_string(),
426 timeout_seconds,
427 source,
428 });
429 }
430
431 child.wait().map_err(|source| HookExecutorError::Wait {
432 phase_event: phase_event.to_string(),
433 hook_name: hook_name.to_string(),
434 command: command.to_string(),
435 source,
436 })
437}
438
439fn spawn_stream_collector<R>(stream: Option<R>, max_output_bytes: u64) -> StreamCollector
440where
441 R: Read + Send + 'static,
442{
443 thread::spawn(move || {
444 let Some(reader) = stream else {
445 return Ok(HookStreamOutput::default());
446 };
447
448 capture_stream_output(reader, max_output_bytes)
449 })
450}
451
452fn collect_stream_output(
453 collector: StreamCollector,
454 stream: &'static str,
455 phase_event: &str,
456 hook_name: &str,
457 command: &str,
458) -> Result<HookStreamOutput, HookExecutorError> {
459 let captured = collector
460 .join()
461 .map_err(|_| HookExecutorError::OutputCollectorJoin {
462 phase_event: phase_event.to_string(),
463 hook_name: hook_name.to_string(),
464 command: command.to_string(),
465 stream,
466 })?;
467
468 captured.map_err(|source| HookExecutorError::OutputRead {
469 phase_event: phase_event.to_string(),
470 hook_name: hook_name.to_string(),
471 command: command.to_string(),
472 stream,
473 source,
474 })
475}
476
477fn capture_stream_output<R: Read>(
478 mut reader: R,
479 max_output_bytes: u64,
480) -> io::Result<HookStreamOutput> {
481 let capture_limit = usize::try_from(max_output_bytes).unwrap_or(usize::MAX);
482 let mut captured = Vec::with_capacity(capture_limit.min(STREAM_READ_BUFFER_BYTES));
483 let mut truncated = false;
484 let mut buffer = [0_u8; STREAM_READ_BUFFER_BYTES];
485
486 loop {
487 let bytes_read = reader.read(&mut buffer)?;
488 if bytes_read == 0 {
489 break;
490 }
491
492 if captured.len() < capture_limit {
493 let remaining = capture_limit - captured.len();
494 let to_copy = remaining.min(bytes_read);
495 captured.extend_from_slice(&buffer[..to_copy]);
496
497 if to_copy < bytes_read {
498 truncated = true;
499 }
500 } else {
501 truncated = true;
502 }
503 }
504
505 if let Err(error) = std::str::from_utf8(&captured)
506 && error.error_len().is_none()
507 {
508 captured.truncate(error.valid_up_to());
509 }
510
511 Ok(HookStreamOutput {
512 content: String::from_utf8_lossy(&captured).into_owned(),
513 truncated,
514 })
515}
516
517fn resolve_hook_cwd(workspace_root: &Path, hook_cwd: Option<&Path>) -> PathBuf {
518 match hook_cwd {
519 Some(path) if path.is_absolute() => path.to_path_buf(),
520 Some(path) => workspace_root.join(path),
521 None => workspace_root.to_path_buf(),
522 }
523}
524
525fn hook_path_override(env_map: &HashMap<String, String>) -> Option<&str> {
526 env_map
527 .get("PATH")
528 .or_else(|| env_map.get("Path"))
529 .map(String::as_str)
530}
531
532fn resolve_hook_command(
533 command: &str,
534 cwd: &Path,
535 path_override: Option<&str>,
536) -> Result<PathBuf, String> {
537 let command_path = Path::new(command);
538 if command_path.is_absolute() || command_path.components().count() > 1 {
539 let resolved = if command_path.is_absolute() {
540 command_path.to_path_buf()
541 } else {
542 cwd.join(command_path)
543 };
544
545 if !resolved.exists() {
546 return Err(format!(
547 "command '{command}' resolves to '{}' but the file does not exist",
548 resolved.display()
549 ));
550 }
551
552 if !is_executable_file(&resolved) {
553 return Err(format!(
554 "command '{command}' resolves to '{}' but it is not executable",
555 resolved.display()
556 ));
557 }
558
559 return Ok(resolved);
560 }
561
562 let path_value = path_override
563 .map(OsString::from)
564 .or_else(|| env::var_os("PATH"))
565 .ok_or_else(|| {
566 format!(
567 "PATH is not set while resolving command '{command}'; set PATH or provide an absolute/relative path"
568 )
569 })?;
570
571 let mut visited = HashSet::new();
572 let extensions = executable_extensions();
573
574 for dir in env::split_paths(&path_value) {
575 if !visited.insert(dir.clone()) {
576 continue;
577 }
578
579 for extension in &extensions {
580 let candidate = if extension.is_empty() {
581 dir.join(command)
582 } else {
583 dir.join(format!("{command}{}", extension.to_string_lossy()))
584 };
585
586 if is_executable_file(&candidate) {
587 return Ok(candidate);
588 }
589 }
590 }
591
592 Err(format!("command '{command}' was not found in PATH"))
593}
594
595fn executable_extensions() -> Vec<OsString> {
596 if cfg!(windows) {
597 let exts = env::var("PATHEXT").unwrap_or_else(|_| ".COM;.EXE;.BAT;.CMD".to_string());
598 exts.split(';')
599 .filter(|ext| !ext.trim().is_empty())
600 .map(|ext| OsString::from(ext.trim().to_string()))
601 .collect()
602 } else {
603 vec![OsString::new()]
604 }
605}
606
607fn is_executable_file(path: &Path) -> bool {
608 if !path.is_file() {
609 return false;
610 }
611
612 #[cfg(unix)]
613 {
614 use std::os::unix::fs::PermissionsExt;
615
616 std::fs::metadata(path)
617 .map(|metadata| metadata.permissions().mode() & 0o111 != 0)
618 .unwrap_or(false)
619 }
620
621 #[cfg(not(unix))]
622 {
623 true
624 }
625}
626
627fn duration_ms(started_at: DateTime<Utc>, ended_at: DateTime<Utc>) -> u64 {
628 let milliseconds = ended_at
629 .signed_duration_since(started_at)
630 .num_milliseconds();
631 if milliseconds <= 0 {
632 return 0;
633 }
634
635 u64::try_from(milliseconds).unwrap_or(u64::MAX)
636}
637
638#[cfg(all(test, unix))]
639mod tests {
640 use super::*;
641 use serde_json::json;
642 use std::collections::HashMap;
643 use std::fs;
644 use std::path::{Path, PathBuf};
645 use tempfile::{TempDir, tempdir};
646
647 fn write_executable_script(temp_dir: &TempDir, file_name: &str, body: &str) -> PathBuf {
648 use std::io::Write;
649 use std::os::unix::fs::OpenOptionsExt;
650
651 let script_path = temp_dir.path().join(file_name);
652 let script = format!("#!/bin/sh\nset -eu\n{body}\n");
653
654 {
657 let mut file = fs::OpenOptions::new()
658 .write(true)
659 .create(true)
660 .truncate(true)
661 .mode(0o755)
662 .open(&script_path)
663 .expect("create script file");
664 file.write_all(script.as_bytes())
665 .expect("write script file");
666 file.sync_all().expect("sync script file");
667 }
668
669 let _ = fs::metadata(&script_path).expect("stat script");
672
673 script_path
674 }
675
676 fn request_with_command(workspace_root: &Path, command: Vec<String>) -> HookRunRequest {
677 HookRunRequest {
678 phase_event: "pre.loop.start".to_string(),
679 hook_name: "test-hook".to_string(),
680 command,
681 workspace_root: workspace_root.to_path_buf(),
682 cwd: None,
683 env: HashMap::new(),
684 timeout_seconds: 2,
685 max_output_bytes: 1024,
686 stdin_payload: json!({"schema_version": 1, "phase_event": "pre.loop.start"}),
687 }
688 }
689
690 #[test]
691 fn run_reports_successful_exit_and_stream_content() {
692 let temp_dir = tempdir().expect("tempdir");
693 let script_path = write_executable_script(
694 &temp_dir,
695 "success.sh",
696 "printf 'ok-stdout'\nprintf 'ok-stderr' >&2",
697 );
698
699 let request = request_with_command(
700 temp_dir.path(),
701 vec![script_path.to_string_lossy().into_owned()],
702 );
703
704 let result = HookExecutor::new().run(request).expect("hook run succeeds");
705
706 assert_eq!(result.exit_code, Some(0));
707 assert!(!result.timed_out);
708 assert_eq!(result.stdout.content, "ok-stdout");
709 assert!(!result.stdout.truncated);
710 assert_eq!(result.stderr.content, "ok-stderr");
711 assert!(!result.stderr.truncated);
712 assert!(result.ended_at >= result.started_at);
713 }
714
715 #[test]
716 fn run_preserves_non_zero_exit_code_without_timeout() {
717 let temp_dir = tempdir().expect("tempdir");
718 let script_path = write_executable_script(
719 &temp_dir,
720 "nonzero.sh",
721 "printf 'failing-hook' >&2\nexit 17",
722 );
723
724 let request = request_with_command(
725 temp_dir.path(),
726 vec![script_path.to_string_lossy().into_owned()],
727 );
728
729 let result = HookExecutor::new()
730 .run(request)
731 .expect("hook run completes");
732
733 assert_eq!(result.exit_code, Some(17));
734 assert!(!result.timed_out);
735 assert_eq!(result.stderr.content, "failing-hook");
736 assert!(!result.stderr.truncated);
737 }
738
739 #[test]
740 fn run_marks_timed_out_when_command_exceeds_timeout() {
741 let temp_dir = tempdir().expect("tempdir");
742 let script_path = write_executable_script(&temp_dir, "timeout.sh", "while :; do :; done");
743
744 let mut request = request_with_command(
745 temp_dir.path(),
746 vec![script_path.to_string_lossy().into_owned()],
747 );
748 request.timeout_seconds = 1;
749
750 let result = HookExecutor::new()
751 .run(request)
752 .expect("hook run completes");
753
754 assert!(result.timed_out);
755 assert_ne!(result.exit_code, Some(0));
756 }
757
758 #[test]
759 fn run_truncates_stdout_and_stderr_at_max_output_bytes() {
760 let temp_dir = tempdir().expect("tempdir");
761 let script_path = write_executable_script(
762 &temp_dir,
763 "truncate.sh",
764 "printf '1234567890'\nprintf 'abcdefghij' >&2",
765 );
766
767 let mut request = request_with_command(
768 temp_dir.path(),
769 vec![script_path.to_string_lossy().into_owned()],
770 );
771 request.max_output_bytes = 8;
772
773 let result = HookExecutor::new().run(request).expect("hook run succeeds");
774
775 assert_eq!(result.exit_code, Some(0));
776 assert_eq!(result.stdout.content, "12345678");
777 assert!(result.stdout.truncated);
778 assert_eq!(result.stderr.content, "abcdefgh");
779 assert!(result.stderr.truncated);
780 }
781
782 #[test]
783 fn run_writes_json_payload_to_hook_stdin() {
784 let temp_dir = tempdir().expect("tempdir");
785 let script_path = write_executable_script(&temp_dir, "stdin.sh", "cat > \"$1\"");
786 let captured_path = temp_dir.path().join("stdin-captured.json");
787
788 let mut request = request_with_command(
789 temp_dir.path(),
790 vec![
791 script_path.to_string_lossy().into_owned(),
792 captured_path.to_string_lossy().into_owned(),
793 ],
794 );
795 let payload = json!({
796 "schema_version": 1,
797 "phase": "pre",
798 "event": "loop.start",
799 "loop": {"id": "loop-test", "is_primary": true}
800 });
801 request.stdin_payload = payload.clone();
802
803 let result = HookExecutor::new().run(request).expect("hook run succeeds");
804
805 assert_eq!(result.exit_code, Some(0));
806 assert!(!result.timed_out);
807
808 let written_payload = fs::read_to_string(captured_path).expect("read captured stdin");
809 let parsed_payload: serde_json::Value =
810 serde_json::from_str(&written_payload).expect("parse captured stdin json");
811
812 assert_eq!(parsed_payload, payload);
813 }
814}