cmd_wrapper/
tokio_command.rs

1use std::{
2    ffi::OsStr,
3    io::{Error, Result},
4    path::Path,
5    pin::Pin,
6    process::{ExitStatus, Stdio},
7    task::{Context, Poll},
8};
9
10use derivative::Derivative;
11use tokio::{
12    io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
13    process::{Child, Command},
14    sync::oneshot::{self, Receiver},
15    task::JoinHandle,
16};
17use tracing::{error, instrument, warn};
18
19#[derive(Debug)]
20pub struct TokioCommand {
21    command: Command,
22}
23
24impl TokioCommand {
25    pub fn new<S: Into<String>>(program: S) -> Self {
26        Self {
27            command: Command::new(program.into()),
28        }
29    }
30
31    pub fn arg<S: AsRef<OsStr>>(&mut self, arg: S) -> &mut Self {
32        self.command.arg(arg);
33        self
34    }
35
36    pub fn args<I, S>(&mut self, args: I) -> &mut Self
37    where
38        I: IntoIterator<Item = S>,
39        S: AsRef<OsStr>,
40    {
41        self.command.args(args);
42        self
43    }
44
45    pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Self
46    where
47        K: AsRef<OsStr>,
48        V: AsRef<OsStr>,
49    {
50        self.command.env(key, val);
51        self
52    }
53
54    pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Self
55    where
56        I: IntoIterator<Item = (K, V)>,
57        K: AsRef<OsStr>,
58        V: AsRef<OsStr>,
59    {
60        self.command.envs(vars);
61        self
62    }
63
64    pub fn current_dir<P: AsRef<Path>>(&mut self, dir: P) -> &mut Self {
65        self.command.current_dir(dir);
66        self
67    }
68
69    pub fn stdin<T: Into<Stdio>>(&mut self, stdin: T) -> &mut Self {
70        self.command.stdin(stdin);
71        self
72    }
73
74    pub fn stdout<T: Into<Stdio>>(&mut self, stdout: T) -> &mut Self {
75        self.command.stdout(stdout);
76        self
77    }
78
79    pub fn stderr<T: Into<Stdio>>(&mut self, stderr: T) -> &mut Self {
80        self.command.stderr(stderr);
81        self
82    }
83
84    #[instrument(level = "trace")]
85    pub fn into_inner(self) -> Command {
86        self.command
87    }
88
89    #[instrument(level = "trace")]
90    pub fn child(mut self) -> Result<Child> {
91        self.command.spawn()
92    }
93
94    #[instrument(level = "trace")]
95    pub async fn spawn(&mut self) -> Result<TokioCommandProcess> {
96        let mut child = self.command.spawn()?;
97
98        let stdin = child
99            .stdin
100            .take()
101            .map(|s| Box::new(s) as Box<dyn AsyncWrite + Send + Unpin>)
102            .unwrap_or_else(|| Box::new(AsyncSink));
103        let stdout = child
104            .stdout
105            .take()
106            .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
107            .unwrap_or_else(|| Box::new(io::empty()));
108        let stderr = child
109            .stderr
110            .take()
111            .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
112            .unwrap_or_else(|| Box::new(io::empty()));
113
114        let pid = child.id().unwrap_or_default();
115
116        let (end_tx, end_rx) = oneshot::channel::<i32>();
117        let end_handler = tokio::spawn(async move {
118            let exit_code = match child.wait().await {
119                Ok(status) if status.code() == Some(0) => 0,
120                Ok(status) => status.code().unwrap_or(-1),
121                Err(err) => {
122                    error!("waiting for child process failed: {err}");
123                    -1
124                }
125            };
126
127            if let Err(err) = end_tx.send(exit_code) {
128                warn!("receiver dropped before sending exit code: {err}");
129            }
130        });
131
132        Ok(TokioCommandProcess {
133            stdin,
134            stdout,
135            stderr,
136            pid,
137            end_rx,
138            end_handler,
139        })
140    }
141
142    #[instrument(level = "trace")]
143    pub async fn output(&mut self) -> Result<String> {
144        let output = self.command.output().await?;
145
146        if !output.status.success() {
147            let stderr = String::from_utf8_lossy(&output.stderr);
148            return Err(Error::other(stderr));
149        }
150
151        Ok(String::from_utf8_lossy(&output.stdout).to_string())
152    }
153
154    #[instrument(level = "trace")]
155    pub async fn status(&mut self) -> Result<ExitStatus> {
156        self.command.status().await
157    }
158
159    #[instrument(level = "trace")]
160    pub async fn execute_and_print(&mut self) -> Result<()> {
161        let output = self.command.output().await?;
162        if !output.stdout.is_empty() {
163            io::stdout().write_all(&output.stdout).await?;
164        }
165        if !output.stderr.is_empty() {
166            io::stderr().write_all(&output.stderr).await?;
167        }
168        Ok(())
169    }
170
171    pub async fn read_full_stderr_if_any(stderr: &mut (impl AsyncRead + Unpin)) -> Result<()> {
172        let mut peek_buf = vec![0u8; 1024];
173        let n = stderr.read(&mut peek_buf).await?;
174        if n == 0 {
175            return Ok(());
176        }
177
178        let mut full = peek_buf[..n].to_vec();
179        let mut rest = Vec::new();
180        stderr.read_to_end(&mut rest).await?;
181        full.extend(rest);
182
183        let msg = String::from_utf8_lossy(&full).trim().to_string();
184        if !msg.is_empty() {
185            return Err(Error::other(msg));
186        }
187
188        Ok(())
189    }
190
191    pub async fn read_bytes_with_stderr_check(
192        stdout: &mut (impl AsyncRead + Unpin),
193        stderr: &mut (impl AsyncRead + Unpin),
194    ) -> Result<Option<Vec<u8>>> {
195        let mut out_buf = vec![0u8; 32 << 10];
196        let mut err_buf = vec![0u8; 1024];
197
198        loop {
199            tokio::select! {
200                read_stdout = stdout.read(&mut out_buf) => {
201                    let n = read_stdout?;
202                    if n == 0 {
203                        return Ok(None);
204                    }
205                    out_buf.truncate(n);
206                    return Ok(Some(out_buf))
207                }
208                read_stderr = stderr.read(&mut err_buf) => {
209                    let n = read_stderr?;
210                    if n > 0 {
211                        let msg = str::from_utf8(&err_buf[..n]).map_err(Error::other)?.trim();
212                        if !msg.is_empty() {
213                            return Err(Error::other(msg));
214                        }
215                    }
216                }
217            }
218        }
219    }
220}
221
222#[derive(Derivative)]
223#[derivative(Debug)]
224pub struct TokioCommandProcess {
225    #[derivative(Debug = "ignore")]
226    pub stdin: Box<dyn AsyncWrite + Send + Unpin>,
227    #[derivative(Debug = "ignore")]
228    pub stdout: Box<dyn AsyncRead + Send + Unpin>,
229    #[derivative(Debug = "ignore")]
230    pub stderr: Box<dyn AsyncRead + Send + Unpin>,
231    pub pid: u32,
232    pub end_rx: Receiver<i32>,
233    pub end_handler: JoinHandle<()>,
234}
235
236struct AsyncSink;
237
238impl AsyncWrite for AsyncSink {
239    fn poll_write(
240        self: Pin<&mut Self>,
241        _cx: &mut Context<'_>,
242        buf: &[u8],
243    ) -> Poll<std::io::Result<usize>> {
244        Poll::Ready(Ok(buf.len()))
245    }
246    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
247        Poll::Ready(Ok(()))
248    }
249    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
250        Poll::Ready(Ok(()))
251    }
252}