cmd_wrapper/
std_command.rs

1use std::{
2    ffi::OsStr,
3    io::{self, Error, Read, Result, Write},
4    path::Path,
5    process::{Child, Command, ExitStatus, Stdio},
6    sync::mpsc::{self, Receiver},
7    thread::{self, JoinHandle},
8};
9
10use tracing::{error, instrument, warn};
11
12#[derive(Debug)]
13pub struct StdCommand {
14    command: Command,
15}
16
17impl StdCommand {
18    pub fn new<S: Into<String>>(program: S) -> Self {
19        Self {
20            command: Command::new(program.into()),
21        }
22    }
23
24    pub fn arg<S: AsRef<OsStr>>(&mut self, arg: S) -> &mut Self {
25        self.command.arg(arg);
26        self
27    }
28
29    pub fn args<I, S>(&mut self, args: I) -> &mut Self
30    where
31        I: IntoIterator<Item = S>,
32        S: AsRef<OsStr>,
33    {
34        self.command.args(args);
35        self
36    }
37
38    pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Self
39    where
40        K: AsRef<OsStr>,
41        V: AsRef<OsStr>,
42    {
43        self.command.env(key, val);
44        self
45    }
46
47    pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Self
48    where
49        I: IntoIterator<Item = (K, V)>,
50        K: AsRef<OsStr>,
51        V: AsRef<OsStr>,
52    {
53        self.command.envs(vars);
54        self
55    }
56
57    pub fn current_dir<P: AsRef<Path>>(&mut self, dir: P) -> &mut Self {
58        self.command.current_dir(dir);
59        self
60    }
61
62    pub fn stdin<T: Into<Stdio>>(&mut self, stdin: T) -> &mut Self {
63        self.command.stdin(stdin);
64        self
65    }
66
67    pub fn stdout<T: Into<Stdio>>(&mut self, stdout: T) -> &mut Self {
68        self.command.stdout(stdout);
69        self
70    }
71
72    pub fn stderr<T: Into<Stdio>>(&mut self, stderr: T) -> &mut Self {
73        self.command.stderr(stderr);
74        self
75    }
76
77    #[instrument(level = "trace")]
78    pub fn into_inner(self) -> Command {
79        self.command
80    }
81
82    #[instrument(level = "trace")]
83    pub fn child(mut self) -> Result<Child> {
84        self.command.spawn()
85    }
86
87    #[instrument(level = "trace")]
88    pub fn spawn(&mut self) -> Result<StdProcessContext> {
89        let mut child = self.command.spawn()?;
90
91        let stdin = child
92            .stdin
93            .take()
94            .map(|s| Box::new(s) as Box<dyn Write>)
95            .unwrap_or_else(|| Box::new(SyncSink));
96        let stdout = child
97            .stdout
98            .take()
99            .map(|s| Box::new(s) as Box<dyn Read>)
100            .unwrap_or_else(|| Box::new(io::empty()));
101        let stderr = child
102            .stderr
103            .take()
104            .map(|s| Box::new(s) as Box<dyn Read>)
105            .unwrap_or_else(|| Box::new(io::empty()));
106
107        let pid = child.id();
108
109        let (end_tx, end_rx) = mpsc::channel::<i32>();
110        let end_handler = thread::spawn(move || {
111            let exit_code = match child.wait() {
112                Ok(status) if status.code() == Some(0) => 0,
113                Ok(_) => 1,
114                Err(err) => {
115                    error!("waiting for child process failed: {err}");
116                    -1
117                }
118            };
119
120            if let Err(err) = end_tx.send(exit_code) {
121                warn!("receiver dropped before sending exit code: {err}");
122            }
123        });
124
125        Ok(StdProcessContext {
126            stdin,
127            stdout,
128            stderr,
129            pid,
130            end_rx,
131            end_handler,
132        })
133    }
134
135    #[instrument(level = "trace")]
136    pub fn output(&mut self) -> Result<String> {
137        let output = self.command.output()?;
138
139        if !output.status.success() {
140            let stderr = String::from_utf8_lossy(&output.stderr);
141            return Err(Error::other(stderr));
142        }
143
144        Ok(String::from_utf8_lossy(&output.stdout).to_string())
145    }
146
147    #[instrument(level = "trace")]
148    pub fn status(&mut self) -> Result<ExitStatus> {
149        self.command.status()
150    }
151
152    #[instrument(level = "trace")]
153    pub fn execute_and_print(&mut self) -> Result<()> {
154        let output = self.command.output()?;
155        if !output.stdout.is_empty() {
156            io::stdout().write_all(&output.stdout)?;
157        }
158        if !output.stderr.is_empty() {
159            io::stderr().write_all(&output.stderr)?;
160        }
161        Ok(())
162    }
163
164    pub fn read_full_stderr_if_any(stderr: &mut impl Read) -> Result<()> {
165        let mut peek_buf = vec![0u8; 1024];
166        let n = stderr.read(&mut peek_buf)?;
167        if n == 0 {
168            return Ok(());
169        }
170
171        let mut full = peek_buf[..n].to_vec();
172        let mut rest = Vec::new();
173        stderr.read_to_end(&mut rest)?;
174        full.extend(rest);
175
176        let msg = String::from_utf8_lossy(&full).trim().to_string();
177        if !msg.is_empty() {
178            return Err(Error::other(msg));
179        }
180
181        Ok(())
182    }
183}
184
185pub struct StdProcessContext {
186    pub stdin: Box<dyn Write>,
187    pub stdout: Box<dyn Read>,
188    pub stderr: Box<dyn Read>,
189    pub pid: u32,
190    pub end_rx: Receiver<i32>,
191    pub end_handler: JoinHandle<()>,
192}
193
194struct SyncSink;
195
196impl Write for SyncSink {
197    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
198        Ok(buf.len())
199    }
200    fn flush(&mut self) -> io::Result<()> {
201        Ok(())
202    }
203}