cmd_wrapper/
tokio_command.rs

1use std::{
2    ffi::OsStr,
3    io::{Error, Result},
4    path::Path,
5    pin::Pin,
6    process::{ExitStatus, Stdio},
7    task::{Context, Poll},
8};
9
10use derivative::Derivative;
11use tokio::{
12    io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
13    process::{Child, Command},
14    sync::oneshot::{self, Receiver},
15    task::JoinHandle,
16};
17use tracing::{instrument, warn};
18
19#[derive(Debug)]
20pub struct TokioCommand {
21    command: Command,
22}
23
24impl TokioCommand {
25    pub fn new<S: Into<String>>(program: S) -> Self {
26        Self {
27            command: Command::new(program.into()),
28        }
29    }
30
31    pub fn arg<S: AsRef<OsStr>>(&mut self, arg: S) -> &mut Self {
32        self.command.arg(arg);
33        self
34    }
35
36    pub fn args<I, S>(&mut self, args: I) -> &mut Self
37    where
38        I: IntoIterator<Item = S>,
39        S: AsRef<OsStr>,
40    {
41        self.command.args(args);
42        self
43    }
44
45    pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Self
46    where
47        K: AsRef<OsStr>,
48        V: AsRef<OsStr>,
49    {
50        self.command.env(key, val);
51        self
52    }
53
54    pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Self
55    where
56        I: IntoIterator<Item = (K, V)>,
57        K: AsRef<OsStr>,
58        V: AsRef<OsStr>,
59    {
60        self.command.envs(vars);
61        self
62    }
63
64    pub fn current_dir<P: AsRef<Path>>(&mut self, dir: P) -> &mut Self {
65        self.command.current_dir(dir);
66        self
67    }
68
69    pub fn stdin<T: Into<Stdio>>(&mut self, stdin: T) -> &mut Self {
70        self.command.stdin(stdin);
71        self
72    }
73
74    pub fn stdout<T: Into<Stdio>>(&mut self, stdout: T) -> &mut Self {
75        self.command.stdout(stdout);
76        self
77    }
78
79    pub fn stderr<T: Into<Stdio>>(&mut self, stderr: T) -> &mut Self {
80        self.command.stderr(stderr);
81        self
82    }
83
84    #[instrument(level = "trace")]
85    pub fn into_inner(self) -> Command {
86        self.command
87    }
88
89    #[instrument(level = "trace")]
90    pub fn child(mut self) -> Result<Child> {
91        self.command.spawn()
92    }
93
94    #[instrument(level = "trace")]
95    pub async fn spawn(&mut self) -> Result<TokioCommandProcess> {
96        let mut child = self.command.spawn()?;
97
98        let stdin = child
99            .stdin
100            .take()
101            .map(|s| Box::new(s) as Box<dyn AsyncWrite + Send + Unpin>)
102            .unwrap_or_else(|| Box::new(AsyncSink));
103        let stdout = child
104            .stdout
105            .take()
106            .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
107            .unwrap_or_else(|| Box::new(io::empty()));
108        let stderr = child
109            .stderr
110            .take()
111            .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
112            .unwrap_or_else(|| Box::new(io::empty()));
113
114        let pid = child.id().unwrap_or_default();
115
116        let (end_tx, end_rx) = oneshot::channel::<(Option<i32>, Option<io::Error>)>();
117        let end_handler = tokio::spawn(async move {
118            let end_msg = match child.wait().await {
119                Ok(status) if status.code() == Some(0) => (Some(0), None),
120                Ok(status) => (
121                    status.code(),
122                    Self::read_full_stderr_if_any(stderr).await.err(),
123                ),
124                Err(err) => (Some(-1), Some(err)),
125            };
126
127            if let Err(err) = end_tx.send(end_msg) {
128                warn!(?err, "receiver dropped before sending end msg");
129            }
130        });
131
132        Ok(TokioCommandProcess {
133            stdin,
134            stdout,
135            pid,
136            end_rx,
137            end_handler,
138        })
139    }
140
141    #[instrument(level = "trace")]
142    pub async fn output(&mut self) -> Result<String> {
143        let output = self.command.output().await?;
144
145        if !output.status.success() {
146            let stderr = String::from_utf8_lossy(&output.stderr);
147            return Err(Error::other(stderr));
148        }
149
150        Ok(String::from_utf8_lossy(&output.stdout).to_string())
151    }
152
153    #[instrument(level = "trace")]
154    pub async fn status(&mut self) -> Result<ExitStatus> {
155        self.command.status().await
156    }
157
158    #[instrument(level = "trace")]
159    pub async fn execute_and_print(&mut self) -> Result<()> {
160        let output = self.command.output().await?;
161        if !output.stdout.is_empty() {
162            io::stdout().write_all(&output.stdout).await?;
163        }
164        if !output.stderr.is_empty() {
165            io::stderr().write_all(&output.stderr).await?;
166        }
167        Ok(())
168    }
169
170    async fn read_full_stderr_if_any(mut stderr: impl AsyncRead + Unpin) -> Result<()> {
171        let mut peek_buf = vec![0u8; 1024];
172        let n = stderr.read(&mut peek_buf).await?;
173        if n == 0 {
174            return Ok(());
175        }
176
177        let mut full = peek_buf[..n].to_vec();
178        let mut rest = Vec::new();
179        stderr.read_to_end(&mut rest).await?;
180        full.extend(rest);
181
182        let msg = String::from_utf8_lossy(&full).trim().to_string();
183        if !msg.is_empty() {
184            return Err(Error::other(msg));
185        }
186
187        Ok(())
188    }
189}
190
191#[derive(Derivative)]
192#[derivative(Debug)]
193pub struct TokioCommandProcess {
194    #[derivative(Debug = "ignore")]
195    pub stdin: Box<dyn AsyncWrite + Send + Unpin>,
196    #[derivative(Debug = "ignore")]
197    pub stdout: Box<dyn AsyncRead + Send + Unpin>,
198    pub pid: u32,
199    pub end_rx: Receiver<(Option<i32>, Option<io::Error>)>,
200    pub end_handler: JoinHandle<()>,
201}
202
203struct AsyncSink;
204
205impl AsyncWrite for AsyncSink {
206    fn poll_write(
207        self: Pin<&mut Self>,
208        _cx: &mut Context<'_>,
209        buf: &[u8],
210    ) -> Poll<std::io::Result<usize>> {
211        Poll::Ready(Ok(buf.len()))
212    }
213    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
214        Poll::Ready(Ok(()))
215    }
216    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
217        Poll::Ready(Ok(()))
218    }
219}