Skip to main content

agent_tools/base/
cmd.rs

1use crate::error::{Error, Result};
2use std::collections::HashMap;
3use std::fs::File;
4use std::io::{self, Read, Write};
5use std::path::PathBuf;
6use std::process::{Command, Stdio};
7use std::thread;
8use std::time::Duration;
9use wait_timeout::ChildExt;
10
11pub struct CmdTool;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum CmdStdin {
15    Text(String),
16    Bytes(Vec<u8>),
17    File(PathBuf),
18}
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct CmdRequest {
22    pub program: String,
23    pub args: Vec<String>,
24    pub cwd: Option<String>,
25    pub env: Option<HashMap<String, String>>,
26    pub timeout_ms: Option<u64>,
27    pub fail_on_non_zero: bool,
28    pub stdin: Option<CmdStdin>,
29    pub background: bool,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct ShellCmdRequest {
34    pub command: String,
35    pub cwd: Option<String>,
36    pub env: Option<HashMap<String, String>>,
37    pub timeout_ms: Option<u64>,
38    pub fail_on_non_zero: bool,
39    pub stdin: Option<CmdStdin>,
40    pub background: bool,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct CmdOutput {
45    pub stdout: String,
46    pub stderr: String,
47    pub exit_code: i32,
48    pub pid: Option<u32>,
49}
50
51impl CmdTool {
52    pub fn run(req: CmdRequest) -> Result<CmdOutput> {
53        let mut cmd = Command::new(&req.program);
54        cmd.args(&req.args);
55
56        run_inner(
57            &mut cmd,
58            req.cwd,
59            req.env,
60            req.timeout_ms,
61            req.fail_on_non_zero,
62            req.stdin,
63            req.background,
64        )
65    }
66
67    pub fn run_shell(req: ShellCmdRequest) -> Result<CmdOutput> {
68        let mut cmd = build_shell_command(&req.command);
69
70        run_inner(
71            &mut cmd,
72            req.cwd,
73            req.env,
74            req.timeout_ms,
75            req.fail_on_non_zero,
76            req.stdin,
77            req.background,
78        )
79    }
80}
81
82fn run_inner(
83    cmd: &mut Command,
84    cwd: Option<String>,
85    env: Option<HashMap<String, String>>,
86    timeout_ms: Option<u64>,
87    fail_on_non_zero: bool,
88    stdin: Option<CmdStdin>,
89    background: bool,
90) -> Result<CmdOutput> {
91    configure_command(cmd, cwd, env, stdin.as_ref(), background)?;
92    let mut child = spawn_child(cmd)?;
93    let stdout_handle = take_output_reader(&mut child.stdout);
94    let stderr_handle = take_output_reader(&mut child.stderr);
95
96    write_stdin(&mut child, stdin.as_ref())?;
97
98    if background {
99        return Ok(background_output(&child));
100    }
101
102    let status = match wait_for_child(&mut child, timeout_ms) {
103        Ok(status) => status,
104        Err(err) => {
105            let _ = collect_output(stdout_handle);
106            let _ = collect_output(stderr_handle);
107            return Err(err);
108        }
109    };
110
111    build_foreground_output(status, stdout_handle, stderr_handle, fail_on_non_zero)
112}
113
114fn configure_command(
115    cmd: &mut Command,
116    cwd: Option<String>,
117    env: Option<HashMap<String, String>>,
118    stdin: Option<&CmdStdin>,
119    background: bool,
120) -> Result<()> {
121    if let Some(cwd) = cwd {
122        cmd.current_dir(cwd);
123    }
124
125    if let Some(env) = env {
126        cmd.envs(env);
127    }
128
129    configure_stdin(cmd, stdin, background)?;
130    configure_output(cmd, background);
131    Ok(())
132}
133
134fn configure_stdin(cmd: &mut Command, stdin: Option<&CmdStdin>, background: bool) -> Result<()> {
135    match stdin {
136        Some(CmdStdin::File(path)) => {
137            let file = File::open(path).map_err(Error::tool_io)?;
138            cmd.stdin(file);
139        }
140        Some(_) => {
141            cmd.stdin(Stdio::piped());
142        }
143        None if background => {
144            cmd.stdin(Stdio::null());
145        }
146        None => {}
147    }
148
149    Ok(())
150}
151
152fn configure_output(cmd: &mut Command, background: bool) {
153    if background {
154        cmd.stdout(Stdio::null());
155        cmd.stderr(Stdio::null());
156        return;
157    }
158
159    cmd.stdout(Stdio::piped());
160    cmd.stderr(Stdio::piped());
161}
162
163fn spawn_child(cmd: &mut Command) -> Result<std::process::Child> {
164    cmd.spawn().map_err(Error::tool_io)
165}
166
167fn take_output_reader<R>(pipe: &mut Option<R>) -> Option<thread::JoinHandle<io::Result<Vec<u8>>>>
168where
169    R: Read + Send + 'static,
170{
171    pipe.take().map(spawn_reader)
172}
173
174fn background_output(child: &std::process::Child) -> CmdOutput {
175    CmdOutput {
176        stdout: String::new(),
177        stderr: String::new(),
178        exit_code: 0,
179        pid: Some(child.id()),
180    }
181}
182
183fn wait_for_child(
184    child: &mut std::process::Child,
185    timeout_ms: Option<u64>,
186) -> Result<std::process::ExitStatus> {
187    match timeout_ms {
188        Some(timeout_ms) => wait_with_timeout(child, timeout_ms),
189        None => child.wait().map_err(Error::tool_io),
190    }
191}
192
193fn wait_with_timeout(
194    child: &mut std::process::Child,
195    timeout_ms: u64,
196) -> Result<std::process::ExitStatus> {
197    let duration = Duration::from_millis(timeout_ms);
198    match child.wait_timeout(duration).map_err(Error::tool_io)? {
199        Some(status) => Ok(status),
200        None => {
201            kill_timed_out_child(child)?;
202            Err(Error::tool_timeout())
203        }
204    }
205}
206
207fn kill_timed_out_child(child: &mut std::process::Child) -> Result<()> {
208    child.kill().map_err(Error::tool_io)?;
209    child.wait().map_err(Error::tool_io)?;
210    Ok(())
211}
212
213fn write_stdin(child: &mut std::process::Child, stdin_content: Option<&CmdStdin>) -> Result<()> {
214    let Some(stdin_content) = stdin_content else {
215        return Ok(());
216    };
217
218    let Some(mut stdin) = child.stdin.take() else {
219        if matches!(stdin_content, CmdStdin::Text(_) | CmdStdin::Bytes(_)) {
220            return Err(Error::tool_io(io::Error::new(
221                io::ErrorKind::BrokenPipe,
222                "stdin pipe not available",
223            )));
224        }
225        return Ok(());
226    };
227
228    match stdin_content {
229        CmdStdin::Text(text) => stdin.write_all(text.as_bytes()).map_err(Error::tool_io)?,
230        CmdStdin::Bytes(bytes) => stdin.write_all(bytes).map_err(Error::tool_io)?,
231        CmdStdin::File(_) => {}
232    }
233
234    stdin.flush().map_err(Error::tool_io)?;
235    drop(stdin);
236    Ok(())
237}
238
239fn spawn_reader<R>(mut reader: R) -> thread::JoinHandle<io::Result<Vec<u8>>>
240where
241    R: Read + Send + 'static,
242{
243    thread::spawn(move || {
244        let mut buf = Vec::new();
245        reader.read_to_end(&mut buf)?;
246        Ok(buf)
247    })
248}
249
250fn collect_output(handle: Option<thread::JoinHandle<io::Result<Vec<u8>>>>) -> Result<String> {
251    let Some(handle) = handle else {
252        return Ok(String::new());
253    };
254
255    let bytes = handle
256        .join()
257        .map_err(|_| Error::tool_io(io::Error::other("output reader thread panicked")))?
258        .map_err(Error::tool_io)?;
259
260    Ok(String::from_utf8_lossy(&bytes).into_owned())
261}
262
263fn build_foreground_output(
264    status: std::process::ExitStatus,
265    stdout_handle: Option<thread::JoinHandle<io::Result<Vec<u8>>>>,
266    stderr_handle: Option<thread::JoinHandle<io::Result<Vec<u8>>>>,
267    fail_on_non_zero: bool,
268) -> Result<CmdOutput> {
269    let stdout = collect_output(stdout_handle)?;
270    let stderr = collect_output(stderr_handle)?;
271    let exit_code = status.code().unwrap_or(-1);
272
273    if fail_on_non_zero && exit_code != 0 {
274        return Err(Error::tool_cmd_failed(exit_code));
275    }
276
277    Ok(CmdOutput {
278        stdout,
279        stderr,
280        exit_code,
281        pid: None,
282    })
283}
284
285fn build_shell_command(command: &str) -> Command {
286    if cfg!(target_os = "windows") {
287        let mut cmd = Command::new("cmd.exe");
288        cmd.arg("/c").arg(command);
289        cmd
290    } else {
291        let mut cmd = Command::new("sh");
292        cmd.arg("-c").arg(command);
293        cmd
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use std::io::Write;
301    use tempfile::NamedTempFile;
302
303    #[test]
304    fn test_successful_command() {
305        let req = CmdRequest {
306            program: "echo".to_string(),
307            args: vec!["hello".to_string()],
308            cwd: None,
309            env: None,
310            timeout_ms: None,
311            fail_on_non_zero: false,
312            stdin: None,
313            background: false,
314        };
315        let result = CmdTool::run(req);
316        assert!(result.is_ok());
317        let output = result.unwrap();
318        assert_eq!(output.exit_code, 0);
319        assert_eq!(output.stdout.trim(), "hello");
320        assert!(output.pid.is_none());
321    }
322
323    #[test]
324    fn test_shell_command() {
325        let req = ShellCmdRequest {
326            command: "echo 'hello from shell'".to_string(),
327            cwd: None,
328            env: None,
329            timeout_ms: None,
330            fail_on_non_zero: false,
331            stdin: None,
332            background: false,
333        };
334        let result = CmdTool::run_shell(req);
335        assert!(result.is_ok());
336        let output = result.unwrap();
337        assert_eq!(output.exit_code, 0);
338        assert_eq!(output.stdout.trim(), "hello from shell");
339        assert!(output.pid.is_none());
340    }
341
342    #[test]
343    fn test_timeout_command() {
344        let req = CmdRequest {
345            program: "sleep".to_string(),
346            args: vec!["2".to_string()],
347            cwd: None,
348            env: None,
349            timeout_ms: Some(100),
350            fail_on_non_zero: false,
351            stdin: None,
352            background: false,
353        };
354        let result = CmdTool::run(req);
355        assert!(result.is_err());
356        let err_msg = result.unwrap_err().to_string().to_lowercase();
357        assert!(err_msg.contains("timed out"));
358    }
359
360    #[test]
361    fn test_non_existent_command() {
362        let req = CmdRequest {
363            program: "this_command_does_not_exist_12345".to_string(),
364            args: vec![],
365            cwd: None,
366            env: None,
367            timeout_ms: None,
368            fail_on_non_zero: false,
369            stdin: None,
370            background: false,
371        };
372        let result = CmdTool::run(req);
373        assert!(result.is_err());
374        let err_msg = result.unwrap_err().to_string().to_lowercase();
375        assert!(err_msg.contains("no such file") || err_msg.contains("not found"));
376    }
377
378    #[test]
379    fn test_stdin_text() {
380        let req = CmdRequest {
381            program: "cat".to_string(),
382            args: vec![],
383            cwd: None,
384            env: None,
385            timeout_ms: None,
386            fail_on_non_zero: false,
387            stdin: Some(CmdStdin::Text("hello stdin text".to_string())),
388            background: false,
389        };
390        let result = CmdTool::run(req);
391        assert!(result.is_ok());
392        let output = result.unwrap();
393        assert_eq!(output.exit_code, 0);
394        assert_eq!(output.stdout, "hello stdin text");
395        assert!(output.pid.is_none());
396    }
397
398    #[test]
399    fn test_stdin_bytes() {
400        let req = CmdRequest {
401            program: "cat".to_string(),
402            args: vec![],
403            cwd: None,
404            env: None,
405            timeout_ms: None,
406            fail_on_non_zero: false,
407            stdin: Some(CmdStdin::Bytes(b"hello stdin bytes".to_vec())),
408            background: false,
409        };
410        let result = CmdTool::run(req);
411        assert!(result.is_ok());
412        let output = result.unwrap();
413        assert_eq!(output.exit_code, 0);
414        assert_eq!(output.stdout, "hello stdin bytes");
415        assert!(output.pid.is_none());
416    }
417
418    #[test]
419    fn test_stdin_file() {
420        let mut temp_file = NamedTempFile::new().unwrap();
421        write!(temp_file, "hello stdin file").unwrap();
422
423        let req = CmdRequest {
424            program: "cat".to_string(),
425            args: vec![],
426            cwd: None,
427            env: None,
428            timeout_ms: None,
429            fail_on_non_zero: false,
430            stdin: Some(CmdStdin::File(temp_file.path().to_path_buf())),
431            background: false,
432        };
433        let result = CmdTool::run(req);
434        assert!(result.is_ok());
435        let output = result.unwrap();
436        assert_eq!(output.exit_code, 0);
437        assert_eq!(output.stdout, "hello stdin file");
438        assert!(output.pid.is_none());
439    }
440
441    #[test]
442    fn test_background() {
443        let req = CmdRequest {
444            program: "sleep".to_string(),
445            args: vec!["1".to_string()],
446            cwd: None,
447            env: None,
448            timeout_ms: None,
449            fail_on_non_zero: false,
450            stdin: None,
451            background: true,
452        };
453        let result = CmdTool::run(req);
454        assert!(result.is_ok());
455        let output = result.unwrap();
456        assert_eq!(output.exit_code, 0);
457        assert!(output.stdout.is_empty());
458        assert!(output.pid.is_some());
459        assert!(output.pid.unwrap() > 0);
460    }
461
462    #[test]
463    fn test_shell_pipe() {
464        let command = if cfg!(target_os = "windows") {
465            "echo hello pipe | findstr pipe"
466        } else {
467            "echo 'hello pipe' | grep pipe"
468        };
469
470        let req = ShellCmdRequest {
471            command: command.to_string(),
472            cwd: None,
473            env: None,
474            timeout_ms: None,
475            fail_on_non_zero: false,
476            stdin: None,
477            background: false,
478        };
479        let result = CmdTool::run_shell(req);
480        assert!(result.is_ok());
481        let output = result.unwrap();
482        assert_eq!(output.exit_code, 0);
483        assert!(output.stdout.contains("hello pipe"));
484        assert!(output.pid.is_none());
485    }
486
487    #[test]
488    fn test_non_zero_exit_can_fail() {
489        let req = if cfg!(target_os = "windows") {
490            ShellCmdRequest {
491                command: "cmd /c exit 7".to_string(),
492                cwd: None,
493                env: None,
494                timeout_ms: None,
495                fail_on_non_zero: true,
496                stdin: None,
497                background: false,
498            }
499        } else {
500            ShellCmdRequest {
501                command: "sh -c 'exit 7'".to_string(),
502                cwd: None,
503                env: None,
504                timeout_ms: None,
505                fail_on_non_zero: true,
506                stdin: None,
507                background: false,
508            }
509        };
510
511        let result = CmdTool::run_shell(req);
512        assert!(result.is_err());
513        let err_msg = result.unwrap_err().to_string().to_lowercase();
514        assert!(err_msg.contains("exit code 7"));
515    }
516
517    #[test]
518    fn test_non_zero_exit_can_be_observed_without_error() {
519        let req = if cfg!(target_os = "windows") {
520            ShellCmdRequest {
521                command: "cmd /c exit 9".to_string(),
522                cwd: None,
523                env: None,
524                timeout_ms: None,
525                fail_on_non_zero: false,
526                stdin: None,
527                background: false,
528            }
529        } else {
530            ShellCmdRequest {
531                command: "sh -c 'exit 9'".to_string(),
532                cwd: None,
533                env: None,
534                timeout_ms: None,
535                fail_on_non_zero: false,
536                stdin: None,
537                background: false,
538            }
539        };
540
541        let result = CmdTool::run_shell(req).unwrap();
542        assert_eq!(result.exit_code, 9);
543    }
544
545    #[cfg(not(target_os = "windows"))]
546    #[test]
547    fn test_non_utf8_stdout_is_preserved_lossily() {
548        let req = ShellCmdRequest {
549            command: "printf '\\377\\376abc'".to_string(),
550            cwd: None,
551            env: None,
552            timeout_ms: None,
553            fail_on_non_zero: false,
554            stdin: None,
555            background: false,
556        };
557
558        let result = CmdTool::run_shell(req).unwrap();
559        assert!(result.stdout.contains("abc"));
560        assert!(!result.stdout.is_empty());
561    }
562}