cmd_wrapper/
tokio_command.rs1use std::{
2 ffi::OsStr,
3 io::{Error, Result},
4 path::Path,
5 pin::Pin,
6 process::{ExitStatus, Stdio},
7 task::{Context, Poll},
8};
9
10use derivative::Derivative;
11use tokio::{
12 io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
13 process::{Child, Command},
14 sync::oneshot::{self, Receiver},
15 task::JoinHandle,
16};
17use tracing::{error, instrument, warn};
18
19#[derive(Debug)]
20pub struct TokioCommand {
21 command: Command,
22}
23
24impl TokioCommand {
25 pub fn new<S: Into<String>>(program: S) -> Self {
26 Self {
27 command: Command::new(program.into()),
28 }
29 }
30
31 pub fn arg<S: AsRef<OsStr>>(&mut self, arg: S) -> &mut Self {
32 self.command.arg(arg);
33 self
34 }
35
36 pub fn args<I, S>(&mut self, args: I) -> &mut Self
37 where
38 I: IntoIterator<Item = S>,
39 S: AsRef<OsStr>,
40 {
41 self.command.args(args);
42 self
43 }
44
45 pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Self
46 where
47 K: AsRef<OsStr>,
48 V: AsRef<OsStr>,
49 {
50 self.command.env(key, val);
51 self
52 }
53
54 pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Self
55 where
56 I: IntoIterator<Item = (K, V)>,
57 K: AsRef<OsStr>,
58 V: AsRef<OsStr>,
59 {
60 self.command.envs(vars);
61 self
62 }
63
64 pub fn current_dir<P: AsRef<Path>>(&mut self, dir: P) -> &mut Self {
65 self.command.current_dir(dir);
66 self
67 }
68
69 pub fn stdin<T: Into<Stdio>>(&mut self, stdin: T) -> &mut Self {
70 self.command.stdin(stdin);
71 self
72 }
73
74 pub fn stdout<T: Into<Stdio>>(&mut self, stdout: T) -> &mut Self {
75 self.command.stdout(stdout);
76 self
77 }
78
79 pub fn stderr<T: Into<Stdio>>(&mut self, stderr: T) -> &mut Self {
80 self.command.stderr(stderr);
81 self
82 }
83
84 #[instrument(level = "trace")]
85 pub fn into_inner(self) -> Command {
86 self.command
87 }
88
89 #[instrument(level = "trace")]
90 pub fn child(mut self) -> Result<Child> {
91 self.command.spawn()
92 }
93
94 #[instrument(level = "trace")]
95 pub async fn spawn(&mut self) -> Result<TokioCommandProcess> {
96 let mut child = self.command.spawn()?;
97
98 let stdin = child
99 .stdin
100 .take()
101 .map(|s| Box::new(s) as Box<dyn AsyncWrite + Send + Unpin>)
102 .unwrap_or_else(|| Box::new(AsyncSink));
103 let stdout = child
104 .stdout
105 .take()
106 .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
107 .unwrap_or_else(|| Box::new(io::empty()));
108 let stderr = child
109 .stderr
110 .take()
111 .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
112 .unwrap_or_else(|| Box::new(io::empty()));
113
114 let pid = child.id().unwrap_or_default();
115
116 let (end_tx, end_rx) = oneshot::channel::<i32>();
117 let end_handler = tokio::spawn(async move {
118 let exit_code = match child.wait().await {
119 Ok(status) if status.code() == Some(0) => 0,
120 Ok(status) => status.code().unwrap_or(-1),
121 Err(err) => {
122 error!("waiting for child process failed: {err}");
123 -1
124 }
125 };
126
127 if let Err(err) = end_tx.send(exit_code) {
128 warn!("receiver dropped before sending exit code: {err}");
129 }
130 });
131
132 Ok(TokioCommandProcess {
133 stdin,
134 stdout,
135 stderr,
136 pid,
137 end_rx,
138 end_handler,
139 })
140 }
141
142 #[instrument(level = "trace")]
143 pub async fn output(&mut self) -> Result<String> {
144 let output = self.command.output().await?;
145
146 if !output.status.success() {
147 let stderr = String::from_utf8_lossy(&output.stderr);
148 return Err(Error::other(stderr));
149 }
150
151 Ok(String::from_utf8_lossy(&output.stdout).to_string())
152 }
153
154 #[instrument(level = "trace")]
155 pub async fn status(&mut self) -> Result<ExitStatus> {
156 self.command.status().await
157 }
158
159 #[instrument(level = "trace")]
160 pub async fn execute_and_print(&mut self) -> Result<()> {
161 let output = self.command.output().await?;
162 if !output.stdout.is_empty() {
163 io::stdout().write_all(&output.stdout).await?;
164 }
165 if !output.stderr.is_empty() {
166 io::stderr().write_all(&output.stderr).await?;
167 }
168 Ok(())
169 }
170
171 pub async fn read_full_stderr_if_any(stderr: &mut (impl AsyncRead + Unpin)) -> Result<()> {
172 let mut peek_buf = vec![0u8; 1024];
173 let n = stderr.read(&mut peek_buf).await?;
174 if n == 0 {
175 return Ok(());
176 }
177
178 let mut full = peek_buf[..n].to_vec();
179 let mut rest = Vec::new();
180 stderr.read_to_end(&mut rest).await?;
181 full.extend(rest);
182
183 let msg = String::from_utf8_lossy(&full).trim().to_string();
184 if !msg.is_empty() {
185 return Err(Error::other(msg));
186 }
187
188 Ok(())
189 }
190
191 pub async fn read_bytes_with_stderr_check(
192 stdout: &mut (impl AsyncRead + Unpin),
193 stderr: &mut (impl AsyncRead + Unpin),
194 ) -> Result<Option<Vec<u8>>> {
195 let mut out_buf = vec![0u8; 32 << 10];
196 let mut err_buf = vec![0u8; 1024];
197
198 loop {
199 tokio::select! {
200 read_stdout = stdout.read(&mut out_buf) => {
201 let n = read_stdout?;
202 if n == 0 {
203 return Ok(None);
204 }
205 out_buf.truncate(n);
206 return Ok(Some(out_buf))
207 }
208 read_stderr = stderr.read(&mut err_buf) => {
209 let n = read_stderr?;
210 if n > 0 {
211 let msg = str::from_utf8(&err_buf[..n]).map_err(Error::other)?.trim();
212 if !msg.is_empty() {
213 return Err(Error::other(msg));
214 }
215 }
216 }
217 }
218 }
219 }
220}
221
222#[derive(Derivative)]
223#[derivative(Debug)]
224pub struct TokioCommandProcess {
225 #[derivative(Debug = "ignore")]
226 pub stdin: Box<dyn AsyncWrite + Send + Unpin>,
227 #[derivative(Debug = "ignore")]
228 pub stdout: Box<dyn AsyncRead + Send + Unpin>,
229 #[derivative(Debug = "ignore")]
230 pub stderr: Box<dyn AsyncRead + Send + Unpin>,
231 pub pid: u32,
232 pub end_rx: Receiver<i32>,
233 pub end_handler: JoinHandle<()>,
234}
235
236struct AsyncSink;
237
238impl AsyncWrite for AsyncSink {
239 fn poll_write(
240 self: Pin<&mut Self>,
241 _cx: &mut Context<'_>,
242 buf: &[u8],
243 ) -> Poll<std::io::Result<usize>> {
244 Poll::Ready(Ok(buf.len()))
245 }
246 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
247 Poll::Ready(Ok(()))
248 }
249 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
250 Poll::Ready(Ok(()))
251 }
252}