cmd_wrapper/
tokio_command.rs

1use std::{
2    ffi::OsStr,
3    io::{Error, Result},
4    path::Path,
5    pin::Pin,
6    process::{ExitStatus, Stdio},
7    task::{Context, Poll},
8};
9
10use tokio::{
11    io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
12    process::{Child, Command},
13    sync::oneshot::{self, Receiver},
14    task::JoinHandle,
15};
16use tracing::{error, instrument, warn};
17
18#[derive(Debug)]
19pub struct TokioCommand {
20    command: Command,
21}
22
23impl TokioCommand {
24    pub fn new<S: Into<String>>(program: S) -> Self {
25        Self {
26            command: Command::new(program.into()),
27        }
28    }
29
30    pub fn arg<S: AsRef<OsStr>>(&mut self, arg: S) -> &mut Self {
31        self.command.arg(arg);
32        self
33    }
34
35    pub fn args<I, S>(&mut self, args: I) -> &mut Self
36    where
37        I: IntoIterator<Item = S>,
38        S: AsRef<OsStr>,
39    {
40        self.command.args(args);
41        self
42    }
43
44    pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Self
45    where
46        K: AsRef<OsStr>,
47        V: AsRef<OsStr>,
48    {
49        self.command.env(key, val);
50        self
51    }
52
53    pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Self
54    where
55        I: IntoIterator<Item = (K, V)>,
56        K: AsRef<OsStr>,
57        V: AsRef<OsStr>,
58    {
59        self.command.envs(vars);
60        self
61    }
62
63    pub fn current_dir<P: AsRef<Path>>(&mut self, dir: P) -> &mut Self {
64        self.command.current_dir(dir);
65        self
66    }
67
68    pub fn stdin<T: Into<Stdio>>(&mut self, stdin: T) -> &mut Self {
69        self.command.stdin(stdin);
70        self
71    }
72
73    pub fn stdout<T: Into<Stdio>>(&mut self, stdout: T) -> &mut Self {
74        self.command.stdout(stdout);
75        self
76    }
77
78    pub fn stderr<T: Into<Stdio>>(&mut self, stderr: T) -> &mut Self {
79        self.command.stderr(stderr);
80        self
81    }
82
83    #[instrument(level = "trace")]
84    pub fn into_inner(self) -> Command {
85        self.command
86    }
87
88    #[instrument(level = "trace")]
89    pub fn child(mut self) -> Result<Child> {
90        self.command.spawn()
91    }
92
93    #[instrument(level = "trace")]
94    pub async fn spawn(&mut self) -> Result<TokioProcessContext> {
95        let mut child = self.command.spawn()?;
96
97        let stdin = child
98            .stdin
99            .take()
100            .map(|s| Box::new(s) as Box<dyn AsyncWrite + Send + Unpin>)
101            .unwrap_or_else(|| Box::new(AsyncSink));
102        let stdout = child
103            .stdout
104            .take()
105            .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
106            .unwrap_or_else(|| Box::new(io::empty()));
107        let stderr = child
108            .stderr
109            .take()
110            .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
111            .unwrap_or_else(|| Box::new(io::empty()));
112
113        let pid = child.id().unwrap_or_default();
114
115        let (end_tx, end_rx) = oneshot::channel::<i32>();
116        let end_handler = tokio::spawn(async move {
117            let exit_code = match child.wait().await {
118                Ok(status) if status.code() == Some(0) => 0,
119                Ok(status) => status.code().unwrap_or(-1),
120                Err(err) => {
121                    error!("waiting for child process failed: {err}");
122                    -1
123                }
124            };
125
126            if let Err(err) = end_tx.send(exit_code) {
127                warn!("receiver dropped before sending exit code: {err}");
128            }
129        });
130
131        Ok(TokioProcessContext {
132            stdin,
133            stdout,
134            stderr,
135            pid,
136            end_rx,
137            end_handler,
138        })
139    }
140
141    #[instrument(level = "trace")]
142    pub async fn output(&mut self) -> Result<String> {
143        let output = self.command.output().await?;
144
145        if !output.status.success() {
146            let stderr = String::from_utf8_lossy(&output.stderr);
147            return Err(Error::other(stderr));
148        }
149
150        Ok(String::from_utf8_lossy(&output.stdout).to_string())
151    }
152
153    #[instrument(level = "trace")]
154    pub async fn status(&mut self) -> Result<ExitStatus> {
155        self.command.status().await
156    }
157
158    #[instrument(level = "trace")]
159    pub async fn execute_and_print(&mut self) -> Result<()> {
160        let output = self.command.output().await?;
161        if !output.stdout.is_empty() {
162            io::stdout().write_all(&output.stdout).await?;
163        }
164        if !output.stderr.is_empty() {
165            io::stderr().write_all(&output.stderr).await?;
166        }
167        Ok(())
168    }
169
170    pub async fn read_full_stderr_if_any(stderr: &mut (impl AsyncRead + Unpin)) -> Result<()> {
171        let mut peek_buf = vec![0u8; 1024];
172        let n = stderr.read(&mut peek_buf).await?;
173        if n == 0 {
174            return Ok(());
175        }
176
177        let mut full = peek_buf[..n].to_vec();
178        let mut rest = Vec::new();
179        stderr.read_to_end(&mut rest).await?;
180        full.extend(rest);
181
182        let msg = String::from_utf8_lossy(&full).trim().to_string();
183        if !msg.is_empty() {
184            return Err(Error::other(msg));
185        }
186
187        Ok(())
188    }
189
190    pub async fn read_bytes_with_stderr_check(
191        stdout: &mut (impl AsyncRead + Unpin),
192        stderr: &mut (impl AsyncRead + Unpin),
193    ) -> Result<Option<Vec<u8>>> {
194        let mut out_buf = vec![0u8; 32 << 10];
195        let mut err_buf = vec![0u8; 1024];
196
197        loop {
198            tokio::select! {
199                read_stdout = stdout.read(&mut out_buf) => {
200                    let n = read_stdout?;
201                    if n == 0 {
202                        return Ok(None);
203                    }
204                    out_buf.truncate(n);
205                    return Ok(Some(out_buf))
206                }
207                read_stderr = stderr.read(&mut err_buf) => {
208                    let n = read_stderr?;
209                    if n > 0 {
210                        let msg = str::from_utf8(&err_buf[..n]).map_err(Error::other)?.trim();
211                        if !msg.is_empty() {
212                            return Err(Error::other(msg));
213                        }
214                    }
215                }
216            }
217        }
218    }
219}
220
221pub struct TokioProcessContext {
222    pub stdin: Box<dyn AsyncWrite + Send + Unpin>,
223    pub stdout: Box<dyn AsyncRead + Send + Unpin>,
224    pub stderr: Box<dyn AsyncRead + Send + Unpin>,
225    pub pid: u32,
226    pub end_rx: Receiver<i32>,
227    pub end_handler: JoinHandle<()>,
228}
229
230struct AsyncSink;
231
232impl AsyncWrite for AsyncSink {
233    fn poll_write(
234        self: Pin<&mut Self>,
235        _cx: &mut Context<'_>,
236        buf: &[u8],
237    ) -> Poll<std::io::Result<usize>> {
238        Poll::Ready(Ok(buf.len()))
239    }
240    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
241        Poll::Ready(Ok(()))
242    }
243    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
244        Poll::Ready(Ok(()))
245    }
246}