cmd_wrapper/
tokio_command.rs1use std::{
2 ffi::OsStr,
3 io::{Error, Result},
4 path::Path,
5 pin::Pin,
6 process::{ExitStatus, Stdio},
7 task::{Context, Poll},
8};
9
10use tokio::{
11 io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
12 process::{Child, Command},
13 sync::oneshot::{self, Receiver},
14 task::JoinHandle,
15};
16use tracing::{error, instrument, warn};
17
18#[derive(Debug)]
19pub struct TokioCommand {
20 command: Command,
21}
22
23impl TokioCommand {
24 pub fn new<S: Into<String>>(program: S) -> Self {
25 Self {
26 command: Command::new(program.into()),
27 }
28 }
29
30 pub fn arg<S: AsRef<OsStr>>(&mut self, arg: S) -> &mut Self {
31 self.command.arg(arg);
32 self
33 }
34
35 pub fn args<I, S>(&mut self, args: I) -> &mut Self
36 where
37 I: IntoIterator<Item = S>,
38 S: AsRef<OsStr>,
39 {
40 self.command.args(args);
41 self
42 }
43
44 pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Self
45 where
46 K: AsRef<OsStr>,
47 V: AsRef<OsStr>,
48 {
49 self.command.env(key, val);
50 self
51 }
52
53 pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Self
54 where
55 I: IntoIterator<Item = (K, V)>,
56 K: AsRef<OsStr>,
57 V: AsRef<OsStr>,
58 {
59 self.command.envs(vars);
60 self
61 }
62
63 pub fn current_dir<P: AsRef<Path>>(&mut self, dir: P) -> &mut Self {
64 self.command.current_dir(dir);
65 self
66 }
67
68 pub fn stdin<T: Into<Stdio>>(&mut self, stdin: T) -> &mut Self {
69 self.command.stdin(stdin);
70 self
71 }
72
73 pub fn stdout<T: Into<Stdio>>(&mut self, stdout: T) -> &mut Self {
74 self.command.stdout(stdout);
75 self
76 }
77
78 pub fn stderr<T: Into<Stdio>>(&mut self, stderr: T) -> &mut Self {
79 self.command.stderr(stderr);
80 self
81 }
82
83 #[instrument(level = "trace")]
84 pub fn into_inner(self) -> Command {
85 self.command
86 }
87
88 #[instrument(level = "trace")]
89 pub fn child(mut self) -> Result<Child> {
90 self.command.spawn()
91 }
92
93 #[instrument(level = "trace")]
94 pub async fn spawn(&mut self) -> Result<TokioProcessContext> {
95 let mut child = self.command.spawn()?;
96
97 let stdin = child
98 .stdin
99 .take()
100 .map(|s| Box::new(s) as Box<dyn AsyncWrite + Send + Unpin>)
101 .unwrap_or_else(|| Box::new(AsyncSink));
102 let stdout = child
103 .stdout
104 .take()
105 .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
106 .unwrap_or_else(|| Box::new(io::empty()));
107 let stderr = child
108 .stderr
109 .take()
110 .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
111 .unwrap_or_else(|| Box::new(io::empty()));
112
113 let pid = child.id().unwrap_or_default();
114
115 let (end_tx, end_rx) = oneshot::channel::<i32>();
116 let end_handler = tokio::spawn(async move {
117 let exit_code = match child.wait().await {
118 Ok(status) if status.code() == Some(0) => 0,
119 Ok(status) => status.code().unwrap_or(-1),
120 Err(err) => {
121 error!("waiting for child process failed: {err}");
122 -1
123 }
124 };
125
126 if let Err(err) = end_tx.send(exit_code) {
127 warn!("receiver dropped before sending exit code: {err}");
128 }
129 });
130
131 Ok(TokioProcessContext {
132 stdin,
133 stdout,
134 stderr,
135 pid,
136 end_rx,
137 end_handler,
138 })
139 }
140
141 #[instrument(level = "trace")]
142 pub async fn output(&mut self) -> Result<String> {
143 let output = self.command.output().await?;
144
145 if !output.status.success() {
146 let stderr = String::from_utf8_lossy(&output.stderr);
147 return Err(Error::other(stderr));
148 }
149
150 Ok(String::from_utf8_lossy(&output.stdout).to_string())
151 }
152
153 #[instrument(level = "trace")]
154 pub async fn status(&mut self) -> Result<ExitStatus> {
155 self.command.status().await
156 }
157
158 #[instrument(level = "trace")]
159 pub async fn execute_and_print(&mut self) -> Result<()> {
160 let output = self.command.output().await?;
161 if !output.stdout.is_empty() {
162 io::stdout().write_all(&output.stdout).await?;
163 }
164 if !output.stderr.is_empty() {
165 io::stderr().write_all(&output.stderr).await?;
166 }
167 Ok(())
168 }
169
170 pub async fn read_full_stderr_if_any(stderr: &mut (impl AsyncRead + Unpin)) -> Result<()> {
171 let mut peek_buf = vec![0u8; 1024];
172 let n = stderr.read(&mut peek_buf).await?;
173 if n == 0 {
174 return Ok(());
175 }
176
177 let mut full = peek_buf[..n].to_vec();
178 let mut rest = Vec::new();
179 stderr.read_to_end(&mut rest).await?;
180 full.extend(rest);
181
182 let msg = String::from_utf8_lossy(&full).trim().to_string();
183 if !msg.is_empty() {
184 return Err(Error::other(msg));
185 }
186
187 Ok(())
188 }
189
190 pub async fn read_bytes_with_stderr_check(
191 stdout: &mut (impl AsyncRead + Unpin),
192 stderr: &mut (impl AsyncRead + Unpin),
193 ) -> Result<Option<Vec<u8>>> {
194 let mut out_buf = vec![0u8; 32 << 10];
195 let mut err_buf = vec![0u8; 1024];
196
197 loop {
198 tokio::select! {
199 read_stdout = stdout.read(&mut out_buf) => {
200 let n = read_stdout?;
201 if n == 0 {
202 return Ok(None);
203 }
204 out_buf.truncate(n);
205 return Ok(Some(out_buf))
206 }
207 read_stderr = stderr.read(&mut err_buf) => {
208 let n = read_stderr?;
209 if n > 0 {
210 let msg = str::from_utf8(&err_buf[..n]).map_err(Error::other)?.trim();
211 if !msg.is_empty() {
212 return Err(Error::other(msg));
213 }
214 }
215 }
216 }
217 }
218 }
219}
220
221pub struct TokioProcessContext {
222 pub stdin: Box<dyn AsyncWrite + Send + Unpin>,
223 pub stdout: Box<dyn AsyncRead + Send + Unpin>,
224 pub stderr: Box<dyn AsyncRead + Send + Unpin>,
225 pub pid: u32,
226 pub end_rx: Receiver<i32>,
227 pub end_handler: JoinHandle<()>,
228}
229
230struct AsyncSink;
231
232impl AsyncWrite for AsyncSink {
233 fn poll_write(
234 self: Pin<&mut Self>,
235 _cx: &mut Context<'_>,
236 buf: &[u8],
237 ) -> Poll<std::io::Result<usize>> {
238 Poll::Ready(Ok(buf.len()))
239 }
240 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
241 Poll::Ready(Ok(()))
242 }
243 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
244 Poll::Ready(Ok(()))
245 }
246}