1use hashbrown::HashMap;
10use std::io::{self, ErrorKind};
11use std::path::Path;
12use std::process::Stdio;
13use std::sync::Arc;
14use std::sync::Mutex as StdMutex;
15use std::sync::atomic::AtomicBool;
16
17use anyhow::{Context, Result};
18use bytes::Bytes;
19use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader};
20use tokio::process::Command;
21use tokio::sync::{broadcast, mpsc, oneshot};
22use tokio::task::JoinHandle;
23
24use crate::process::{ChildTerminator, ProcessHandle, SpawnedProcess};
25use crate::process_group;
26
27struct PipeChildTerminator {
29 #[cfg(windows)]
30 pid: u32,
31 #[cfg(unix)]
32 process_group_id: u32,
33}
34
35impl ChildTerminator for PipeChildTerminator {
36 fn kill(&mut self) -> io::Result<()> {
37 #[cfg(unix)]
38 {
39 process_group::kill_process_group(self.process_group_id)
40 }
41
42 #[cfg(windows)]
43 {
44 process_group::kill_process(self.pid)
45 }
46
47 #[cfg(not(any(unix, windows)))]
48 {
49 Ok(())
50 }
51 }
52}
53
54async fn read_output_stream<R>(mut reader: R, output_tx: broadcast::Sender<Bytes>)
56where
57 R: AsyncRead + Unpin,
58{
59 let mut buf = vec![0u8; 8_192];
60 loop {
61 match reader.read(&mut buf).await {
62 Ok(0) => break,
63 Ok(n) => {
64 let _ = output_tx.send(Bytes::copy_from_slice(&buf[..n]));
65 }
66 Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
67 Err(_) => break,
68 }
69 }
70}
71
72#[derive(Clone, Copy)]
74pub enum PipeStdinMode {
75 Piped,
77 Null,
79}
80
81#[derive(Clone)]
83pub struct PipeSpawnOptions {
84 pub program: String,
86 pub args: Vec<String>,
88 pub cwd: std::path::PathBuf,
90 pub env: Option<HashMap<String, String>>,
92 pub arg0: Option<String>,
94 pub stdin_mode: PipeStdinMode,
96}
97
98impl PipeSpawnOptions {
99 pub fn new(program: impl Into<String>, cwd: impl Into<std::path::PathBuf>) -> Self {
101 Self {
102 program: program.into(),
103 args: Vec::new(),
104 cwd: cwd.into(),
105 env: None,
106 arg0: None,
107 stdin_mode: PipeStdinMode::Piped,
108 }
109 }
110
111 pub fn args(mut self, args: impl IntoIterator<Item = impl Into<String>>) -> Self {
113 self.args = args.into_iter().map(Into::into).collect();
114 self
115 }
116
117 pub fn env(mut self, env: HashMap<String, String>) -> Self {
119 self.env = Some(env);
120 self
121 }
122
123 pub fn arg0(mut self, arg0: impl Into<String>) -> Self {
125 self.arg0 = Some(arg0.into());
126 self
127 }
128
129 pub fn stdin_mode(mut self, mode: PipeStdinMode) -> Self {
131 self.stdin_mode = mode;
132 self
133 }
134}
135
136async fn spawn_process_internal(opts: PipeSpawnOptions) -> Result<SpawnedProcess> {
138 if opts.program.is_empty() {
139 anyhow::bail!("missing program for pipe spawn");
140 }
141
142 let mut command = Command::new(&opts.program);
143
144 #[cfg(unix)]
145 if let Some(ref arg0) = opts.arg0 {
146 command.arg0(arg0);
147 }
148
149 #[cfg(unix)]
150 {
151 command.process_group(0);
152 }
153
154 #[cfg(not(unix))]
155 let _ = &opts.arg0;
156
157 command.current_dir(&opts.cwd);
158
159 if let Some(ref env) = opts.env {
161 command.env_clear();
162 for (key, value) in env {
163 command.env(key, value);
164 }
165 }
166
167 for arg in &opts.args {
168 command.arg(arg);
169 }
170
171 match opts.stdin_mode {
172 PipeStdinMode::Piped => {
173 command.stdin(Stdio::piped());
174 }
175 PipeStdinMode::Null => {
176 command.stdin(Stdio::null());
177 }
178 }
179 command.stdout(Stdio::piped());
180 command.stderr(Stdio::piped());
181
182 let mut child = command.spawn().context("failed to spawn pipe process")?;
183 let pid = child
184 .id()
185 .ok_or_else(|| io::Error::other("missing child pid"))?;
186
187 #[cfg(unix)]
188 let process_group_id = pid;
189
190 let stdin = child.stdin.take();
191 let stdout = child.stdout.take();
192 let stderr = child.stderr.take();
193
194 let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
195 let (output_tx, _) = broadcast::channel::<Bytes>(256);
196 let initial_output_rx = output_tx.subscribe();
197
198 let writer_handle = if let Some(stdin) = stdin {
200 let writer = Arc::new(tokio::sync::Mutex::new(stdin));
201 tokio::spawn(async move {
202 while let Some(bytes) = writer_rx.recv().await {
203 let mut guard = writer.lock().await;
204 let _ = guard.write_all(&bytes).await;
205 let _ = guard.flush().await;
206 }
207 })
208 } else {
209 drop(writer_rx);
210 tokio::spawn(async {})
211 };
212
213 let stdout_handle = stdout.map(|stdout| {
215 let output_tx = output_tx.clone();
216 tokio::spawn(async move {
217 read_output_stream(BufReader::new(stdout), output_tx).await;
218 })
219 });
220
221 let stderr_handle = stderr.map(|stderr| {
222 let output_tx = output_tx.clone();
223 tokio::spawn(async move {
224 read_output_stream(BufReader::new(stderr), output_tx).await;
225 })
226 });
227
228 let mut reader_abort_handles = Vec::new();
229 if let Some(ref handle) = stdout_handle {
230 reader_abort_handles.push(handle.abort_handle());
231 }
232 if let Some(ref handle) = stderr_handle {
233 reader_abort_handles.push(handle.abort_handle());
234 }
235
236 let reader_handle = tokio::spawn(async move {
237 if let Some(handle) = stdout_handle {
238 let _ = handle.await;
239 }
240 if let Some(handle) = stderr_handle {
241 let _ = handle.await;
242 }
243 });
244
245 let (exit_tx, exit_rx) = oneshot::channel::<i32>();
247 let exit_status = Arc::new(AtomicBool::new(false));
248 let wait_exit_status = Arc::clone(&exit_status);
249 let exit_code = Arc::new(StdMutex::new(None));
250 let wait_exit_code = Arc::clone(&exit_code);
251
252 let wait_handle: JoinHandle<()> = tokio::spawn(async move {
253 let code = match child.wait().await {
254 Ok(status) => status.code().unwrap_or(-1),
255 Err(_) => -1,
256 };
257 wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst);
258 if let Ok(mut guard) = wait_exit_code.lock() {
259 *guard = Some(code);
260 }
261 let _ = exit_tx.send(code);
262 });
263
264 let (handle, output_rx) = ProcessHandle::new(
265 writer_tx,
266 output_tx,
267 initial_output_rx,
268 Box::new(PipeChildTerminator {
269 #[cfg(windows)]
270 pid,
271 #[cfg(unix)]
272 process_group_id,
273 }),
274 reader_handle,
275 reader_abort_handles,
276 writer_handle,
277 wait_handle,
278 exit_status,
279 exit_code,
280 None,
281 );
282
283 Ok(SpawnedProcess {
284 session: handle,
285 output_rx,
286 exit_rx,
287 })
288}
289
290pub async fn spawn_process(
304 program: &str,
305 args: &[String],
306 cwd: &Path,
307 env: &HashMap<String, String>,
308 arg0: &Option<String>,
309) -> Result<SpawnedProcess> {
310 let opts = PipeSpawnOptions {
311 program: program.to_string(),
312 args: args.to_vec(),
313 cwd: cwd.to_path_buf(),
314 env: Some(env.clone()),
315 arg0: arg0.clone(),
316 stdin_mode: PipeStdinMode::Piped,
317 };
318 spawn_process_internal(opts).await
319}
320
321pub async fn spawn_process_no_stdin(
325 program: &str,
326 args: &[String],
327 cwd: &Path,
328 env: &HashMap<String, String>,
329 arg0: &Option<String>,
330) -> Result<SpawnedProcess> {
331 let opts = PipeSpawnOptions {
332 program: program.to_string(),
333 args: args.to_vec(),
334 cwd: cwd.to_path_buf(),
335 env: Some(env.clone()),
336 arg0: arg0.clone(),
337 stdin_mode: PipeStdinMode::Null,
338 };
339 spawn_process_internal(opts).await
340}
341
342pub async fn spawn_process_with_options(opts: PipeSpawnOptions) -> Result<SpawnedProcess> {
344 spawn_process_internal(opts).await
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 fn find_echo_command() -> Option<(String, Vec<String>)> {
352 #[cfg(windows)]
353 {
354 Some((
355 "cmd.exe".to_string(),
356 vec!["/C".to_string(), "echo".to_string()],
357 ))
358 }
359 #[cfg(not(windows))]
360 {
361 Some(("echo".to_string(), vec![]))
362 }
363 }
364
365 #[tokio::test]
366 async fn test_spawn_process_echo() -> Result<()> {
367 let Some((program, mut base_args)) = find_echo_command() else {
368 return Ok(());
369 };
370
371 base_args.push("hello".to_string());
372
373 let env: HashMap<String, String> = std::env::vars().collect();
374 let spawned = spawn_process(&program, &base_args, Path::new("."), &env, &None).await?;
375
376 let exit_code = spawned.exit_rx.await.unwrap_or(-1);
377 assert_eq!(exit_code, 0);
378
379 Ok(())
380 }
381
382 #[tokio::test]
383 async fn test_spawn_options_builder() {
384 let opts = PipeSpawnOptions::new("echo", ".")
385 .args(["hello", "world"])
386 .stdin_mode(PipeStdinMode::Null);
387
388 assert_eq!(opts.program, "echo");
389 assert_eq!(opts.args, vec!["hello", "world"]);
390 assert!(matches!(opts.stdin_mode, PipeStdinMode::Null));
391 }
392}