cmd_wrapper/
tokio_command.rs

1use std::{
2    ffi::OsStr,
3    io::{Error, Result},
4    path::Path,
5    pin::Pin,
6    process::{ExitStatus, Stdio},
7    task::{Context, Poll},
8};
9
10use derivative::Derivative;
11use tokio::{
12    io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
13    process::{Child, Command},
14    sync::oneshot::{self, Receiver},
15    task::JoinHandle,
16};
17use tracing::{instrument, warn};
18
19use crate::ProcessEnd;
20
21#[derive(Debug)]
22pub struct TokioCommand {
23    command: Command,
24}
25
26impl TokioCommand {
27    pub fn new<S: Into<String>>(program: S) -> Self {
28        Self {
29            command: Command::new(program.into()),
30        }
31    }
32
33    pub fn arg<S: AsRef<OsStr>>(&mut self, arg: S) -> &mut Self {
34        self.command.arg(arg);
35        self
36    }
37
38    pub fn args<I, S>(&mut self, args: I) -> &mut Self
39    where
40        I: IntoIterator<Item = S>,
41        S: AsRef<OsStr>,
42    {
43        self.command.args(args);
44        self
45    }
46
47    pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Self
48    where
49        K: AsRef<OsStr>,
50        V: AsRef<OsStr>,
51    {
52        self.command.env(key, val);
53        self
54    }
55
56    pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Self
57    where
58        I: IntoIterator<Item = (K, V)>,
59        K: AsRef<OsStr>,
60        V: AsRef<OsStr>,
61    {
62        self.command.envs(vars);
63        self
64    }
65
66    pub fn current_dir<P: AsRef<Path>>(&mut self, dir: P) -> &mut Self {
67        self.command.current_dir(dir);
68        self
69    }
70
71    pub fn stdin<T: Into<Stdio>>(&mut self, stdin: T) -> &mut Self {
72        self.command.stdin(stdin);
73        self
74    }
75
76    pub fn stdout<T: Into<Stdio>>(&mut self, stdout: T) -> &mut Self {
77        self.command.stdout(stdout);
78        self
79    }
80
81    pub fn stderr<T: Into<Stdio>>(&mut self, stderr: T) -> &mut Self {
82        self.command.stderr(stderr);
83        self
84    }
85
86    #[instrument(level = "trace")]
87    pub fn into_inner(self) -> Command {
88        self.command
89    }
90
91    #[instrument(level = "trace")]
92    pub fn child(mut self) -> Result<Child> {
93        self.command.spawn()
94    }
95
96    #[instrument(level = "trace")]
97    pub async fn spawn(&mut self) -> Result<TokioCommandProcess> {
98        let mut child = self.command.spawn()?;
99
100        let stdin = child
101            .stdin
102            .take()
103            .map(|s| Box::new(s) as Box<dyn AsyncWrite + Send + Unpin>)
104            .unwrap_or_else(|| Box::new(AsyncSink));
105        let stdout = child
106            .stdout
107            .take()
108            .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
109            .unwrap_or_else(|| Box::new(io::empty()));
110        let stderr = child
111            .stderr
112            .take()
113            .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
114            .unwrap_or_else(|| Box::new(io::empty()));
115
116        let pid = child.id().unwrap_or_default();
117
118        let (end_tx, end_rx) = oneshot::channel::<ProcessEnd>();
119        let end_handler = tokio::spawn(async move {
120            let end_msg = match child.wait().await {
121                Ok(status) => {
122                    if let Some(0) = status.code() {
123                        ProcessEnd::Success
124                    } else {
125                        let err = Self::read_full_stderr_if_any(stderr).await.err();
126                        if let Some(code) = status.code() {
127                            ProcessEnd::Failed(code, err)
128                        } else {
129                            ProcessEnd::Killed(err)
130                        }
131                    }
132                }
133                Err(err) => ProcessEnd::Error(err),
134            };
135
136            if let Err(err) = end_tx.send(end_msg) {
137                warn!(?err, "receiver dropped before sending end msg");
138            }
139        });
140
141        Ok(TokioCommandProcess {
142            stdin,
143            stdout,
144            pid,
145            end_rx,
146            end_handler,
147        })
148    }
149
150    #[instrument(level = "trace")]
151    pub async fn output(&mut self) -> Result<String> {
152        let output = self.command.output().await?;
153
154        if !output.status.success() {
155            let stderr = String::from_utf8_lossy(&output.stderr);
156            return Err(Error::other(stderr));
157        }
158
159        Ok(String::from_utf8_lossy(&output.stdout).to_string())
160    }
161
162    #[instrument(level = "trace")]
163    pub async fn status(&mut self) -> Result<ExitStatus> {
164        self.command.status().await
165    }
166
167    #[instrument(level = "trace")]
168    pub async fn execute_and_print(&mut self) -> Result<()> {
169        let output = self.command.output().await?;
170        if !output.stdout.is_empty() {
171            io::stdout().write_all(&output.stdout).await?;
172        }
173        if !output.stderr.is_empty() {
174            io::stderr().write_all(&output.stderr).await?;
175        }
176        Ok(())
177    }
178
179    async fn read_full_stderr_if_any(mut stderr: impl AsyncRead + Unpin) -> Result<()> {
180        let mut peek_buf = vec![0u8; 1024];
181        let n = stderr.read(&mut peek_buf).await?;
182        if n == 0 {
183            return Ok(());
184        }
185
186        let mut full = peek_buf[..n].to_vec();
187        let mut rest = Vec::new();
188        stderr.read_to_end(&mut rest).await?;
189        full.extend(rest);
190
191        let msg = String::from_utf8_lossy(&full).trim().to_string();
192        if !msg.is_empty() {
193            return Err(Error::other(msg));
194        }
195
196        Ok(())
197    }
198}
199
200#[derive(Derivative)]
201#[derivative(Debug)]
202pub struct TokioCommandProcess {
203    #[derivative(Debug = "ignore")]
204    pub stdin: Box<dyn AsyncWrite + Send + Unpin>,
205    #[derivative(Debug = "ignore")]
206    pub stdout: Box<dyn AsyncRead + Send + Unpin>,
207    pub pid: u32,
208    pub end_rx: Receiver<ProcessEnd>,
209    pub end_handler: JoinHandle<()>,
210}
211
212struct AsyncSink;
213
214impl AsyncWrite for AsyncSink {
215    fn poll_write(
216        self: Pin<&mut Self>,
217        _cx: &mut Context<'_>,
218        buf: &[u8],
219    ) -> Poll<std::io::Result<usize>> {
220        Poll::Ready(Ok(buf.len()))
221    }
222    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
223        Poll::Ready(Ok(()))
224    }
225    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
226        Poll::Ready(Ok(()))
227    }
228}