Skip to main content

ensembler/
cmd.rs

1use crate::Result;
2use aho_corasick::AhoCorasick;
3use std::collections::HashSet;
4use std::ffi::OsStr;
5use std::fmt::{Debug, Display, Formatter};
6use std::path::Path;
7use std::process::{ExitStatus, Stdio};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
11use tokio::{
12    io::BufReader,
13    process::Command,
14    select,
15    sync::{oneshot, Mutex},
16};
17use tokio_util::sync::CancellationToken;
18
19use indexmap::IndexSet;
20use std::sync::LazyLock as Lazy;
21
22use crate::Error::ScriptFailed;
23#[cfg(feature = "progress")]
24use clx::progress::{self, ProgressJob};
25
26/// Holds the Aho-Corasick automaton and replacement strings for redaction.
27struct Redactor {
28    automaton: AhoCorasick,
29    replacements: Vec<&'static str>,
30}
31
32/// A builder for executing external commands with advanced output handling.
33///
34/// `CmdLineRunner` provides a fluent API for configuring and executing external
35/// commands. It supports output capture, secret redaction, progress bar integration,
36/// and cancellation.
37///
38/// # Example
39///
40/// ```no_run
41/// use ensembler::CmdLineRunner;
42///
43/// #[tokio::main]
44/// async fn main() -> ensembler::Result<()> {
45///     let result = CmdLineRunner::new("ls")
46///         .arg("-la")
47///         .current_dir("/tmp")
48///         .execute()
49///         .await?;
50///
51///     println!("{}", result.stdout);
52///     Ok(())
53/// }
54/// ```
55pub struct CmdLineRunner {
56    cmd: Command,
57    program: String,
58    args: Vec<String>,
59    #[cfg(feature = "progress")]
60    pr: Option<Arc<ProgressJob>>,
61    stdin: Option<String>,
62    redactions: IndexSet<String>,
63    #[cfg(feature = "progress")]
64    show_stderr_on_error: bool,
65    #[cfg(feature = "progress")]
66    stderr_to_progress: bool,
67    cancel: CancellationToken,
68    allow_non_zero: bool,
69    timeout: Option<Duration>,
70}
71
72static RUNNING_PIDS: Lazy<std::sync::Mutex<HashSet<u32>>> = Lazy::new(Default::default);
73
74impl CmdLineRunner {
75    /// Creates a new command runner for the given program.
76    ///
77    /// On Windows, commands are automatically wrapped with `cmd.exe /c`.
78    /// The command is configured with piped stdout/stderr and null stdin by default.
79    pub fn new<P: AsRef<OsStr>>(program: P) -> Self {
80        let program = program.as_ref().to_string_lossy().to_string();
81        let mut cmd = if cfg!(windows) {
82            let mut cmd = Command::new("cmd.exe");
83            cmd.arg("/c").arg(&program);
84            cmd
85        } else {
86            Command::new(&program)
87        };
88        cmd.stdin(Stdio::null());
89        cmd.stdout(Stdio::piped());
90        cmd.stderr(Stdio::piped());
91
92        Self {
93            cmd,
94            program,
95            args: vec![],
96            #[cfg(feature = "progress")]
97            pr: None,
98            stdin: None,
99            redactions: Default::default(),
100            #[cfg(feature = "progress")]
101            show_stderr_on_error: true,
102            #[cfg(feature = "progress")]
103            stderr_to_progress: false,
104            cancel: CancellationToken::new(),
105            allow_non_zero: false,
106            timeout: None,
107        }
108    }
109
110    /// Sends a signal to all running child process groups.
111    ///
112    /// Each child is placed in its own process group at spawn time, so this
113    /// kills the entire process tree (not just the direct child).
114    /// This is useful for graceful shutdown scenarios.
115    #[cfg(unix)]
116    pub fn kill_all(signal: nix::sys::signal::Signal) {
117        let Ok(pids) = RUNNING_PIDS.lock() else {
118            debug!("Failed to acquire lock on RUNNING_PIDS");
119            return;
120        };
121        for pid in pids.iter() {
122            let pgid = nix::unistd::Pid::from_raw(*pid as i32);
123            trace!("{signal}: pgid {pid}");
124            if let Err(e) = nix::sys::signal::killpg(pgid, signal) {
125                debug!("Failed to kill process group {pid}: {e}");
126            }
127        }
128    }
129
130    /// Terminates all running child processes on Windows.
131    ///
132    /// Uses `taskkill /F /T` to forcefully terminate process trees.
133    #[cfg(windows)]
134    pub fn kill_all() {
135        let Ok(pids) = RUNNING_PIDS.lock() else {
136            debug!("Failed to acquire lock on RUNNING_PIDS");
137            return;
138        };
139        for pid in pids.iter() {
140            if let Err(e) = Command::new("taskkill")
141                .arg("/F")
142                .arg("/T")
143                .arg("/PID")
144                .arg(pid.to_string())
145                .spawn()
146            {
147                warn!("Failed to kill cmd {pid}: {e}");
148            }
149        }
150    }
151
152    /// Configures stdin handling for the command.
153    pub fn stdin<T: Into<Stdio>>(mut self, cfg: T) -> Self {
154        self.cmd.stdin(cfg);
155        self
156    }
157
158    /// Configures stdout handling for the command.
159    pub fn stdout<T: Into<Stdio>>(mut self, cfg: T) -> Self {
160        self.cmd.stdout(cfg);
161        self
162    }
163
164    /// Configures stderr handling for the command.
165    pub fn stderr<T: Into<Stdio>>(mut self, cfg: T) -> Self {
166        self.cmd.stderr(cfg);
167        self
168    }
169
170    /// Adds strings to redact from command output.
171    ///
172    /// Any occurrence of these strings in stdout or stderr will be replaced
173    /// with `[redacted]`. This is useful for hiding sensitive data like
174    /// API keys or passwords.
175    ///
176    /// # Example
177    ///
178    /// ```no_run
179    /// use ensembler::CmdLineRunner;
180    ///
181    /// # #[tokio::main]
182    /// # async fn main() -> ensembler::Result<()> {
183    /// let result = CmdLineRunner::new("echo")
184    ///     .arg("secret-api-key")
185    ///     .redact(vec!["secret-api-key".to_string()])
186    ///     .execute()
187    ///     .await?;
188    ///
189    /// assert_eq!(result.stdout.trim(), "[redacted]");
190    /// # Ok(())
191    /// # }
192    /// ```
193    pub fn redact(mut self, redactions: impl IntoIterator<Item = String>) -> Self {
194        for r in redactions {
195            self.redactions.insert(r);
196        }
197        self
198    }
199
200    /// Attaches a progress bar to display command status.
201    ///
202    /// The progress bar will be updated with the command being run and
203    /// its output. Uses the `clx` crate's progress bar system.
204    ///
205    /// This method is only available when the `progress` feature is enabled.
206    #[cfg(feature = "progress")]
207    pub fn with_pr(mut self, pr: Arc<ProgressJob>) -> Self {
208        self.pr = Some(pr);
209        self
210    }
211
212    /// Sets a cancellation token for the command.
213    ///
214    /// When the token is cancelled, the running process will be killed.
215    pub fn with_cancel_token(mut self, cancel: CancellationToken) -> Self {
216        self.cancel = cancel;
217        self
218    }
219
220    /// Controls whether stderr is displayed when the command fails.
221    ///
222    /// Defaults to `true`.
223    ///
224    /// This method is only available when the `progress` feature is enabled.
225    #[cfg(feature = "progress")]
226    pub fn show_stderr_on_error(mut self, show: bool) -> Self {
227        self.show_stderr_on_error = show;
228        self
229    }
230
231    /// Routes stderr to the progress bar instead of printing it directly.
232    ///
233    /// When enabled, stderr lines update the progress bar's status.
234    /// When disabled (default), stderr is printed above the progress bar.
235    ///
236    /// This method is only available when the `progress` feature is enabled.
237    #[cfg(feature = "progress")]
238    pub fn stderr_to_progress(mut self, enable: bool) -> Self {
239        self.stderr_to_progress = enable;
240        self
241    }
242
243    /// Allows the command to exit with a non-zero status without returning an error.
244    ///
245    /// When enabled, the command result is returned even if the exit code is non-zero.
246    /// This is useful when you need to capture output from commands that may fail
247    /// but still produce useful output.
248    ///
249    /// # Example
250    ///
251    /// ```no_run
252    /// use ensembler::CmdLineRunner;
253    ///
254    /// # #[tokio::main]
255    /// # async fn main() -> ensembler::Result<()> {
256    /// let result = CmdLineRunner::new("bash")
257    ///     .arg("-c")
258    ///     .arg("echo 'output'; exit 1")
259    ///     .allow_non_zero(true)
260    ///     .execute()
261    ///     .await?;
262    ///
263    /// // Command succeeded (no error) even though exit code was 1
264    /// assert_eq!(result.status.code(), Some(1));
265    /// assert_eq!(result.stdout.trim(), "output");
266    /// # Ok(())
267    /// # }
268    /// ```
269    pub fn allow_non_zero(mut self, allow: bool) -> Self {
270        self.allow_non_zero = allow;
271        self
272    }
273
274    /// Sets a timeout for the command.
275    ///
276    /// If the command does not complete within the specified duration,
277    /// it will be killed and [`Error::TimedOut`] will be returned.
278    ///
279    /// # Example
280    ///
281    /// ```no_run
282    /// use ensembler::CmdLineRunner;
283    /// use std::time::Duration;
284    ///
285    /// # #[tokio::main]
286    /// # async fn main() {
287    /// let result = CmdLineRunner::new("sleep")
288    ///     .arg("60")
289    ///     .timeout(Duration::from_secs(1))
290    ///     .execute()
291    ///     .await;
292    ///
293    /// assert!(result.is_err()); // TimedOut error
294    /// # }
295    /// ```
296    pub fn timeout(mut self, duration: Duration) -> Self {
297        self.timeout = Some(duration);
298        self
299    }
300
301    /// Sets the working directory for the command.
302    pub fn current_dir<P: AsRef<Path>>(mut self, dir: P) -> Self {
303        self.cmd.current_dir(dir);
304        self
305    }
306
307    /// Clears all environment variables for the command.
308    pub fn env_clear(mut self) -> Self {
309        self.cmd.env_clear();
310        self
311    }
312
313    /// Sets an environment variable for the command.
314    pub fn env<K, V>(mut self, key: K, val: V) -> Self
315    where
316        K: AsRef<OsStr>,
317        V: AsRef<OsStr>,
318    {
319        self.cmd.env(key, val);
320        self
321    }
322
323    /// Sets multiple environment variables for the command.
324    pub fn envs<I, K, V>(mut self, vars: I) -> Self
325    where
326        I: IntoIterator<Item = (K, V)>,
327        K: AsRef<OsStr>,
328        V: AsRef<OsStr>,
329    {
330        self.cmd.envs(vars);
331        self
332    }
333
334    /// Adds an optional argument to the command.
335    ///
336    /// If `arg` is `None`, no argument is added.
337    pub fn opt_arg<S: AsRef<OsStr>>(mut self, arg: Option<S>) -> Self {
338        if let Some(arg) = arg {
339            self.cmd.arg(arg);
340        }
341        self
342    }
343
344    /// Adds a single argument to the command.
345    pub fn arg<S: AsRef<OsStr>>(mut self, arg: S) -> Self {
346        self.cmd.arg(arg.as_ref());
347        self.args.push(arg.as_ref().to_string_lossy().to_string());
348        self
349    }
350
351    /// Adds multiple arguments to the command.
352    pub fn args<I, S>(mut self, args: I) -> Self
353    where
354        I: IntoIterator<Item = S>,
355        S: AsRef<OsStr>,
356    {
357        let args = args
358            .into_iter()
359            .map(|s| s.as_ref().to_string_lossy().to_string())
360            .collect::<Vec<_>>();
361        self.cmd.args(&args);
362        self.args.extend(args);
363        self
364    }
365
366    /// Pipes a string to the command's stdin.
367    ///
368    /// This automatically configures stdin to be piped.
369    pub fn stdin_string(mut self, input: impl Into<String>) -> Self {
370        self.cmd.stdin(Stdio::piped());
371        self.stdin = Some(input.into());
372        self
373    }
374
375    /// Executes the command and waits for it to complete.
376    ///
377    /// Returns [`CmdResult`] containing captured stdout, stderr, and exit status
378    /// on success. Returns an error if the command fails to start or exits with
379    /// a non-zero status.
380    ///
381    /// # Errors
382    ///
383    /// - [`Error::Io`] if the command fails to start
384    /// - [`Error::ScriptFailed`] if the command exits with a non-zero status
385    pub async fn execute(mut self) -> Result<CmdResult> {
386        debug!("$ {self}");
387
388        // Build Aho-Corasick automaton for efficient multi-pattern redaction
389        // This is done before spawning to avoid orphan processes on build failure
390        let redactor: Option<Arc<Redactor>> = if self.redactions.is_empty() {
391            None
392        } else {
393            let automaton = AhoCorasick::new(self.redactions.iter()).map_err(|e| {
394                crate::Error::Internal(format!("failed to build redaction matcher: {e}"))
395            })?;
396            let replacements = vec!["[redacted]"; self.redactions.len()];
397            Some(Arc::new(Redactor {
398                automaton,
399                replacements,
400            }))
401        };
402
403        // Put the child in its own process group so we can kill the entire
404        // tree on timeout/cancellation (not just the direct child).
405        #[cfg(unix)]
406        self.cmd.process_group(0);
407
408        let mut cp = self.cmd.spawn()?;
409        let id = match cp.id() {
410            Some(id) => id,
411            None => {
412                let _ = cp.kill().await;
413                return Err(crate::Error::Internal("process has no id".to_string()));
414            }
415        };
416        if let Err(e) = RUNNING_PIDS
417            .lock()
418            .map(|mut pids| pids.insert(id))
419            .map_err(|e| e.to_string())
420        {
421            let _ = cp.kill().await;
422            return Err(crate::Error::Internal(format!(
423                "failed to lock RUNNING_PIDS: {e}"
424            )));
425        }
426        trace!("Started process: {id} for {}", self.program);
427        #[cfg(feature = "progress")]
428        if let Some(pr) = &self.pr {
429            pr.prop("ensembler_cmd", &self.to_string());
430            pr.prop("ensembler_stdout", &"".to_string());
431            pr.set_status(progress::ProgressStatus::Running);
432        }
433        let result = Arc::new(Mutex::new(CmdResult::default()));
434        let combined_output = Arc::new(Mutex::new(Vec::new()));
435
436        let (stdout_flush, stdout_ready) = oneshot::channel();
437        if let Some(stdout) = cp.stdout.take() {
438            let result = result.clone();
439            let combined_output = combined_output.clone();
440            let redactor = redactor.clone();
441            #[cfg(feature = "progress")]
442            let pr = self.pr.clone();
443            tokio::spawn(async move {
444                let stdout = BufReader::new(stdout);
445                let mut lines = stdout.lines();
446                while let Ok(Some(line)) = lines.next_line().await {
447                    let line = match &redactor {
448                        Some(r) => r.automaton.replace_all(&line, &r.replacements),
449                        None => line,
450                    };
451                    let mut result = result.lock().await;
452                    result.stdout += &line;
453                    result.stdout += "\n";
454                    result.combined_output += &line;
455                    result.combined_output += "\n";
456                    #[cfg(feature = "progress")]
457                    if let Some(pr) = &pr {
458                        pr.prop("ensembler_stdout", &line);
459                        pr.update();
460                    }
461                    combined_output.lock().await.push(line);
462                }
463                let _ = stdout_flush.send(());
464            });
465        } else {
466            drop(stdout_flush);
467        }
468        let (stderr_flush, stderr_ready) = oneshot::channel();
469        if let Some(stderr) = cp.stderr.take() {
470            let result = result.clone();
471            let combined_output = combined_output.clone();
472            #[cfg(feature = "progress")]
473            let pr = self.pr.clone();
474            #[cfg(feature = "progress")]
475            let stderr_to_progress = self.stderr_to_progress;
476            tokio::spawn(async move {
477                let stderr = BufReader::new(stderr);
478                let mut lines = stderr.lines();
479                while let Ok(Some(line)) = lines.next_line().await {
480                    let line = match &redactor {
481                        Some(r) => r.automaton.replace_all(&line, &r.replacements),
482                        None => line,
483                    };
484                    let mut result = result.lock().await;
485                    result.stderr += &line;
486                    result.stderr += "\n";
487                    result.combined_output += &line;
488                    result.combined_output += "\n";
489                    #[cfg(feature = "progress")]
490                    if let Some(pr) = &pr {
491                        if stderr_to_progress {
492                            // Update progress bar like stdout does
493                            pr.prop("ensembler_stdout", &line);
494                            pr.update();
495                        } else {
496                            // Print above progress bars (current behavior)
497                            pr.println(&line);
498                        }
499                    }
500                    combined_output.lock().await.push(line);
501                }
502                let _ = stderr_flush.send(());
503            });
504        } else {
505            drop(stderr_flush);
506        }
507        let (stdin_flush, stdin_ready) = oneshot::channel();
508        if let Some(text) = self.stdin.take() {
509            let Some(mut stdin) = cp.stdin.take() else {
510                let _ = cp.kill().await;
511                if let Err(e) = RUNNING_PIDS
512                    .lock()
513                    .map(|mut pids| pids.remove(&id))
514                    .map_err(|e| e.to_string())
515                {
516                    debug!("Failed to lock RUNNING_PIDS to remove pid {id}: {e}");
517                }
518                #[cfg(feature = "progress")]
519                if let Some(pr) = &self.pr {
520                    pr.set_status(progress::ProgressStatus::Failed);
521                }
522                return Err(crate::Error::Internal(
523                    "stdin was requested but not available".to_string(),
524                ));
525            };
526            tokio::spawn(async move {
527                if let Err(e) = stdin.write_all(text.as_bytes()).await {
528                    debug!("Failed to write to stdin: {e}");
529                }
530                let _ = stdin_flush.send(());
531            });
532        } else {
533            drop(stdin_flush);
534        }
535
536        // Create timeout future that either sleeps or waits forever
537        let timeout_fut = async {
538            if let Some(duration) = self.timeout {
539                tokio::time::sleep(duration).await;
540            } else {
541                std::future::pending::<()>().await;
542            }
543        };
544        tokio::pin!(timeout_fut);
545
546        let mut timed_out = false;
547        let mut was_cancelled = false;
548        let status = loop {
549            // Use biased select to prioritize process completion over timeout/cancellation.
550            // This prevents a race where if process completes at the same instant as timeout,
551            // we'd incorrectly report a timeout instead of success.
552            select! {
553                biased;
554                status = cp.wait() => {
555                    break status?;
556                }
557                _ = &mut timeout_fut => {
558                    timed_out = true;
559                    #[cfg(unix)]
560                    kill_process_group(id);
561                    let _ = cp.kill().await;
562                }
563                _ = self.cancel.cancelled() => {
564                    was_cancelled = true;
565                    #[cfg(unix)]
566                    kill_process_group(id);
567                    let _ = cp.kill().await;
568                }
569            }
570        };
571        if let Err(e) = RUNNING_PIDS
572            .lock()
573            .map(|mut pids| pids.remove(&id))
574            .map_err(|e| e.to_string())
575        {
576            debug!("Failed to lock RUNNING_PIDS to remove pid {id}: {e}");
577        }
578
579        if was_cancelled {
580            #[cfg(feature = "progress")]
581            if let Some(pr) = &self.pr {
582                pr.set_status(progress::ProgressStatus::Failed);
583            }
584            return Err(crate::Error::Cancelled);
585        }
586
587        if timed_out {
588            #[cfg(feature = "progress")]
589            if let Some(pr) = &self.pr {
590                pr.set_status(progress::ProgressStatus::Failed);
591            }
592            return Err(crate::Error::TimedOut);
593        }
594
595        result.lock().await.status = status;
596
597        // these are sent when the process has flushed IO
598        let _ = stdout_ready.await;
599        let _ = stderr_ready.await;
600        let _ = stdin_ready.await;
601
602        if status.success() || self.allow_non_zero {
603            #[cfg(feature = "progress")]
604            if let Some(pr) = &self.pr {
605                pr.set_status(progress::ProgressStatus::Done);
606            }
607        } else {
608            let result = result.lock().await.to_owned();
609            self.on_error(combined_output.lock().await.join("\n"), result)?;
610        }
611
612        let result = result.lock().await.to_owned();
613        Ok(result)
614    }
615
616    fn on_error(&self, output: String, result: CmdResult) -> Result<()> {
617        let output = output.trim().to_string();
618        #[cfg(feature = "progress")]
619        if let Some(pr) = &self.pr {
620            pr.set_status(progress::ProgressStatus::Failed);
621            if self.show_stderr_on_error {
622                pr.println(&output);
623            }
624        }
625        Err(ScriptFailed(Box::new((
626            self.program.clone(),
627            self.args.clone(),
628            output,
629            result,
630        ))))?
631    }
632}
633
634/// Kill an entire process group by PGID (which equals the child PID since
635/// we spawn with process_group(0)).
636#[cfg(unix)]
637fn kill_process_group(pid: u32) {
638    let pgid = nix::unistd::Pid::from_raw(pid as i32);
639    if let Err(e) = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGKILL) {
640        debug!("Failed to kill process group {pid}: {e}");
641    }
642}
643
644impl Display for CmdLineRunner {
645    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
646        let args = self.args.join(" ");
647        let mut cmd = format!("{} {}", &self.program, args);
648        if cmd.starts_with("sh -o errexit -c ") {
649            cmd = cmd[17..].to_string();
650        }
651        write!(f, "{cmd}")
652    }
653}
654
655impl Debug for CmdLineRunner {
656    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
657        let args = self.args.join(" ");
658        write!(f, "{} {args}", self.program)
659    }
660}
661
662/// The result of executing a command.
663///
664/// Contains the captured output streams and exit status.
665#[derive(Debug, Default, Clone)]
666pub struct CmdResult {
667    /// The captured standard output.
668    pub stdout: String,
669    /// The captured standard error.
670    pub stderr: String,
671    /// Combined stdout and stderr in the order they were received.
672    pub combined_output: String,
673    /// The exit status of the process.
674    pub status: ExitStatus,
675}