1use crate::error::{Error, Result};
2use std::collections::HashMap;
3use std::fs::File;
4use std::io::{self, Read, Write};
5use std::path::PathBuf;
6use std::process::{Command, Stdio};
7use std::thread;
8use std::time::Duration;
9use wait_timeout::ChildExt;
10
11pub struct CmdTool;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum CmdStdin {
15 Text(String),
16 Bytes(Vec<u8>),
17 File(PathBuf),
18}
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct CmdRequest {
22 pub program: String,
23 pub args: Vec<String>,
24 pub cwd: Option<String>,
25 pub env: Option<HashMap<String, String>>,
26 pub timeout_ms: Option<u64>,
27 pub fail_on_non_zero: bool,
28 pub stdin: Option<CmdStdin>,
29 pub background: bool,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct ShellCmdRequest {
34 pub command: String,
35 pub cwd: Option<String>,
36 pub env: Option<HashMap<String, String>>,
37 pub timeout_ms: Option<u64>,
38 pub fail_on_non_zero: bool,
39 pub stdin: Option<CmdStdin>,
40 pub background: bool,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct CmdOutput {
45 pub stdout: String,
46 pub stderr: String,
47 pub exit_code: i32,
48 pub pid: Option<u32>,
49}
50
51impl CmdTool {
52 pub fn run(req: CmdRequest) -> Result<CmdOutput> {
53 let mut cmd = Command::new(&req.program);
54 cmd.args(&req.args);
55
56 run_inner(
57 &mut cmd,
58 req.cwd,
59 req.env,
60 req.timeout_ms,
61 req.fail_on_non_zero,
62 req.stdin,
63 req.background,
64 )
65 }
66
67 pub fn run_shell(req: ShellCmdRequest) -> Result<CmdOutput> {
68 let mut cmd = build_shell_command(&req.command);
69
70 run_inner(
71 &mut cmd,
72 req.cwd,
73 req.env,
74 req.timeout_ms,
75 req.fail_on_non_zero,
76 req.stdin,
77 req.background,
78 )
79 }
80}
81
82fn run_inner(
83 cmd: &mut Command,
84 cwd: Option<String>,
85 env: Option<HashMap<String, String>>,
86 timeout_ms: Option<u64>,
87 fail_on_non_zero: bool,
88 stdin: Option<CmdStdin>,
89 background: bool,
90) -> Result<CmdOutput> {
91 configure_command(cmd, cwd, env, stdin.as_ref(), background)?;
92 let mut child = spawn_child(cmd)?;
93 let stdout_handle = take_output_reader(&mut child.stdout);
94 let stderr_handle = take_output_reader(&mut child.stderr);
95
96 write_stdin(&mut child, stdin.as_ref())?;
97
98 if background {
99 return Ok(background_output(&child));
100 }
101
102 let status = match wait_for_child(&mut child, timeout_ms) {
103 Ok(status) => status,
104 Err(err) => {
105 let _ = collect_output(stdout_handle);
106 let _ = collect_output(stderr_handle);
107 return Err(err);
108 }
109 };
110
111 build_foreground_output(status, stdout_handle, stderr_handle, fail_on_non_zero)
112}
113
114fn configure_command(
115 cmd: &mut Command,
116 cwd: Option<String>,
117 env: Option<HashMap<String, String>>,
118 stdin: Option<&CmdStdin>,
119 background: bool,
120) -> Result<()> {
121 if let Some(cwd) = cwd {
122 cmd.current_dir(cwd);
123 }
124
125 if let Some(env) = env {
126 cmd.envs(env);
127 }
128
129 configure_stdin(cmd, stdin, background)?;
130 configure_output(cmd, background);
131 Ok(())
132}
133
134fn configure_stdin(cmd: &mut Command, stdin: Option<&CmdStdin>, background: bool) -> Result<()> {
135 match stdin {
136 Some(CmdStdin::File(path)) => {
137 let file = File::open(path).map_err(Error::tool_io)?;
138 cmd.stdin(file);
139 }
140 Some(_) => {
141 cmd.stdin(Stdio::piped());
142 }
143 None if background => {
144 cmd.stdin(Stdio::null());
145 }
146 None => {}
147 }
148
149 Ok(())
150}
151
152fn configure_output(cmd: &mut Command, background: bool) {
153 if background {
154 cmd.stdout(Stdio::null());
155 cmd.stderr(Stdio::null());
156 return;
157 }
158
159 cmd.stdout(Stdio::piped());
160 cmd.stderr(Stdio::piped());
161}
162
163fn spawn_child(cmd: &mut Command) -> Result<std::process::Child> {
164 cmd.spawn().map_err(Error::tool_io)
165}
166
167fn take_output_reader<R>(pipe: &mut Option<R>) -> Option<thread::JoinHandle<io::Result<Vec<u8>>>>
168where
169 R: Read + Send + 'static,
170{
171 pipe.take().map(spawn_reader)
172}
173
174fn background_output(child: &std::process::Child) -> CmdOutput {
175 CmdOutput {
176 stdout: String::new(),
177 stderr: String::new(),
178 exit_code: 0,
179 pid: Some(child.id()),
180 }
181}
182
183fn wait_for_child(
184 child: &mut std::process::Child,
185 timeout_ms: Option<u64>,
186) -> Result<std::process::ExitStatus> {
187 match timeout_ms {
188 Some(timeout_ms) => wait_with_timeout(child, timeout_ms),
189 None => child.wait().map_err(Error::tool_io),
190 }
191}
192
193fn wait_with_timeout(
194 child: &mut std::process::Child,
195 timeout_ms: u64,
196) -> Result<std::process::ExitStatus> {
197 let duration = Duration::from_millis(timeout_ms);
198 match child.wait_timeout(duration).map_err(Error::tool_io)? {
199 Some(status) => Ok(status),
200 None => {
201 kill_timed_out_child(child)?;
202 Err(Error::tool_timeout())
203 }
204 }
205}
206
207fn kill_timed_out_child(child: &mut std::process::Child) -> Result<()> {
208 child.kill().map_err(Error::tool_io)?;
209 child.wait().map_err(Error::tool_io)?;
210 Ok(())
211}
212
213fn write_stdin(child: &mut std::process::Child, stdin_content: Option<&CmdStdin>) -> Result<()> {
214 let Some(stdin_content) = stdin_content else {
215 return Ok(());
216 };
217
218 let Some(mut stdin) = child.stdin.take() else {
219 if matches!(stdin_content, CmdStdin::Text(_) | CmdStdin::Bytes(_)) {
220 return Err(Error::tool_io(io::Error::new(
221 io::ErrorKind::BrokenPipe,
222 "stdin pipe not available",
223 )));
224 }
225 return Ok(());
226 };
227
228 match stdin_content {
229 CmdStdin::Text(text) => stdin.write_all(text.as_bytes()).map_err(Error::tool_io)?,
230 CmdStdin::Bytes(bytes) => stdin.write_all(bytes).map_err(Error::tool_io)?,
231 CmdStdin::File(_) => {}
232 }
233
234 stdin.flush().map_err(Error::tool_io)?;
235 drop(stdin);
236 Ok(())
237}
238
239fn spawn_reader<R>(mut reader: R) -> thread::JoinHandle<io::Result<Vec<u8>>>
240where
241 R: Read + Send + 'static,
242{
243 thread::spawn(move || {
244 let mut buf = Vec::new();
245 reader.read_to_end(&mut buf)?;
246 Ok(buf)
247 })
248}
249
250fn collect_output(handle: Option<thread::JoinHandle<io::Result<Vec<u8>>>>) -> Result<String> {
251 let Some(handle) = handle else {
252 return Ok(String::new());
253 };
254
255 let bytes = handle
256 .join()
257 .map_err(|_| Error::tool_io(io::Error::other("output reader thread panicked")))?
258 .map_err(Error::tool_io)?;
259
260 Ok(String::from_utf8_lossy(&bytes).into_owned())
261}
262
263fn build_foreground_output(
264 status: std::process::ExitStatus,
265 stdout_handle: Option<thread::JoinHandle<io::Result<Vec<u8>>>>,
266 stderr_handle: Option<thread::JoinHandle<io::Result<Vec<u8>>>>,
267 fail_on_non_zero: bool,
268) -> Result<CmdOutput> {
269 let stdout = collect_output(stdout_handle)?;
270 let stderr = collect_output(stderr_handle)?;
271 let exit_code = status.code().unwrap_or(-1);
272
273 if fail_on_non_zero && exit_code != 0 {
274 return Err(Error::tool_cmd_failed(exit_code));
275 }
276
277 Ok(CmdOutput {
278 stdout,
279 stderr,
280 exit_code,
281 pid: None,
282 })
283}
284
285fn build_shell_command(command: &str) -> Command {
286 if cfg!(target_os = "windows") {
287 let mut cmd = Command::new("cmd.exe");
288 cmd.arg("/c").arg(command);
289 cmd
290 } else {
291 let mut cmd = Command::new("sh");
292 cmd.arg("-c").arg(command);
293 cmd
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use std::io::Write;
301 use tempfile::NamedTempFile;
302
303 #[test]
304 fn test_successful_command() {
305 let req = CmdRequest {
306 program: "echo".to_string(),
307 args: vec!["hello".to_string()],
308 cwd: None,
309 env: None,
310 timeout_ms: None,
311 fail_on_non_zero: false,
312 stdin: None,
313 background: false,
314 };
315 let result = CmdTool::run(req);
316 assert!(result.is_ok());
317 let output = result.unwrap();
318 assert_eq!(output.exit_code, 0);
319 assert_eq!(output.stdout.trim(), "hello");
320 assert!(output.pid.is_none());
321 }
322
323 #[test]
324 fn test_shell_command() {
325 let req = ShellCmdRequest {
326 command: "echo 'hello from shell'".to_string(),
327 cwd: None,
328 env: None,
329 timeout_ms: None,
330 fail_on_non_zero: false,
331 stdin: None,
332 background: false,
333 };
334 let result = CmdTool::run_shell(req);
335 assert!(result.is_ok());
336 let output = result.unwrap();
337 assert_eq!(output.exit_code, 0);
338 assert_eq!(output.stdout.trim(), "hello from shell");
339 assert!(output.pid.is_none());
340 }
341
342 #[test]
343 fn test_timeout_command() {
344 let req = CmdRequest {
345 program: "sleep".to_string(),
346 args: vec!["2".to_string()],
347 cwd: None,
348 env: None,
349 timeout_ms: Some(100),
350 fail_on_non_zero: false,
351 stdin: None,
352 background: false,
353 };
354 let result = CmdTool::run(req);
355 assert!(result.is_err());
356 let err_msg = result.unwrap_err().to_string().to_lowercase();
357 assert!(err_msg.contains("timed out"));
358 }
359
360 #[test]
361 fn test_non_existent_command() {
362 let req = CmdRequest {
363 program: "this_command_does_not_exist_12345".to_string(),
364 args: vec![],
365 cwd: None,
366 env: None,
367 timeout_ms: None,
368 fail_on_non_zero: false,
369 stdin: None,
370 background: false,
371 };
372 let result = CmdTool::run(req);
373 assert!(result.is_err());
374 let err_msg = result.unwrap_err().to_string().to_lowercase();
375 assert!(err_msg.contains("no such file") || err_msg.contains("not found"));
376 }
377
378 #[test]
379 fn test_stdin_text() {
380 let req = CmdRequest {
381 program: "cat".to_string(),
382 args: vec![],
383 cwd: None,
384 env: None,
385 timeout_ms: None,
386 fail_on_non_zero: false,
387 stdin: Some(CmdStdin::Text("hello stdin text".to_string())),
388 background: false,
389 };
390 let result = CmdTool::run(req);
391 assert!(result.is_ok());
392 let output = result.unwrap();
393 assert_eq!(output.exit_code, 0);
394 assert_eq!(output.stdout, "hello stdin text");
395 assert!(output.pid.is_none());
396 }
397
398 #[test]
399 fn test_stdin_bytes() {
400 let req = CmdRequest {
401 program: "cat".to_string(),
402 args: vec![],
403 cwd: None,
404 env: None,
405 timeout_ms: None,
406 fail_on_non_zero: false,
407 stdin: Some(CmdStdin::Bytes(b"hello stdin bytes".to_vec())),
408 background: false,
409 };
410 let result = CmdTool::run(req);
411 assert!(result.is_ok());
412 let output = result.unwrap();
413 assert_eq!(output.exit_code, 0);
414 assert_eq!(output.stdout, "hello stdin bytes");
415 assert!(output.pid.is_none());
416 }
417
418 #[test]
419 fn test_stdin_file() {
420 let mut temp_file = NamedTempFile::new().unwrap();
421 write!(temp_file, "hello stdin file").unwrap();
422
423 let req = CmdRequest {
424 program: "cat".to_string(),
425 args: vec![],
426 cwd: None,
427 env: None,
428 timeout_ms: None,
429 fail_on_non_zero: false,
430 stdin: Some(CmdStdin::File(temp_file.path().to_path_buf())),
431 background: false,
432 };
433 let result = CmdTool::run(req);
434 assert!(result.is_ok());
435 let output = result.unwrap();
436 assert_eq!(output.exit_code, 0);
437 assert_eq!(output.stdout, "hello stdin file");
438 assert!(output.pid.is_none());
439 }
440
441 #[test]
442 fn test_background() {
443 let req = CmdRequest {
444 program: "sleep".to_string(),
445 args: vec!["1".to_string()],
446 cwd: None,
447 env: None,
448 timeout_ms: None,
449 fail_on_non_zero: false,
450 stdin: None,
451 background: true,
452 };
453 let result = CmdTool::run(req);
454 assert!(result.is_ok());
455 let output = result.unwrap();
456 assert_eq!(output.exit_code, 0);
457 assert!(output.stdout.is_empty());
458 assert!(output.pid.is_some());
459 assert!(output.pid.unwrap() > 0);
460 }
461
462 #[test]
463 fn test_shell_pipe() {
464 let command = if cfg!(target_os = "windows") {
465 "echo hello pipe | findstr pipe"
466 } else {
467 "echo 'hello pipe' | grep pipe"
468 };
469
470 let req = ShellCmdRequest {
471 command: command.to_string(),
472 cwd: None,
473 env: None,
474 timeout_ms: None,
475 fail_on_non_zero: false,
476 stdin: None,
477 background: false,
478 };
479 let result = CmdTool::run_shell(req);
480 assert!(result.is_ok());
481 let output = result.unwrap();
482 assert_eq!(output.exit_code, 0);
483 assert!(output.stdout.contains("hello pipe"));
484 assert!(output.pid.is_none());
485 }
486
487 #[test]
488 fn test_non_zero_exit_can_fail() {
489 let req = if cfg!(target_os = "windows") {
490 ShellCmdRequest {
491 command: "cmd /c exit 7".to_string(),
492 cwd: None,
493 env: None,
494 timeout_ms: None,
495 fail_on_non_zero: true,
496 stdin: None,
497 background: false,
498 }
499 } else {
500 ShellCmdRequest {
501 command: "sh -c 'exit 7'".to_string(),
502 cwd: None,
503 env: None,
504 timeout_ms: None,
505 fail_on_non_zero: true,
506 stdin: None,
507 background: false,
508 }
509 };
510
511 let result = CmdTool::run_shell(req);
512 assert!(result.is_err());
513 let err_msg = result.unwrap_err().to_string().to_lowercase();
514 assert!(err_msg.contains("exit code 7"));
515 }
516
517 #[test]
518 fn test_non_zero_exit_can_be_observed_without_error() {
519 let req = if cfg!(target_os = "windows") {
520 ShellCmdRequest {
521 command: "cmd /c exit 9".to_string(),
522 cwd: None,
523 env: None,
524 timeout_ms: None,
525 fail_on_non_zero: false,
526 stdin: None,
527 background: false,
528 }
529 } else {
530 ShellCmdRequest {
531 command: "sh -c 'exit 9'".to_string(),
532 cwd: None,
533 env: None,
534 timeout_ms: None,
535 fail_on_non_zero: false,
536 stdin: None,
537 background: false,
538 }
539 };
540
541 let result = CmdTool::run_shell(req).unwrap();
542 assert_eq!(result.exit_code, 9);
543 }
544
545 #[cfg(not(target_os = "windows"))]
546 #[test]
547 fn test_non_utf8_stdout_is_preserved_lossily() {
548 let req = ShellCmdRequest {
549 command: "printf '\\377\\376abc'".to_string(),
550 cwd: None,
551 env: None,
552 timeout_ms: None,
553 fail_on_non_zero: false,
554 stdin: None,
555 background: false,
556 };
557
558 let result = CmdTool::run_shell(req).unwrap();
559 assert!(result.stdout.contains("abc"));
560 assert!(!result.stdout.is_empty());
561 }
562}