1use crate::Result;
2use aho_corasick::AhoCorasick;
3use std::collections::HashSet;
4use std::ffi::OsStr;
5use std::fmt::{Debug, Display, Formatter};
6use std::path::Path;
7use std::process::{ExitStatus, Stdio};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
11use tokio::{
12 io::BufReader,
13 process::Command,
14 select,
15 sync::{oneshot, Mutex},
16};
17use tokio_util::sync::CancellationToken;
18
19use indexmap::IndexSet;
20use std::sync::LazyLock as Lazy;
21
22use crate::Error::ScriptFailed;
23#[cfg(feature = "progress")]
24use clx::progress::{self, ProgressJob};
25
26struct Redactor {
28 automaton: AhoCorasick,
29 replacements: Vec<&'static str>,
30}
31
32pub struct CmdLineRunner {
56 cmd: Command,
57 program: String,
58 args: Vec<String>,
59 #[cfg(feature = "progress")]
60 pr: Option<Arc<ProgressJob>>,
61 stdin: Option<String>,
62 redactions: IndexSet<String>,
63 #[cfg(feature = "progress")]
64 show_stderr_on_error: bool,
65 #[cfg(feature = "progress")]
66 stderr_to_progress: bool,
67 cancel: CancellationToken,
68 allow_non_zero: bool,
69 timeout: Option<Duration>,
70}
71
72static RUNNING_PIDS: Lazy<std::sync::Mutex<HashSet<u32>>> = Lazy::new(Default::default);
73
74impl CmdLineRunner {
75 pub fn new<P: AsRef<OsStr>>(program: P) -> Self {
80 let program = program.as_ref().to_string_lossy().to_string();
81 let mut cmd = if cfg!(windows) {
82 let mut cmd = Command::new("cmd.exe");
83 cmd.arg("/c").arg(&program);
84 cmd
85 } else {
86 Command::new(&program)
87 };
88 cmd.stdin(Stdio::null());
89 cmd.stdout(Stdio::piped());
90 cmd.stderr(Stdio::piped());
91
92 Self {
93 cmd,
94 program,
95 args: vec![],
96 #[cfg(feature = "progress")]
97 pr: None,
98 stdin: None,
99 redactions: Default::default(),
100 #[cfg(feature = "progress")]
101 show_stderr_on_error: true,
102 #[cfg(feature = "progress")]
103 stderr_to_progress: false,
104 cancel: CancellationToken::new(),
105 allow_non_zero: false,
106 timeout: None,
107 }
108 }
109
110 #[cfg(unix)]
116 pub fn kill_all(signal: nix::sys::signal::Signal) {
117 let Ok(pids) = RUNNING_PIDS.lock() else {
118 debug!("Failed to acquire lock on RUNNING_PIDS");
119 return;
120 };
121 for pid in pids.iter() {
122 let pgid = nix::unistd::Pid::from_raw(*pid as i32);
123 trace!("{signal}: pgid {pid}");
124 if let Err(e) = nix::sys::signal::killpg(pgid, signal) {
125 debug!("Failed to kill process group {pid}: {e}");
126 }
127 }
128 }
129
130 #[cfg(windows)]
134 pub fn kill_all() {
135 let Ok(pids) = RUNNING_PIDS.lock() else {
136 debug!("Failed to acquire lock on RUNNING_PIDS");
137 return;
138 };
139 for pid in pids.iter() {
140 if let Err(e) = Command::new("taskkill")
141 .arg("/F")
142 .arg("/T")
143 .arg("/PID")
144 .arg(pid.to_string())
145 .spawn()
146 {
147 warn!("Failed to kill cmd {pid}: {e}");
148 }
149 }
150 }
151
152 pub fn stdin<T: Into<Stdio>>(mut self, cfg: T) -> Self {
154 self.cmd.stdin(cfg);
155 self
156 }
157
158 pub fn stdout<T: Into<Stdio>>(mut self, cfg: T) -> Self {
160 self.cmd.stdout(cfg);
161 self
162 }
163
164 pub fn stderr<T: Into<Stdio>>(mut self, cfg: T) -> Self {
166 self.cmd.stderr(cfg);
167 self
168 }
169
170 pub fn redact(mut self, redactions: impl IntoIterator<Item = String>) -> Self {
194 for r in redactions {
195 self.redactions.insert(r);
196 }
197 self
198 }
199
200 #[cfg(feature = "progress")]
207 pub fn with_pr(mut self, pr: Arc<ProgressJob>) -> Self {
208 self.pr = Some(pr);
209 self
210 }
211
212 pub fn with_cancel_token(mut self, cancel: CancellationToken) -> Self {
216 self.cancel = cancel;
217 self
218 }
219
220 #[cfg(feature = "progress")]
226 pub fn show_stderr_on_error(mut self, show: bool) -> Self {
227 self.show_stderr_on_error = show;
228 self
229 }
230
231 #[cfg(feature = "progress")]
238 pub fn stderr_to_progress(mut self, enable: bool) -> Self {
239 self.stderr_to_progress = enable;
240 self
241 }
242
243 pub fn allow_non_zero(mut self, allow: bool) -> Self {
270 self.allow_non_zero = allow;
271 self
272 }
273
274 pub fn timeout(mut self, duration: Duration) -> Self {
297 self.timeout = Some(duration);
298 self
299 }
300
301 pub fn current_dir<P: AsRef<Path>>(mut self, dir: P) -> Self {
303 self.cmd.current_dir(dir);
304 self
305 }
306
307 pub fn env_clear(mut self) -> Self {
309 self.cmd.env_clear();
310 self
311 }
312
313 pub fn env<K, V>(mut self, key: K, val: V) -> Self
315 where
316 K: AsRef<OsStr>,
317 V: AsRef<OsStr>,
318 {
319 self.cmd.env(key, val);
320 self
321 }
322
323 pub fn envs<I, K, V>(mut self, vars: I) -> Self
325 where
326 I: IntoIterator<Item = (K, V)>,
327 K: AsRef<OsStr>,
328 V: AsRef<OsStr>,
329 {
330 self.cmd.envs(vars);
331 self
332 }
333
334 pub fn opt_arg<S: AsRef<OsStr>>(mut self, arg: Option<S>) -> Self {
338 if let Some(arg) = arg {
339 self.cmd.arg(arg);
340 }
341 self
342 }
343
344 pub fn arg<S: AsRef<OsStr>>(mut self, arg: S) -> Self {
346 self.cmd.arg(arg.as_ref());
347 self.args.push(arg.as_ref().to_string_lossy().to_string());
348 self
349 }
350
351 pub fn args<I, S>(mut self, args: I) -> Self
353 where
354 I: IntoIterator<Item = S>,
355 S: AsRef<OsStr>,
356 {
357 let args = args
358 .into_iter()
359 .map(|s| s.as_ref().to_string_lossy().to_string())
360 .collect::<Vec<_>>();
361 self.cmd.args(&args);
362 self.args.extend(args);
363 self
364 }
365
366 pub fn stdin_string(mut self, input: impl Into<String>) -> Self {
370 self.cmd.stdin(Stdio::piped());
371 self.stdin = Some(input.into());
372 self
373 }
374
375 pub async fn execute(mut self) -> Result<CmdResult> {
386 debug!("$ {self}");
387
388 let redactor: Option<Arc<Redactor>> = if self.redactions.is_empty() {
391 None
392 } else {
393 let automaton = AhoCorasick::new(self.redactions.iter()).map_err(|e| {
394 crate::Error::Internal(format!("failed to build redaction matcher: {e}"))
395 })?;
396 let replacements = vec!["[redacted]"; self.redactions.len()];
397 Some(Arc::new(Redactor {
398 automaton,
399 replacements,
400 }))
401 };
402
403 #[cfg(unix)]
406 self.cmd.process_group(0);
407
408 let mut cp = self.cmd.spawn()?;
409 let id = match cp.id() {
410 Some(id) => id,
411 None => {
412 let _ = cp.kill().await;
413 return Err(crate::Error::Internal("process has no id".to_string()));
414 }
415 };
416 if let Err(e) = RUNNING_PIDS
417 .lock()
418 .map(|mut pids| pids.insert(id))
419 .map_err(|e| e.to_string())
420 {
421 let _ = cp.kill().await;
422 return Err(crate::Error::Internal(format!(
423 "failed to lock RUNNING_PIDS: {e}"
424 )));
425 }
426 trace!("Started process: {id} for {}", self.program);
427 #[cfg(feature = "progress")]
428 if let Some(pr) = &self.pr {
429 pr.prop("ensembler_cmd", &self.to_string());
430 pr.prop("ensembler_stdout", &"".to_string());
431 pr.set_status(progress::ProgressStatus::Running);
432 }
433 let result = Arc::new(Mutex::new(CmdResult::default()));
434 let combined_output = Arc::new(Mutex::new(Vec::new()));
435
436 let (stdout_flush, stdout_ready) = oneshot::channel();
437 if let Some(stdout) = cp.stdout.take() {
438 let result = result.clone();
439 let combined_output = combined_output.clone();
440 let redactor = redactor.clone();
441 #[cfg(feature = "progress")]
442 let pr = self.pr.clone();
443 tokio::spawn(async move {
444 let stdout = BufReader::new(stdout);
445 let mut lines = stdout.lines();
446 while let Ok(Some(line)) = lines.next_line().await {
447 let line = match &redactor {
448 Some(r) => r.automaton.replace_all(&line, &r.replacements),
449 None => line,
450 };
451 let mut result = result.lock().await;
452 result.stdout += &line;
453 result.stdout += "\n";
454 result.combined_output += &line;
455 result.combined_output += "\n";
456 #[cfg(feature = "progress")]
457 if let Some(pr) = &pr {
458 pr.prop("ensembler_stdout", &line);
459 pr.update();
460 }
461 combined_output.lock().await.push(line);
462 }
463 let _ = stdout_flush.send(());
464 });
465 } else {
466 drop(stdout_flush);
467 }
468 let (stderr_flush, stderr_ready) = oneshot::channel();
469 if let Some(stderr) = cp.stderr.take() {
470 let result = result.clone();
471 let combined_output = combined_output.clone();
472 #[cfg(feature = "progress")]
473 let pr = self.pr.clone();
474 #[cfg(feature = "progress")]
475 let stderr_to_progress = self.stderr_to_progress;
476 tokio::spawn(async move {
477 let stderr = BufReader::new(stderr);
478 let mut lines = stderr.lines();
479 while let Ok(Some(line)) = lines.next_line().await {
480 let line = match &redactor {
481 Some(r) => r.automaton.replace_all(&line, &r.replacements),
482 None => line,
483 };
484 let mut result = result.lock().await;
485 result.stderr += &line;
486 result.stderr += "\n";
487 result.combined_output += &line;
488 result.combined_output += "\n";
489 #[cfg(feature = "progress")]
490 if let Some(pr) = &pr {
491 if stderr_to_progress {
492 pr.prop("ensembler_stdout", &line);
494 pr.update();
495 } else {
496 pr.println(&line);
498 }
499 }
500 combined_output.lock().await.push(line);
501 }
502 let _ = stderr_flush.send(());
503 });
504 } else {
505 drop(stderr_flush);
506 }
507 let (stdin_flush, stdin_ready) = oneshot::channel();
508 if let Some(text) = self.stdin.take() {
509 let Some(mut stdin) = cp.stdin.take() else {
510 let _ = cp.kill().await;
511 if let Err(e) = RUNNING_PIDS
512 .lock()
513 .map(|mut pids| pids.remove(&id))
514 .map_err(|e| e.to_string())
515 {
516 debug!("Failed to lock RUNNING_PIDS to remove pid {id}: {e}");
517 }
518 #[cfg(feature = "progress")]
519 if let Some(pr) = &self.pr {
520 pr.set_status(progress::ProgressStatus::Failed);
521 }
522 return Err(crate::Error::Internal(
523 "stdin was requested but not available".to_string(),
524 ));
525 };
526 tokio::spawn(async move {
527 if let Err(e) = stdin.write_all(text.as_bytes()).await {
528 debug!("Failed to write to stdin: {e}");
529 }
530 let _ = stdin_flush.send(());
531 });
532 } else {
533 drop(stdin_flush);
534 }
535
536 let timeout_fut = async {
538 if let Some(duration) = self.timeout {
539 tokio::time::sleep(duration).await;
540 } else {
541 std::future::pending::<()>().await;
542 }
543 };
544 tokio::pin!(timeout_fut);
545
546 let mut timed_out = false;
547 let mut was_cancelled = false;
548 let status = loop {
549 select! {
553 biased;
554 status = cp.wait() => {
555 break status?;
556 }
557 _ = &mut timeout_fut => {
558 timed_out = true;
559 #[cfg(unix)]
560 kill_process_group(id);
561 let _ = cp.kill().await;
562 }
563 _ = self.cancel.cancelled() => {
564 was_cancelled = true;
565 #[cfg(unix)]
566 kill_process_group(id);
567 let _ = cp.kill().await;
568 }
569 }
570 };
571 if let Err(e) = RUNNING_PIDS
572 .lock()
573 .map(|mut pids| pids.remove(&id))
574 .map_err(|e| e.to_string())
575 {
576 debug!("Failed to lock RUNNING_PIDS to remove pid {id}: {e}");
577 }
578
579 if was_cancelled {
580 #[cfg(feature = "progress")]
581 if let Some(pr) = &self.pr {
582 pr.set_status(progress::ProgressStatus::Failed);
583 }
584 return Err(crate::Error::Cancelled);
585 }
586
587 if timed_out {
588 #[cfg(feature = "progress")]
589 if let Some(pr) = &self.pr {
590 pr.set_status(progress::ProgressStatus::Failed);
591 }
592 return Err(crate::Error::TimedOut);
593 }
594
595 result.lock().await.status = status;
596
597 let _ = stdout_ready.await;
599 let _ = stderr_ready.await;
600 let _ = stdin_ready.await;
601
602 if status.success() || self.allow_non_zero {
603 #[cfg(feature = "progress")]
604 if let Some(pr) = &self.pr {
605 pr.set_status(progress::ProgressStatus::Done);
606 }
607 } else {
608 let result = result.lock().await.to_owned();
609 self.on_error(combined_output.lock().await.join("\n"), result)?;
610 }
611
612 let result = result.lock().await.to_owned();
613 Ok(result)
614 }
615
616 fn on_error(&self, output: String, result: CmdResult) -> Result<()> {
617 let output = output.trim().to_string();
618 #[cfg(feature = "progress")]
619 if let Some(pr) = &self.pr {
620 pr.set_status(progress::ProgressStatus::Failed);
621 if self.show_stderr_on_error {
622 pr.println(&output);
623 }
624 }
625 Err(ScriptFailed(Box::new((
626 self.program.clone(),
627 self.args.clone(),
628 output,
629 result,
630 ))))?
631 }
632}
633
634#[cfg(unix)]
637fn kill_process_group(pid: u32) {
638 let pgid = nix::unistd::Pid::from_raw(pid as i32);
639 if let Err(e) = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGKILL) {
640 debug!("Failed to kill process group {pid}: {e}");
641 }
642}
643
644impl Display for CmdLineRunner {
645 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
646 let args = self.args.join(" ");
647 let mut cmd = format!("{} {}", &self.program, args);
648 if cmd.starts_with("sh -o errexit -c ") {
649 cmd = cmd[17..].to_string();
650 }
651 write!(f, "{cmd}")
652 }
653}
654
655impl Debug for CmdLineRunner {
656 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
657 let args = self.args.join(" ");
658 write!(f, "{} {args}", self.program)
659 }
660}
661
662#[derive(Debug, Default, Clone)]
666pub struct CmdResult {
667 pub stdout: String,
669 pub stderr: String,
671 pub combined_output: String,
673 pub status: ExitStatus,
675}