cmd_wrapper/
tokio_command.rs1use std::{
2 ffi::OsStr,
3 io::{Error, Result},
4 path::Path,
5 pin::Pin,
6 process::{ExitStatus, Stdio},
7 task::{Context, Poll},
8};
9
10use derivative::Derivative;
11use tokio::{
12 io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
13 process::{Child, Command},
14 sync::oneshot::{self, Receiver},
15 task::JoinHandle,
16};
17use tracing::{instrument, warn};
18
19#[derive(Debug)]
20pub struct TokioCommand {
21 command: Command,
22}
23
24impl TokioCommand {
25 pub fn new<S: Into<String>>(program: S) -> Self {
26 Self {
27 command: Command::new(program.into()),
28 }
29 }
30
31 pub fn arg<S: AsRef<OsStr>>(&mut self, arg: S) -> &mut Self {
32 self.command.arg(arg);
33 self
34 }
35
36 pub fn args<I, S>(&mut self, args: I) -> &mut Self
37 where
38 I: IntoIterator<Item = S>,
39 S: AsRef<OsStr>,
40 {
41 self.command.args(args);
42 self
43 }
44
45 pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Self
46 where
47 K: AsRef<OsStr>,
48 V: AsRef<OsStr>,
49 {
50 self.command.env(key, val);
51 self
52 }
53
54 pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Self
55 where
56 I: IntoIterator<Item = (K, V)>,
57 K: AsRef<OsStr>,
58 V: AsRef<OsStr>,
59 {
60 self.command.envs(vars);
61 self
62 }
63
64 pub fn current_dir<P: AsRef<Path>>(&mut self, dir: P) -> &mut Self {
65 self.command.current_dir(dir);
66 self
67 }
68
69 pub fn stdin<T: Into<Stdio>>(&mut self, stdin: T) -> &mut Self {
70 self.command.stdin(stdin);
71 self
72 }
73
74 pub fn stdout<T: Into<Stdio>>(&mut self, stdout: T) -> &mut Self {
75 self.command.stdout(stdout);
76 self
77 }
78
79 pub fn stderr<T: Into<Stdio>>(&mut self, stderr: T) -> &mut Self {
80 self.command.stderr(stderr);
81 self
82 }
83
84 #[instrument(level = "trace")]
85 pub fn into_inner(self) -> Command {
86 self.command
87 }
88
89 #[instrument(level = "trace")]
90 pub fn child(mut self) -> Result<Child> {
91 self.command.spawn()
92 }
93
94 #[instrument(level = "trace")]
95 pub async fn spawn(&mut self) -> Result<TokioCommandProcess> {
96 let mut child = self.command.spawn()?;
97
98 let stdin = child
99 .stdin
100 .take()
101 .map(|s| Box::new(s) as Box<dyn AsyncWrite + Send + Unpin>)
102 .unwrap_or_else(|| Box::new(AsyncSink));
103 let stdout = child
104 .stdout
105 .take()
106 .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
107 .unwrap_or_else(|| Box::new(io::empty()));
108 let stderr = child
109 .stderr
110 .take()
111 .map(|s| Box::new(s) as Box<dyn AsyncRead + Send + Unpin>)
112 .unwrap_or_else(|| Box::new(io::empty()));
113
114 let pid = child.id().unwrap_or_default();
115
116 let (end_tx, end_rx) = oneshot::channel::<(Option<i32>, Option<io::Error>)>();
117 let end_handler = tokio::spawn(async move {
118 let end_msg = match child.wait().await {
119 Ok(status) if status.code() == Some(0) => (Some(0), None),
120 Ok(status) => (
121 status.code(),
122 Self::read_full_stderr_if_any(stderr).await.err(),
123 ),
124 Err(err) => (Some(-1), Some(err)),
125 };
126
127 if let Err(err) = end_tx.send(end_msg) {
128 warn!(?err, "receiver dropped before sending end msg");
129 }
130 });
131
132 Ok(TokioCommandProcess {
133 stdin,
134 stdout,
135 pid,
136 end_rx,
137 end_handler,
138 })
139 }
140
141 #[instrument(level = "trace")]
142 pub async fn output(&mut self) -> Result<String> {
143 let output = self.command.output().await?;
144
145 if !output.status.success() {
146 let stderr = String::from_utf8_lossy(&output.stderr);
147 return Err(Error::other(stderr));
148 }
149
150 Ok(String::from_utf8_lossy(&output.stdout).to_string())
151 }
152
153 #[instrument(level = "trace")]
154 pub async fn status(&mut self) -> Result<ExitStatus> {
155 self.command.status().await
156 }
157
158 #[instrument(level = "trace")]
159 pub async fn execute_and_print(&mut self) -> Result<()> {
160 let output = self.command.output().await?;
161 if !output.stdout.is_empty() {
162 io::stdout().write_all(&output.stdout).await?;
163 }
164 if !output.stderr.is_empty() {
165 io::stderr().write_all(&output.stderr).await?;
166 }
167 Ok(())
168 }
169
170 async fn read_full_stderr_if_any(mut stderr: impl AsyncRead + Unpin) -> Result<()> {
171 let mut peek_buf = vec![0u8; 1024];
172 let n = stderr.read(&mut peek_buf).await?;
173 if n == 0 {
174 return Ok(());
175 }
176
177 let mut full = peek_buf[..n].to_vec();
178 let mut rest = Vec::new();
179 stderr.read_to_end(&mut rest).await?;
180 full.extend(rest);
181
182 let msg = String::from_utf8_lossy(&full).trim().to_string();
183 if !msg.is_empty() {
184 return Err(Error::other(msg));
185 }
186
187 Ok(())
188 }
189}
190
191#[derive(Derivative)]
192#[derivative(Debug)]
193pub struct TokioCommandProcess {
194 #[derivative(Debug = "ignore")]
195 pub stdin: Box<dyn AsyncWrite + Send + Unpin>,
196 #[derivative(Debug = "ignore")]
197 pub stdout: Box<dyn AsyncRead + Send + Unpin>,
198 pub pid: u32,
199 pub end_rx: Receiver<(Option<i32>, Option<io::Error>)>,
200 pub end_handler: JoinHandle<()>,
201}
202
203struct AsyncSink;
204
205impl AsyncWrite for AsyncSink {
206 fn poll_write(
207 self: Pin<&mut Self>,
208 _cx: &mut Context<'_>,
209 buf: &[u8],
210 ) -> Poll<std::io::Result<usize>> {
211 Poll::Ready(Ok(buf.len()))
212 }
213 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
214 Poll::Ready(Ok(()))
215 }
216 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
217 Poll::Ready(Ok(()))
218 }
219}