cmd_wrapper/
tokio_command.rs1use std::{
2 ffi::OsStr,
3 io::{Error, Result},
4 path::Path,
5 pin::Pin,
6 process::{ExitStatus, Stdio},
7 task::{Context, Poll},
8};
9
10use derivative::Derivative;
11use tokio::{
12 io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
13 process::{Child, Command},
14 sync::oneshot::{self, Receiver},
15 task::JoinHandle,
16};
17use tracing::{instrument, warn};
18
19use crate::ProcessEnd;
20
21#[derive(Debug)]
22pub struct TokioCommand {
23 command: Command,
24}
25
26impl TokioCommand {
27 pub fn new<S: AsRef<OsStr>>(program: S) -> Self {
28 Self {
29 command: Command::new(program),
30 }
31 }
32
33 pub fn arg<S: AsRef<OsStr>>(&mut self, arg: S) -> &mut Self {
34 self.command.arg(arg);
35 self
36 }
37
38 pub fn args<I, S>(&mut self, args: I) -> &mut Self
39 where
40 I: IntoIterator<Item = S>,
41 S: AsRef<OsStr>,
42 {
43 self.command.args(args);
44 self
45 }
46
47 pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Self
48 where
49 K: AsRef<OsStr>,
50 V: AsRef<OsStr>,
51 {
52 self.command.env(key, val);
53 self
54 }
55
56 pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Self
57 where
58 I: IntoIterator<Item = (K, V)>,
59 K: AsRef<OsStr>,
60 V: AsRef<OsStr>,
61 {
62 self.command.envs(vars);
63 self
64 }
65
66 pub fn current_dir<P: AsRef<Path>>(&mut self, dir: P) -> &mut Self {
67 self.command.current_dir(dir);
68 self
69 }
70
71 pub fn stdin<T: Into<Stdio>>(&mut self, stdin: T) -> &mut Self {
72 self.command.stdin(stdin);
73 self
74 }
75
76 pub fn stdout<T: Into<Stdio>>(&mut self, stdout: T) -> &mut Self {
77 self.command.stdout(stdout);
78 self
79 }
80
81 pub fn stderr<T: Into<Stdio>>(&mut self, stderr: T) -> &mut Self {
82 self.command.stderr(stderr);
83 self
84 }
85
86 #[instrument(level = "trace")]
87 pub fn into_inner(self) -> Command {
88 self.command
89 }
90
91 #[instrument(level = "trace")]
92 pub fn child(mut self) -> Result<Child> {
93 self.command.spawn()
94 }
95
96 #[instrument(level = "trace")]
97 pub async fn spawn(&mut self) -> Result<TokioCommandProcess> {
98 let mut child = self.command.spawn()?;
99
100 let stdin = child
101 .stdin
102 .take()
103 .map(|s| Box::new(s) as Box<dyn AsyncWrite + Send + Unpin>)
104 .unwrap_or_else(|| Box::new(AsyncSink));
105 let stdout = child
106 .stdout
107 .take()
108 .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
109 .unwrap_or_else(|| Box::new(io::empty()));
110 let stderr = child
111 .stderr
112 .take()
113 .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
114 .unwrap_or_else(|| Box::new(io::empty()));
115
116 let pid = child.id().unwrap_or_default();
117
118 let (end_tx, end_rx) = oneshot::channel::<ProcessEnd>();
119 let end_handler = tokio::spawn(async move {
120 let end_msg = match child.wait().await {
121 Ok(status) => {
122 if let Some(0) = status.code() {
123 ProcessEnd::Success
124 } else {
125 let err = Self::read_full_stderr_if_any(stderr).await.err();
126 if let Some(code) = status.code() {
127 ProcessEnd::Failed(code, err)
128 } else {
129 ProcessEnd::Killed(err)
130 }
131 }
132 }
133 Err(err) => ProcessEnd::Error(err),
134 };
135
136 if let Err(err) = end_tx.send(end_msg) {
137 warn!(?err, "receiver dropped before sending end msg");
138 }
139 });
140
141 Ok(TokioCommandProcess {
142 stdin,
143 stdout,
144 pid,
145 end_rx,
146 end_handler,
147 })
148 }
149
150 #[instrument(level = "trace")]
151 pub async fn output(&mut self) -> Result<String> {
152 let output = self.command.output().await?;
153
154 if !output.status.success() {
155 let stderr = String::from_utf8_lossy(&output.stderr);
156 return Err(Error::other(stderr));
157 }
158
159 Ok(String::from_utf8_lossy(&output.stdout).to_string())
160 }
161
162 #[instrument(level = "trace")]
163 pub async fn status(&mut self) -> Result<ExitStatus> {
164 self.command.status().await
165 }
166
167 #[instrument(level = "trace")]
168 pub async fn execute_and_print(&mut self) -> Result<()> {
169 let output = self.command.output().await?;
170 if !output.stdout.is_empty() {
171 io::stdout().write_all(&output.stdout).await?;
172 }
173 if !output.stderr.is_empty() {
174 io::stderr().write_all(&output.stderr).await?;
175 }
176 Ok(())
177 }
178
179 async fn read_full_stderr_if_any(mut stderr: impl AsyncRead + Unpin) -> Result<()> {
180 let mut peek_buf = vec![0u8; 1024];
181 let n = stderr.read(&mut peek_buf).await?;
182 if n == 0 {
183 return Ok(());
184 }
185
186 let mut full = peek_buf[..n].to_vec();
187 let mut rest = Vec::new();
188 stderr.read_to_end(&mut rest).await?;
189 full.extend(rest);
190
191 let msg = String::from_utf8_lossy(&full).trim().to_string();
192 if !msg.is_empty() {
193 return Err(Error::other(msg));
194 }
195
196 Ok(())
197 }
198}
199
200#[derive(Derivative)]
201#[derivative(Debug)]
202pub struct TokioCommandProcess {
203 #[derivative(Debug = "ignore")]
204 pub stdin: Box<dyn AsyncWrite + Send + Unpin>,
205 #[derivative(Debug = "ignore")]
206 pub stdout: Box<dyn AsyncRead + Send + Unpin>,
207 pub pid: u32,
208 pub end_rx: Receiver<ProcessEnd>,
209 pub end_handler: JoinHandle<()>,
210}
211
212struct AsyncSink;
213
214impl AsyncWrite for AsyncSink {
215 fn poll_write(
216 self: Pin<&mut Self>,
217 _cx: &mut Context<'_>,
218 buf: &[u8],
219 ) -> Poll<std::io::Result<usize>> {
220 Poll::Ready(Ok(buf.len()))
221 }
222 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
223 Poll::Ready(Ok(()))
224 }
225 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
226 Poll::Ready(Ok(()))
227 }
228}