clitest_lib/
script.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    path::Path,
4    process::ExitStatus,
5    sync::{Arc, Mutex, atomic::AtomicBool},
6    thread::ScopedJoinHandle,
7    time::{Duration, Instant},
8};
9
10use grok::Grok;
11use keepcalm::SharedMut;
12use serde::{Serialize, ser::SerializeMap};
13use termcolor::{Color, ColorChoice, WriteColor};
14
15use crate::{
16    command::{CommandLine, CommandResult},
17    util::{NicePathBuf, NiceTempDir},
18};
19use crate::{cwrite, cwriteln, cwriteln_rule};
20use crate::{output::*, util::ShellBit};
21
22const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
23
24#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
25pub struct ScriptLocation {
26    pub file: ScriptFile,
27    pub line: usize,
28}
29
30impl ScriptLocation {
31    pub fn new(file: ScriptFile, line: usize) -> Self {
32        Self { file, line }
33    }
34}
35
36impl std::fmt::Display for ScriptLocation {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        write!(f, "{}:{}", self.file, self.line)
39    }
40}
41
42#[derive(
43    derive_more::Debug, derive_more::Display, Clone, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash,
44)]
45#[display("{}", file)]
46pub struct ScriptFile {
47    pub file: Arc<NicePathBuf>,
48}
49
50impl ScriptFile {
51    pub fn new(file: impl AsRef<Path>) -> Self {
52        Self {
53            file: Arc::new(NicePathBuf::new(file)),
54        }
55    }
56}
57
58impl<T: AsRef<Path>> From<T> for ScriptFile {
59    fn from(file: T) -> Self {
60        Self::new(file)
61    }
62}
63
64#[derive(Clone, derive_more::Debug, Serialize)]
65pub struct Script {
66    pub commands: Arc<Vec<ScriptBlock>>,
67    pub includes: Arc<HashMap<String, Script>>,
68    pub file: ScriptFile,
69}
70
71#[derive(Debug, Clone, Default)]
72pub struct ScriptRunArgs {
73    pub delay_steps: Option<u64>,
74    pub ignore_exit_codes: bool,
75    pub ignore_matches: bool,
76    pub simplified_output: bool,
77    pub show_line_numbers: bool,
78    pub runner: Option<String>,
79    pub quiet: bool,
80    pub verbose: bool,
81    pub global_timeout: Option<Duration>,
82    pub no_color: bool,
83}
84
85#[derive(derive_more::Debug, Clone)]
86pub struct ScriptOutput {
87    #[debug(skip)]
88    stream: SharedMut<Box<dyn WriteColorAny>>,
89}
90
91trait WriteColorAny: WriteColor + Send + Sync + std::any::Any + 'static + std::fmt::Debug {
92    /// Workaround for lack of upcasting
93    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String>;
94    fn clone_buffer(&self) -> Result<termcolor::Buffer, String>;
95}
96
97impl WriteColorAny for termcolor::StandardStream {
98    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String> {
99        Err("not a buffer".to_string())
100    }
101    fn clone_buffer(&self) -> Result<termcolor::Buffer, String> {
102        Err("not a buffer".to_string())
103    }
104}
105
106impl WriteColorAny for termcolor::Buffer {
107    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String> {
108        Ok(*self)
109    }
110    fn clone_buffer(&self) -> Result<termcolor::Buffer, String> {
111        Ok(self.clone())
112    }
113}
114
115impl ScriptOutput {
116    pub fn no_color() -> Self {
117        let stm = termcolor::StandardStream::stdout(ColorChoice::Never);
118        Self {
119            stream: SharedMut::new(Box::new(stm) as _),
120        }
121    }
122
123    pub fn quiet(no_color: bool) -> Self {
124        let stm = if no_color {
125            termcolor::Buffer::no_color()
126        } else {
127            termcolor::Buffer::ansi()
128        };
129        Self {
130            stream: SharedMut::new(Box::new(stm) as _),
131        }
132    }
133
134    pub fn take_buffer(self) -> String {
135        let stream = match SharedMut::try_unwrap(self.stream) {
136            Ok(stream) => stream.take_buffer().expect("wrong stream type"),
137            Err(shared) => shared.read().clone_buffer().expect("wrong stream type"),
138        };
139        String::from_utf8_lossy(&stream.into_inner()).to_string()
140    }
141}
142
143impl Default for ScriptOutput {
144    fn default() -> Self {
145        let stm = termcolor::StandardStream::stdout(ColorChoice::Auto);
146        Self {
147            stream: SharedMut::new(Box::new(stm) as _),
148        }
149    }
150}
151
152impl std::io::Write for ScriptOutputLock<'_> {
153    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
154        self.stream.write(buf)
155    }
156    fn flush(&mut self) -> std::io::Result<()> {
157        self.stream.flush()
158    }
159}
160
161impl termcolor::WriteColor for ScriptOutputLock<'_> {
162    fn supports_color(&self) -> bool {
163        self.stream.supports_color()
164    }
165    fn set_color(&mut self, spec: &termcolor::ColorSpec) -> std::io::Result<()> {
166        self.stream.set_color(spec)
167    }
168    fn reset(&mut self) -> std::io::Result<()> {
169        self.stream.reset()
170    }
171    fn is_synchronous(&self) -> bool {
172        self.stream.is_synchronous()
173    }
174    fn set_hyperlink(&mut self, _link: &termcolor::HyperlinkSpec) -> std::io::Result<()> {
175        self.stream.set_hyperlink(_link)
176    }
177    fn supports_hyperlinks(&self) -> bool {
178        self.stream.supports_hyperlinks()
179    }
180}
181
182struct ScriptOutputLock<'a> {
183    stream: keepcalm::SharedWriteLock<'a, Box<dyn WriteColorAny>>,
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Eq)]
187enum ScriptMode {
188    Normal,
189    Deferred,
190    Background,
191}
192
193#[derive(derive_more::Debug)]
194pub struct ScriptRunContext {
195    pub args: ScriptRunArgs,
196    pub grok: Grok,
197    timeout: Duration,
198    env_vars: HashMap<String, String>,
199    includes: Arc<HashMap<String, Script>>,
200    background: ScriptMode,
201    #[debug(skip)]
202    kill: ScriptKillReceiver,
203    #[debug(skip)]
204    kill_sender: ScriptKillSender,
205    output: ScriptOutput,
206
207    global_ignore: OutputPatterns,
208    global_reject: OutputPatterns,
209}
210
211impl Default for ScriptRunContext {
212    fn default() -> Self {
213        let kill = Arc::new(AtomicBool::new(false));
214        Self {
215            args: ScriptRunArgs::default(),
216            grok: Grok::with_default_patterns(),
217            timeout: DEFAULT_TIMEOUT,
218            env_vars: HashMap::new(),
219            background: ScriptMode::Normal,
220            includes: Arc::new(HashMap::new()),
221            kill: ScriptKillReceiver::new(kill.clone()),
222            kill_sender: ScriptKillSender::new(kill.clone()),
223            output: ScriptOutput::default(),
224            global_ignore: OutputPatterns::default(),
225            global_reject: OutputPatterns::default(),
226        }
227    }
228}
229
230impl ScriptRunContext {
231    pub fn new_background(&self) -> Self {
232        let kill = Arc::new(AtomicBool::new(false));
233        Self {
234            args: self.args.clone(),
235            grok: self.grok.clone(),
236            // Background processes are not subject to timeouts
237            timeout: Duration::MAX,
238            env_vars: self.env_vars.clone(),
239            background: ScriptMode::Background,
240            kill: ScriptKillReceiver::new(kill.clone()),
241            kill_sender: ScriptKillSender::new(kill.clone()),
242            includes: self.includes.clone(),
243            output: if self.args.verbose {
244                self.output.clone()
245            } else {
246                ScriptOutput::quiet(self.args.no_color)
247            },
248            global_ignore: self.global_ignore.clone(),
249            global_reject: self.global_reject.clone(),
250        }
251    }
252
253    pub fn new_deferred(&self) -> Self {
254        Self {
255            args: self.args.clone(),
256            grok: self.grok.clone(),
257            timeout: self.timeout,
258            env_vars: self.env_vars.clone(),
259            background: ScriptMode::Deferred,
260            kill: self.kill.clone(),
261            kill_sender: self.kill_sender.clone(),
262            includes: self.includes.clone(),
263            output: self.output.clone(),
264            global_ignore: self.global_ignore.clone(),
265            global_reject: self.global_reject.clone(),
266        }
267    }
268
269    pub fn pwd(&self) -> NicePathBuf {
270        self.env_vars
271            .get("PWD")
272            .cloned()
273            .map(NicePathBuf::from)
274            .unwrap_or_else(NicePathBuf::cwd)
275    }
276
277    pub fn get_env(&self, name: &str) -> Option<&str> {
278        self.env_vars.get(name).map(|s| s.as_str())
279    }
280
281    pub fn set_env(&mut self, name: impl Into<String>, value: impl Into<String>) {
282        let name = name.into();
283        if name == "PWD" {
284            self.set_pwd(value.into());
285        } else {
286            self.env_vars.insert(name, value.into());
287        }
288    }
289
290    pub fn set_pwd(&mut self, pwd: impl Into<NicePathBuf>) {
291        let pwd = pwd.into().env_string();
292        self.env_vars.insert("PWD".to_string(), pwd);
293    }
294
295    pub fn take_output(self) -> String {
296        self.output.take_buffer()
297    }
298
299    fn expand(&self, value: &ShellBit) -> Result<String, ScriptRunError> {
300        match value {
301            ShellBit::Literal(s) => Ok(s.clone()),
302            ShellBit::Quoted(s) => self.expand_str(s),
303        }
304    }
305
306    /// Perform shell expansion on a string.
307    fn expand_str(&self, value: impl AsRef<str>) -> Result<String, ScriptRunError> {
308        enum State {
309            Normal,
310            EscapeNext,
311            InCurly,
312            Dollar,
313            InDollar,
314        }
315
316        let value = value.as_ref();
317
318        // "\" triggers escaping
319        // ${A} expands to the value of A
320        // $A expands to the value of A (variable ends on first non-alphanumeric character)
321
322        let mut state = State::Normal;
323        let mut variable = String::new();
324        let mut expanded = String::new();
325
326        for c in value.chars() {
327            match state {
328                State::Normal => {
329                    if c == '$' {
330                        state = State::Dollar;
331                        continue;
332                    }
333                    if c == '\\' {
334                        state = State::EscapeNext;
335                        continue;
336                    }
337                    expanded.push(c);
338                }
339                State::EscapeNext => {
340                    expanded.push(c);
341                    state = State::Normal;
342                }
343                State::InCurly => {
344                    if c == '}' {
345                        if let Some(value) = self.get_env(&std::mem::take(&mut variable)) {
346                            expanded.push_str(value);
347                        } else {
348                            return Err(ScriptRunError::ExpansionError(format!(
349                                "undefined variable in ${{...}}: {:?} (in {value:?})",
350                                variable
351                            )));
352                        }
353                        state = State::Normal;
354                    } else {
355                        variable.push(c);
356                    }
357                }
358                State::Dollar => {
359                    if c.is_alphanumeric() || c == '_' {
360                        state = State::InDollar;
361                        variable.push(c);
362                    } else if c == '{' {
363                        state = State::InCurly;
364                    } else {
365                        return Err(ScriptRunError::ExpansionError(format!(
366                            "invalid variable: {:?} (in {value:?})",
367                            c
368                        )));
369                    }
370                }
371                State::InDollar => {
372                    if c.is_alphanumeric() || c == '_' {
373                        variable.push(c);
374                    } else {
375                        if let Some(value) = self.get_env(&std::mem::take(&mut variable)) {
376                            expanded.push_str(value);
377                        } else {
378                            return Err(ScriptRunError::ExpansionError(format!(
379                                "undefined variable in $...: {:?} (in {value:?})",
380                                variable
381                            )));
382                        }
383                        expanded.push(c);
384                        state = State::Normal;
385                    }
386                }
387            }
388        }
389        match state {
390            State::InDollar => {
391                if let Some(value) = self.get_env(&variable) {
392                    expanded.push_str(value);
393                } else {
394                    return Err(ScriptRunError::ExpansionError(format!(
395                        "undefined variable: {}",
396                        variable
397                    )));
398                }
399            }
400            State::Dollar => {
401                return Err(ScriptRunError::ExpansionError(
402                    "incomplete variable".to_string(),
403                ));
404            }
405            State::InCurly => {
406                return Err(ScriptRunError::ExpansionError(format!(
407                    "unclosed variable: {}",
408                    variable
409                )));
410            }
411            State::Normal => {}
412            State::EscapeNext => {
413                return Err(ScriptRunError::ExpansionError(
414                    "unclosed backslash".to_string(),
415                ));
416            }
417        }
418        Ok(expanded)
419    }
420
421    /// Get a mutable reference to the output stream.
422    pub fn stream(&self) -> impl termcolor::WriteColor + use<'_> {
423        ScriptOutputLock {
424            stream: self.output.stream.write(),
425        }
426    }
427}
428
429#[derive(Clone)]
430pub struct ScriptKillReceiver {
431    kill_receiver: Arc<AtomicBool>,
432}
433
434impl ScriptKillReceiver {
435    pub fn new(kill_receiver: Arc<AtomicBool>) -> Self {
436        Self { kill_receiver }
437    }
438
439    pub fn is_killed(&self) -> bool {
440        self.kill_receiver.load(std::sync::atomic::Ordering::SeqCst)
441    }
442
443    pub fn run_with<T>(&self, kill: impl FnOnce() + Send, wait: impl FnOnce() -> T) -> T {
444        std::thread::scope(|s| {
445            let done = Arc::new(AtomicBool::new(false));
446            let done_clone = done.clone();
447            let t = s.spawn(move || {
448                while !done_clone.load(std::sync::atomic::Ordering::SeqCst) {
449                    if self.is_killed() {
450                        kill();
451                        break;
452                    }
453                    std::thread::sleep(Duration::from_millis(10));
454                }
455            });
456            let res = wait();
457            done.store(true, std::sync::atomic::Ordering::SeqCst);
458            t.join().unwrap();
459            res
460        })
461    }
462
463    #[cfg(windows)]
464    pub fn run_cmd(
465        &self,
466        output: std::process::Child,
467        warn_time: Duration,
468    ) -> std::io::Result<ExitStatus> {
469        use std::os::windows::io::AsRawHandle;
470        use win32job::Job;
471
472        fn map_job_error(e: win32job::JobError) -> std::io::Error {
473            match e {
474                win32job::JobError::AssignFailed(e) => e,
475                win32job::JobError::CreateFailed(e) => e,
476                win32job::JobError::GetInfoFailed(e) => e,
477                win32job::JobError::SetInfoFailed(e) => e,
478                _ => std::io::Error::new(std::io::ErrorKind::Other, "Unknown error"),
479            }
480        }
481
482        // Create a new Job object
483        let job = Job::create().map_err(map_job_error)?;
484
485        // Configure the job to terminate all child processes when the job is closed
486        let mut info = job.query_extended_limit_info().map_err(map_job_error)?;
487        info.limit_kill_on_job_close();
488        job.set_extended_limit_info(&info).map_err(map_job_error)?;
489        job.assign_process(output.as_raw_handle() as _)?;
490
491        // Resume the main thread for the process
492        let id = output.id();
493        for thread_entry in tlhelp32::Snapshot::new_thread()? {
494            if thread_entry.owner_process_id == id {
495                use windows_sys::Win32::Foundation::CloseHandle;
496                use windows_sys::Win32::System::Threading::*;
497
498                // TODO: error handling
499                unsafe {
500                    let thread = OpenThread(THREAD_SUSPEND_RESUME, 0, thread_entry.thread_id);
501                    ResumeThread(thread);
502                    CloseHandle(thread);
503                }
504            }
505        }
506
507        let job = Mutex::new(Some(job));
508        let output = Mutex::new(output);
509        self.run_with(
510            || {
511                _ = job.lock().unwrap().take();
512                _ = output.lock().unwrap().kill();
513            },
514            || {
515                let start = std::time::Instant::now();
516                let mut warned = false;
517                loop {
518                    let res = output.lock().unwrap().try_wait()?;
519                    if let Some(status) = res {
520                        return Ok::<_, std::io::Error>(status);
521                    }
522                    if start.elapsed() > warn_time {
523                        if !warned {
524                            let child = output.lock().unwrap().id();
525                            eprintln!("Process #{child} taking too long to finish.");
526                            warned = true;
527                        }
528                    }
529                    std::thread::sleep(Duration::from_millis(10));
530                }
531            },
532        )
533    }
534
535    #[cfg(unix)]
536    pub fn run_cmd(
537        &self,
538        output: std::process::Child,
539        warn_time: Duration,
540    ) -> std::io::Result<ExitStatus> {
541        let output = Mutex::new(output);
542        self.run_with(
543            || {
544                use signal_child::{signal, signal::*};
545                let id = output.lock().unwrap().id() as i32;
546                _ = signal(-id, SIGINT);
547                std::thread::sleep(Duration::from_millis(10));
548                _ = signal(-id, SIGTERM);
549            },
550            || {
551                let start = std::time::Instant::now();
552                let mut warned = false;
553                loop {
554                    let res = output.lock().unwrap().try_wait()?;
555                    if let Some(status) = res {
556                        return Ok::<_, std::io::Error>(status);
557                    }
558                    if start.elapsed() > warn_time && !warned {
559                        let child = output.lock().unwrap().id();
560                        eprintln!("Process #{child} taking too long to finish.");
561                        warned = true;
562                    }
563                    std::thread::sleep(Duration::from_millis(10));
564                }
565            },
566        )
567    }
568}
569
570#[derive(Clone)]
571pub struct ScriptKillSender {
572    kill_sender: Arc<AtomicBool>,
573}
574
575impl ScriptKillSender {
576    pub fn new(kill_sender: Arc<AtomicBool>) -> Self {
577        Self { kill_sender }
578    }
579
580    pub fn kill(&self) {
581        self.kill_sender
582            .store(true, std::sync::atomic::Ordering::SeqCst);
583    }
584}
585
586impl ScriptRunContext {
587    pub fn new(args: ScriptRunArgs, script_path: impl AsRef<Path>, output: ScriptOutput) -> Self {
588        let mut env_vars = HashMap::new();
589
590        macro_rules! target {
591            ($env:ident, $var:ident, [$($vals:expr),*]) => {
592                $(
593                if cfg!($var = $vals) {
594                    env_vars.insert(stringify!($env).to_string(), $vals.to_string());
595                }
596                )*
597            };
598        }
599
600        target!(
601            TARGET_OS,
602            target_os,
603            ["windows", "linux", "macos", "ios", "android"]
604        );
605        target!(TARGET_FAMILY, target_family, ["windows", "unix", "wasm"]);
606        target!(
607            TARGET_ARCH,
608            target_arch,
609            ["x86", "x86_64", "arm", "aarch64"]
610        );
611
612        // Set the current working directory as a special variable "PWD"
613        env_vars.insert(
614            "PWD".to_string(),
615            NicePathBuf::from(script_path.as_ref().parent().unwrap()).env_string(),
616        );
617        // Save the initial PWD as INITIAL_PWD so it can easily be restored
618        env_vars.insert("INITIAL_PWD".to_string(), env_vars["PWD"].clone());
619
620        let kill = Arc::new(AtomicBool::new(false));
621
622        Self {
623            timeout: args.global_timeout.unwrap_or(DEFAULT_TIMEOUT),
624            args,
625            env_vars,
626            grok: Grok::with_default_patterns(),
627            includes: Arc::new(HashMap::new()),
628            background: ScriptMode::Normal,
629            kill: ScriptKillReceiver::new(kill.clone()),
630            kill_sender: ScriptKillSender::new(kill.clone()),
631            output,
632            global_ignore: OutputPatterns::default(),
633            global_reject: OutputPatterns::default(),
634        }
635    }
636}
637
638#[derive(Clone, Debug, PartialEq, Eq)]
639pub struct ScriptLine {
640    pub location: ScriptLocation,
641    text: String,
642}
643
644impl ScriptLine {
645    pub fn new(file: ScriptFile, line: usize, text: impl AsRef<str>) -> Self {
646        Self {
647            location: ScriptLocation::new(file, line),
648            text: text.as_ref().to_string(),
649        }
650    }
651
652    pub fn parse(file: ScriptFile, text: impl AsRef<str>) -> Vec<Self> {
653        text.as_ref()
654            .lines()
655            .enumerate()
656            .map(|(line, text)| Self {
657                location: ScriptLocation::new(file.clone(), line + 1),
658                text: text.to_string(),
659            })
660            .collect()
661    }
662
663    pub fn starts_with(&self, text: &str) -> bool {
664        self.text.trim().starts_with(text)
665    }
666
667    pub fn first_char(&self) -> Option<char> {
668        self.text.trim().chars().next()
669    }
670
671    pub fn text(&self) -> &str {
672        self.text.trim()
673    }
674
675    pub fn text_untrimmed(&self) -> &str {
676        &self.text
677    }
678
679    pub fn is_empty(&self) -> bool {
680        self.text.trim().is_empty()
681    }
682
683    pub fn strip_prefix(&self, prefix: &str) -> Option<&str> {
684        self.text.strip_prefix(prefix)
685    }
686}
687
688#[derive(Debug, thiserror::Error, derive_more::Display)]
689#[display("{error} at {location}{}", associated_data.as_deref().map_or("".to_string(), |d| format!(": {d}")))]
690pub struct ScriptError {
691    pub error: ScriptErrorType,
692    pub location: ScriptLocation,
693    pub associated_data: Option<String>,
694}
695
696impl ScriptError {
697    pub fn new(error: ScriptErrorType, location: ScriptLocation) -> Self {
698        if std::env::var("PANIC_ON_ERROR").is_ok() {
699            panic!("ScriptError: {error} at {location}");
700        }
701        Self {
702            error,
703            location,
704            associated_data: None,
705        }
706    }
707
708    pub fn new_with_data(
709        error: ScriptErrorType,
710        location: ScriptLocation,
711        associated_data: String,
712    ) -> Self {
713        if std::env::var("PANIC_ON_ERROR").is_ok() {
714            panic!("ScriptError: {error} at {location}: {associated_data}");
715        }
716        Self {
717            error,
718            location,
719            associated_data: Some(associated_data),
720        }
721    }
722}
723
724#[derive(Debug, thiserror::Error, Eq, PartialEq)]
725pub enum ScriptErrorType {
726    #[error("background process not allowed")]
727    BackgroundProcessNotAllowed,
728    #[error("unclosed quote")]
729    UnclosedQuote,
730    #[error("unclosed backslash")]
731    UnclosedBackslash,
732    #[error("illegal shell command format")]
733    IllegalShellCommand,
734    #[error("unsupported redirection")]
735    UnsupportedRedirection,
736    #[error("invalid pattern definition")]
737    InvalidPatternDefinition,
738    #[error("invalid pattern")]
739    InvalidPattern,
740    #[error("invalid meta command")]
741    InvalidMetaCommand,
742    #[error("invalid pattern at global level (only reject or ignore allowed here)")]
743    InvalidGlobalPattern,
744    #[error("invalid block type")]
745    InvalidBlockType,
746    #[error("invalid block arguments")]
747    InvalidBlockArgs,
748    #[error("unsupported command position")]
749    UnsupportedCommandPosition,
750    #[error("invalid trailing pattern after *")]
751    InvalidAnyPattern,
752    #[error("invalid exit status")]
753    InvalidExitStatus,
754    #[error("invalid set variable")]
755    InvalidSetVariable,
756    #[error("invalid version")]
757    InvalidVersion,
758    #[error("invalid internal command")]
759    InvalidInternalCommand,
760    #[error("missing command lines")]
761    MissingCommandLines,
762    #[error(
763        "block end without matching block start, too many closing braces or braces not properly nested"
764    )]
765    InvalidBlockEnd,
766    #[error("invalid if condition")]
767    InvalidIfCondition,
768    #[error("expected block or semi-colon (did you forget to add ';' at the end of this line?)")]
769    ExpectedBlockOrSemi,
770}
771
772#[derive(Debug, thiserror::Error)]
773pub enum ScriptRunError {
774    #[error("{0}")]
775    Pattern(#[from] OutputPatternMatchFailure),
776    #[error("{0}")]
777    PatternPrepareError(#[from] OutputPatternPrepareError),
778    #[error("{0}")]
779    Exit(CommandResult),
780    #[error("included file not found: {0}")]
781    IncludedFileNotFound(String),
782    #[error("expected failure, but passed")]
783    ExpectedFailure,
784    #[error("{0}")]
785    ExpansionError(String),
786    #[error("{0}")]
787    IO(#[from] std::io::Error),
788    #[error("killed")]
789    Killed,
790    #[error("background process took too long to finish")]
791    BackgroundProcessTookTooLong,
792    #[error("retry took too long to finish")]
793    RetryTookTooLong,
794    /// Internal flow control: exit the script
795    #[error("exiting script")]
796    ExitScript,
797}
798
799impl ScriptRunError {
800    #[allow(unused)]
801    pub fn short(&self) -> String {
802        match self {
803            Self::Pattern(_) => "Pattern".to_string(),
804            Self::PatternPrepareError(e) => format!("PatternPrepareError({:?})", e),
805            Self::Exit(status) => format!("Exit({})", status),
806            Self::ExpectedFailure => "ExpectedFailure".to_string(),
807            Self::IO(e) => format!("IO({:?})", e.kind()),
808            Self::Killed => "Killed".to_string(),
809            Self::BackgroundProcessTookTooLong => "BackgroundProcessTookTooLong".to_string(),
810            Self::ExpansionError(e) => "ExpansionError".to_string(),
811            Self::RetryTookTooLong => "RetryTookTooLong".to_string(),
812            Self::ExitScript => unreachable!(),
813            Self::IncludedFileNotFound(path) => format!("IncludedFileNotFound({})", path),
814        }
815    }
816}
817
818impl Script {
819    pub fn new(file: ScriptFile) -> Self {
820        Self {
821            commands: Arc::new(vec![]),
822            includes: Arc::new(HashMap::new()),
823            file,
824        }
825    }
826
827    /// Collect all included script paths from the script.
828    pub fn includes(&self) -> Vec<(ScriptLocation, String)> {
829        self.commands
830            .iter()
831            .flat_map(|block| block.includes())
832            .collect()
833    }
834
835    pub fn run(&self, context: &mut ScriptRunContext) -> Result<(), ScriptRunError> {
836        let old_includes = context.includes.clone();
837        context.includes = self.includes.clone();
838        let res = ScriptBlock::run_blocks(context, &self.commands);
839        context.includes = old_includes;
840        let v = match res {
841            Ok(v) => v,
842            // Bypass normal script processing and exit successfully
843            Err(ScriptRunError::ExitScript) => return Ok(()),
844            Err(e) => return Err(e),
845        };
846        assert!(v.is_empty(), "script did not run to completion: {v:?}");
847        Ok(())
848    }
849
850    pub fn run_with_args(
851        &self,
852        args: ScriptRunArgs,
853        output: ScriptOutput,
854    ) -> Result<(), ScriptRunError> {
855        let start = Instant::now();
856        let script_path = &*self.file.file;
857        let mut context = ScriptRunContext::new(args, script_path, output);
858
859        // Write "Running..." message with colors
860        cwrite!(context.stream(), "Running ");
861        cwrite!(context.stream(), fg = Color::Cyan, "{}", script_path);
862        cwriteln!(context.stream(), " ...");
863        cwriteln!(context.stream());
864
865        let result = self.run(&mut context);
866
867        // Handle success and error output
868        if let Err(ref e) = result {
869            cwrite!(context.stream(), fg = Color::Cyan, "{} ", script_path);
870            cwrite!(context.stream(), fg = Color::Red, "FAILED");
871            if !context.args.simplified_output {
872                cwriteln!(context.stream(), " ({:.2}s)", start.elapsed().as_secs_f32());
873            } else {
874                cwriteln!(context.stream());
875            }
876            cwrite!(context.stream(), fg = Color::Red, "Error: ");
877            cwriteln!(context.stream(), "{}", e);
878            cwriteln!(context.stream());
879        } else {
880            cwrite!(context.stream(), fg = Color::Cyan, "{} ", script_path);
881            cwrite!(context.stream(), fg = Color::Green, "PASSED");
882            if !context.args.simplified_output {
883                cwriteln!(context.stream(), " ({:.2}s)", start.elapsed().as_secs_f32());
884            } else {
885                cwriteln!(context.stream());
886            }
887        }
888
889        result
890    }
891}
892
893#[derive(Debug, Default, Serialize)]
894pub enum CommandExit {
895    #[default]
896    Success,
897    Failure(i32),
898    Timeout,
899    Any,
900    AnyFailure,
901}
902
903impl CommandExit {
904    pub fn matches(&self, status: CommandResult) -> bool {
905        match (self, status) {
906            (CommandExit::Success, CommandResult::Exit(status)) => status.success(),
907            (CommandExit::Failure(code), CommandResult::Exit(status)) => {
908                *code == status.code().unwrap_or(-1)
909            }
910            (CommandExit::Timeout, CommandResult::TimedOut) => true,
911            (CommandExit::Any, _) => true,
912            (CommandExit::AnyFailure, CommandResult::Exit(status)) => !status.success(),
913            (CommandExit::AnyFailure, _) => true,
914            _ => false,
915        }
916    }
917
918    pub fn is_success(&self) -> bool {
919        matches!(self, CommandExit::Success)
920    }
921}
922
923#[derive(derive_more::Debug)]
924pub enum ScriptBlock {
925    Command(ScriptCommand),
926    InternalCommand(ScriptLocation, InternalCommand),
927    Background(Vec<ScriptBlock>),
928    Defer(Vec<ScriptBlock>),
929    If(IfCondition, Vec<ScriptBlock>),
930    For(ForCondition, Vec<ScriptBlock>),
931    Retry(Vec<ScriptBlock>),
932    GlobalIgnore(OutputPatterns),
933    GlobalReject(OutputPatterns),
934}
935
936impl ScriptBlock {
937    pub fn includes(&self) -> Vec<(ScriptLocation, String)> {
938        match self {
939            ScriptBlock::Command(..) => vec![],
940            ScriptBlock::InternalCommand(location, InternalCommand::Include(path)) => {
941                vec![(location.clone(), path.clone())]
942            }
943            ScriptBlock::InternalCommand(..) => vec![],
944            ScriptBlock::Background(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
945            ScriptBlock::Defer(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
946            ScriptBlock::If(_, blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
947            ScriptBlock::For(_, blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
948            ScriptBlock::Retry(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
949            ScriptBlock::GlobalIgnore(_) => vec![],
950            ScriptBlock::GlobalReject(_) => vec![],
951        }
952    }
953
954    pub fn run_blocks(
955        context: &mut ScriptRunContext,
956        blocks: &[ScriptBlock],
957    ) -> Result<Vec<ScriptResult>, ScriptRunError> {
958        enum Deferred<'a> {
959            Scripts(&'a [ScriptBlock]),
960            Internal(
961                Box<
962                    dyn FnOnce(&mut ScriptRunContext) -> Result<(), ScriptRunError>
963                        + Send
964                        + Sync
965                        + 'a,
966                >,
967            ),
968            Background(
969                ScopedJoinHandle<'a, Result<Vec<ScriptResult>, ScriptRunError>>,
970                ScriptKillSender,
971            ),
972        }
973
974        let mut results = Vec::new();
975        std::thread::scope(|s| {
976            let mut defer_blocks = VecDeque::new();
977            let mut pending_error = None;
978            for block in blocks {
979                if context.kill.is_killed() {
980                    return Err(ScriptRunError::Killed);
981                }
982                match block {
983                    ScriptBlock::Background(blocks) => {
984                        let mut context = context.new_background();
985                        let kill_sender = context.kill_sender.clone();
986                        let handle = s.spawn(move || Self::run_blocks(&mut context, blocks));
987                        defer_blocks.push_front(Deferred::Background(handle, kill_sender));
988                    }
989                    ScriptBlock::Defer(blocks) => {
990                        // Insert at the front of the queue by extending and
991                        // then rotating
992                        defer_blocks.push_front(Deferred::Scripts(blocks));
993                    }
994                    ScriptBlock::InternalCommand(_, command) => {
995                        if context.background == ScriptMode::Deferred {
996                            cwrite!(context.stream(), dimmed = true, "(deferred) ");
997                        }
998                        if let Some(f) = command.run(context)? {
999                            defer_blocks.push_front(Deferred::Internal(f));
1000                        }
1001                    }
1002                    _ => match block.run(context) {
1003                        Ok(res) => results.extend(res),
1004                        Err(e) => {
1005                            pending_error = Some(e);
1006                            break;
1007                        }
1008                    },
1009                }
1010            }
1011            for block in defer_blocks {
1012                match block {
1013                    Deferred::Scripts(blocks) => {
1014                        let mut context = context.new_deferred();
1015                        ScriptBlock::run_blocks(&mut context, blocks)?;
1016                    }
1017                    Deferred::Internal(block) => {
1018                        cwrite!(context.stream(), dimmed = true, "(cleanup) ");
1019                        block(context)?;
1020                    }
1021                    Deferred::Background(handle, kill_sender) => {
1022                        kill_sender.kill();
1023                        let start = std::time::Instant::now();
1024                        let mut warned = false;
1025
1026                        let timeout = context.timeout;
1027                        let warn_at = timeout * 8 / 10;
1028
1029                        let results = loop {
1030                            if handle.is_finished() {
1031                                break handle.join().unwrap()?;
1032                            }
1033                            std::thread::sleep(std::time::Duration::from_millis(10));
1034                            if !warned && start.elapsed() > warn_at {
1035                                cwriteln!(
1036                                    context.stream(),
1037                                    fg = Color::Yellow,
1038                                    "Background process is taking too long to finish."
1039                                );
1040                                warned = true;
1041                            }
1042                            if start.elapsed() > timeout {
1043                                cwriteln!(
1044                                    context.stream(),
1045                                    fg = Color::Red,
1046                                    "Background process took too long to finish."
1047                                );
1048                                return Err(ScriptRunError::BackgroundProcessTookTooLong);
1049                            }
1050                        };
1051                        for result in results {
1052                            cwrite!(context.stream(), dimmed = true, "(background) ");
1053                            for line in result.command.command.split('\n') {
1054                                cwriteln!(context.stream(), fg = Color::Green, "{}", line);
1055                            }
1056                            if context.args.simplified_output {
1057                                cwriteln!(context.stream(), dimmed = true, "---");
1058                            } else {
1059                                cwriteln_rule!(
1060                                    context.stream(),
1061                                    fg = Color::Cyan,
1062                                    "{}",
1063                                    result.command.location
1064                                );
1065                            }
1066                            for line in &result.output {
1067                                cwriteln!(context.stream(), "{}", line);
1068                            }
1069                            if result.output.is_empty() {
1070                                cwriteln!(context.stream(), dimmed = true, "(no output)");
1071                            }
1072                            if context.args.simplified_output {
1073                                cwriteln!(context.stream(), dimmed = true, "---");
1074                            } else {
1075                                cwriteln_rule!(context.stream());
1076                            }
1077                            result.evaluate(context)?;
1078                        }
1079                    }
1080                }
1081            }
1082            if let Some(error) = pending_error {
1083                return Err(error);
1084            }
1085            Ok(results)
1086        })
1087    }
1088
1089    pub fn run(&self, context: &mut ScriptRunContext) -> Result<Vec<ScriptResult>, ScriptRunError> {
1090        let pwd = context.pwd();
1091        let res = pwd.exists();
1092        if !matches!(res, Ok(true)) {
1093            cwriteln!(
1094                context.stream(),
1095                fg = Color::Red,
1096                "$PWD {pwd:?} doesn't exist. Run `cd $INITIAL_PWD` to fix.",
1097            );
1098            return Err(ScriptRunError::IO(std::io::Error::new(
1099                std::io::ErrorKind::NotFound,
1100                format!("PWD does not exist: {pwd:?}"),
1101            )));
1102        }
1103
1104        match self {
1105            ScriptBlock::Command(command) => {
1106                if context.background == ScriptMode::Deferred {
1107                    cwrite!(context.stream(), dimmed = true, "(deferred) ");
1108                }
1109                let result = command.run(context)?;
1110                if context.background != ScriptMode::Background {
1111                    result.evaluate(context)?;
1112                    Ok(vec![])
1113                } else {
1114                    Ok(vec![result])
1115                }
1116            }
1117            ScriptBlock::If(condition, blocks) => {
1118                let condition = condition.expand(context)?;
1119                if condition.matches(context) {
1120                    Self::run_blocks(context, blocks)
1121                } else {
1122                    Ok(vec![])
1123                }
1124            }
1125            ScriptBlock::For(ForCondition::Env(env, values), blocks) => {
1126                let mut results = Vec::new();
1127                for value in values {
1128                    context.set_env(env, context.expand(value)?);
1129                    results.extend(Self::run_blocks(context, blocks)?);
1130                }
1131                Ok(results)
1132            }
1133            ScriptBlock::Retry(blocks) => {
1134                let start = Instant::now();
1135                let mut backoff = Duration::from_millis(100);
1136
1137                cwrite!(context.stream(), fg = Color::Green, "retry: ");
1138                cwriteln!(context.stream(), "running...");
1139
1140                loop {
1141                    let mut nested_context = context.new_background();
1142                    if let Ok(results) = Self::run_blocks(&mut nested_context, blocks) {
1143                        let mut all_ok = true;
1144                        for result in results {
1145                            if result.evaluate(&mut nested_context).is_err() {
1146                                all_ok = false;
1147                                break;
1148                            }
1149                        }
1150                        if all_ok {
1151                            let output = nested_context.take_output();
1152                            cwrite!(context.stream(), fg = Color::Green, "retry: ");
1153                            cwriteln!(context.stream(), "success");
1154                            cwriteln!(context.stream());
1155                            cwriteln!(context.stream(), "{output}");
1156                            return Ok(vec![]);
1157                        }
1158                    }
1159
1160                    if start.elapsed() > context.timeout {
1161                        let output = nested_context.take_output();
1162                        cwrite!(context.stream(), fg = Color::Green, "retry: ");
1163                        cwriteln!(context.stream(), fg = Color::Red, "timed out");
1164                        cwriteln!(context.stream());
1165                        cwriteln!(context.stream(), "{output}");
1166                        cwriteln_rule!(context.stream());
1167                        return Err(ScriptRunError::RetryTookTooLong);
1168                    }
1169                    std::thread::sleep(backoff);
1170                    backoff *= 2;
1171                }
1172            }
1173            ScriptBlock::GlobalIgnore(patterns) => {
1174                for pattern in patterns.iter() {
1175                    pattern.prepare(&context.grok)?;
1176                }
1177                context.global_ignore.extend(patterns);
1178                Ok(vec![])
1179            }
1180            ScriptBlock::GlobalReject(patterns) => {
1181                for pattern in patterns.iter() {
1182                    pattern.prepare(&context.grok)?;
1183                }
1184                context.global_reject.extend(patterns);
1185                Ok(vec![])
1186            }
1187            _ => unreachable!("Unexpected block type: {self:?}"),
1188        }
1189    }
1190}
1191
1192impl Serialize for ScriptBlock {
1193    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1194    where
1195        S: serde::Serializer,
1196    {
1197        match self {
1198            ScriptBlock::Command(command) => command.serialize(serializer),
1199            ScriptBlock::InternalCommand(_, command) => command.serialize(serializer),
1200            ScriptBlock::Background(blocks) => {
1201                let mut ser = serializer.serialize_map(Some(1))?;
1202                ser.serialize_entry("background", blocks)?;
1203                ser.end()
1204            }
1205            ScriptBlock::Defer(blocks) => {
1206                let mut ser = serializer.serialize_map(Some(1))?;
1207                ser.serialize_entry("defer", blocks)?;
1208                ser.end()
1209            }
1210            ScriptBlock::If(condition, blocks) => {
1211                let mut ser = serializer.serialize_map(Some(2))?;
1212                ser.serialize_entry("if", condition)?;
1213                ser.serialize_entry("blocks", blocks)?;
1214                ser.end()
1215            }
1216            ScriptBlock::For(condition, blocks) => {
1217                let mut ser = serializer.serialize_map(Some(2))?;
1218                ser.serialize_entry("for", condition)?;
1219                ser.serialize_entry("blocks", blocks)?;
1220                ser.end()
1221            }
1222            ScriptBlock::Retry(blocks) => {
1223                let mut ser = serializer.serialize_map(Some(1))?;
1224                ser.serialize_entry("retry", blocks)?;
1225                ser.end()
1226            }
1227            ScriptBlock::GlobalIgnore(patterns) => {
1228                let mut ser = serializer.serialize_map(Some(1))?;
1229                ser.serialize_entry("ignore", patterns)?;
1230                ser.end()
1231            }
1232            ScriptBlock::GlobalReject(patterns) => {
1233                let mut ser = serializer.serialize_map(Some(1))?;
1234                ser.serialize_entry("reject", patterns)?;
1235                ser.end()
1236            }
1237        }
1238    }
1239}
1240
1241#[derive(Debug, Clone, Serialize)]
1242pub enum InternalCommand {
1243    UsingTempdir,
1244    UsingDir(ShellBit, bool),
1245    ChangeDir(ShellBit),
1246    Set(String, ShellBit),
1247    Include(String),
1248    ExitScript,
1249    Pattern(String, String),
1250}
1251
1252impl InternalCommand {
1253    pub fn run(
1254        &self,
1255        context: &mut ScriptRunContext,
1256    ) -> Result<
1257        Option<Box<dyn FnOnce(&mut ScriptRunContext) -> Result<(), ScriptRunError> + Send + Sync>>,
1258        ScriptRunError,
1259    > {
1260        match self.clone() {
1261            InternalCommand::Include(path) => {
1262                let Some(script) = context.includes.get(&path) else {
1263                    return Err(ScriptRunError::IncludedFileNotFound(path));
1264                };
1265                script.clone().run(context)?;
1266                Ok(None)
1267            }
1268            InternalCommand::Pattern(name, pattern) => {
1269                context.grok.add_pattern(name, pattern);
1270                Ok(None)
1271            }
1272            InternalCommand::UsingTempdir => {
1273                let current_pwd = context.pwd();
1274                let tempdir = NiceTempDir::new();
1275                cwrite!(context.stream(), fg = Color::Yellow, "using tempdir: ");
1276                cwriteln!(context.stream(), "{}", tempdir);
1277                cwriteln!(context.stream());
1278                context.set_pwd(&tempdir);
1279                let pwd = context.pwd();
1280                if !pwd.exists()? {
1281                    return Err(ScriptRunError::IO(std::io::Error::new(
1282                        std::io::ErrorKind::NotFound,
1283                        format!("newly created tempdir does not exist: {pwd:?}"),
1284                    )));
1285                }
1286                Ok(Some(Box::new(move |context: &mut ScriptRunContext| {
1287                    cwriteln!(
1288                        context.stream(),
1289                        fg = Color::Yellow,
1290                        "removing {} && cd {}",
1291                        tempdir,
1292                        current_pwd
1293                    );
1294                    cwriteln!(context.stream());
1295                    if !tempdir.exists()? {
1296                        cwriteln!(
1297                            context.stream(),
1298                            fg = Color::Red,
1299                            "tempdir does not exist: {tempdir}"
1300                        );
1301                    }
1302                    if let Err(e) = tempdir.remove_dir_all() {
1303                        cwriteln!(
1304                            context.stream(),
1305                            fg = Color::Red,
1306                            "error removing tempdir: {e:?}"
1307                        );
1308                    }
1309                    Ok::<_, ScriptRunError>(())
1310                })))
1311            }
1312            InternalCommand::UsingDir(dir, new) => {
1313                let current_pwd = context.pwd();
1314                let dir = context.expand(&dir)?;
1315                let new_pwd = current_pwd.join(dir);
1316                if new {
1317                    cwrite!(context.stream(), fg = Color::Yellow, "using new dir: ");
1318                } else {
1319                    cwrite!(context.stream(), fg = Color::Yellow, "using dir: ");
1320                }
1321                cwriteln!(context.stream(), "{}", new_pwd);
1322                cwriteln!(context.stream());
1323
1324                if new {
1325                    new_pwd.create_dir_all()?;
1326                } else if !new_pwd.exists()? {
1327                    return Err(ScriptRunError::IO(std::io::Error::new(
1328                        std::io::ErrorKind::NotFound,
1329                        "directory does not exist",
1330                    )));
1331                }
1332                context.set_pwd(&new_pwd);
1333                Ok(Some(Box::new(move |context: &mut ScriptRunContext| {
1334                    if new {
1335                        cwriteln!(
1336                            context.stream(),
1337                            fg = Color::Yellow,
1338                            "removing {} && cd {}",
1339                            new_pwd,
1340                            current_pwd
1341                        );
1342                        cwriteln!(context.stream());
1343                    } else {
1344                        cwriteln!(context.stream(), fg = Color::Yellow, "cd {}", current_pwd);
1345                        cwriteln!(context.stream());
1346                    }
1347                    if new {
1348                        new_pwd.remove_dir_all()?;
1349                    }
1350                    context.set_pwd(current_pwd);
1351                    Ok::<_, ScriptRunError>(())
1352                })))
1353            }
1354            InternalCommand::ChangeDir(dir) => {
1355                let dir = context.expand(&dir)?;
1356
1357                cwriteln!(context.stream(), fg = Color::Yellow, "cd {dir}");
1358                cwriteln!(context.stream());
1359                let current_pwd = context.pwd();
1360                let new_pwd = current_pwd.join(dir);
1361                context.set_pwd(new_pwd);
1362                Ok(None)
1363            }
1364            InternalCommand::Set(name, value) => {
1365                let value = context.expand(&value)?;
1366
1367                context.set_env(&name, &value);
1368                let new_value = context.get_env(&name).unwrap_or_default();
1369                if new_value != value {
1370                    cwriteln!(
1371                        context.stream(),
1372                        fg = Color::Yellow,
1373                        "set {name} {value} (-> {new_value})"
1374                    );
1375                } else {
1376                    cwriteln!(context.stream(), fg = Color::Yellow, "set {name} {value}");
1377                }
1378                cwriteln!(context.stream());
1379
1380                Ok(None)
1381            }
1382            InternalCommand::ExitScript => {
1383                cwriteln!(context.stream(), fg = Color::Yellow, "exiting script");
1384                cwriteln!(context.stream());
1385                Err(ScriptRunError::ExitScript)
1386            }
1387        }
1388    }
1389}
1390
1391#[derive(Debug, Clone)]
1392pub enum IfCondition {
1393    True,
1394    False,
1395    EnvEq(bool, String, ShellBit),
1396}
1397
1398impl IfCondition {
1399    pub fn matches(&self, context: &ScriptRunContext) -> bool {
1400        match self {
1401            IfCondition::True => true,
1402            IfCondition::False => false,
1403            IfCondition::EnvEq(negated, name, expected) => {
1404                let value = context.get_env(name).unwrap_or_default();
1405                (expected == value) ^ negated
1406            }
1407        }
1408    }
1409
1410    pub fn expand(&self, context: &ScriptRunContext) -> Result<IfCondition, ScriptRunError> {
1411        match self {
1412            IfCondition::True => Ok(IfCondition::True),
1413            IfCondition::False => Ok(IfCondition::False),
1414            IfCondition::EnvEq(negated, name, expected) => {
1415                let value = context.expand(expected)?;
1416                Ok(IfCondition::EnvEq(
1417                    *negated,
1418                    name.clone(),
1419                    ShellBit::Literal(value),
1420                ))
1421            }
1422        }
1423    }
1424}
1425
1426impl Serialize for IfCondition {
1427    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1428    where
1429        S: serde::Serializer,
1430    {
1431        match self {
1432            IfCondition::True => "true".serialize(serializer),
1433            IfCondition::False => "false".serialize(serializer),
1434            IfCondition::EnvEq(negated, name, value) => {
1435                let mut ser = serializer.serialize_map(Some(3))?;
1436                ser.serialize_entry("op", if *negated { "!=" } else { "==" })?;
1437                ser.serialize_entry("env", name)?;
1438                ser.serialize_entry("value", value)?;
1439                ser.end()
1440            }
1441        }
1442    }
1443}
1444
1445#[derive(Debug)]
1446pub enum ForCondition {
1447    Env(String, Vec<ShellBit>),
1448}
1449
1450impl Serialize for ForCondition {
1451    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1452    where
1453        S: serde::Serializer,
1454    {
1455        match self {
1456            ForCondition::Env(name, values) => {
1457                let mut ser = serializer.serialize_map(Some(2))?;
1458                ser.serialize_entry("env", name)?;
1459                ser.serialize_entry("values", values)?;
1460                ser.end()
1461            }
1462        }
1463    }
1464}
1465
1466fn is_bool_false(b: &bool) -> bool {
1467    !b
1468}
1469
1470#[derive(Debug, Serialize)]
1471pub struct ScriptCommand {
1472    pub command: CommandLine,
1473    pub pattern: OutputPattern,
1474    #[serde(skip_serializing_if = "CommandExit::is_success")]
1475    pub exit: CommandExit,
1476    #[serde(skip_serializing_if = "is_bool_false")]
1477    pub expect_failure: bool,
1478    #[serde(skip_serializing_if = "Option::is_none")]
1479    pub set_var: Option<String>,
1480    #[serde(skip_serializing_if = "Option::is_none")]
1481    pub timeout: Option<Duration>,
1482}
1483
1484impl ScriptCommand {
1485    pub fn run(&self, context: &mut ScriptRunContext) -> Result<ScriptResult, ScriptRunError> {
1486        let command = &self.command;
1487        let args = &context.args;
1488        let start = Instant::now();
1489
1490        if let Some(delay) = args.delay_steps {
1491            std::thread::sleep(std::time::Duration::from_millis(delay));
1492        }
1493
1494        for line in command.command.split('\n') {
1495            cwriteln!(context.stream(), fg = Color::Green, "{}", line);
1496        }
1497        if args.simplified_output {
1498            cwriteln!(context.stream(), dimmed = true, "---");
1499        } else {
1500            cwriteln_rule!(context.stream(), fg = Color::Cyan, "{}", command.location);
1501        }
1502        let (output, status) = command.run(
1503            &mut context.stream(),
1504            context.args.show_line_numbers,
1505            context.args.runner.clone(),
1506            self.timeout.unwrap_or(context.timeout),
1507            &context.env_vars,
1508            &context.kill,
1509            &context.kill_sender,
1510        )?;
1511
1512        let exit_result = if !self.exit.matches(status) {
1513            ExitResult::Mismatch(status)
1514        } else {
1515            ExitResult::Matches(status)
1516        };
1517
1518        // Side-effects
1519        if let Some(set_var) = &self.set_var {
1520            context.set_env(set_var, output.to_string().trim());
1521        }
1522
1523        let match_context = OutputMatchContext::new(context);
1524        self.pattern.prepare(&context.grok)?;
1525        let prepared_output = output
1526            .with_ignore(&context.global_ignore)
1527            .with_reject(&context.global_reject);
1528        let pattern_result = match self.pattern.matches(match_context.clone(), prepared_output) {
1529            Ok(_) => {
1530                if self.expect_failure {
1531                    PatternResult::ExpectedFailure
1532                } else {
1533                    PatternResult::Matches
1534                }
1535            }
1536            Err(e) => {
1537                if self.expect_failure {
1538                    PatternResult::MatchesFailure
1539                } else {
1540                    let mut trace = String::new();
1541                    for line in match_context.traces() {
1542                        trace.push_str(&format!("{line}\n"));
1543                    }
1544                    PatternResult::Mismatch(e, trace)
1545                }
1546            }
1547        };
1548
1549        if output.is_empty() {
1550            cwriteln!(context.stream(), dimmed = true, "(no output)");
1551        }
1552
1553        if context.args.simplified_output {
1554            cwriteln!(context.stream(), dimmed = true, "---");
1555        } else {
1556            cwriteln_rule!(context.stream());
1557        }
1558
1559        Ok(ScriptResult {
1560            command: command.clone(),
1561            pattern: pattern_result,
1562            exit: exit_result,
1563            elapsed: start.elapsed(),
1564            output,
1565        })
1566    }
1567}
1568
1569#[derive(derive_more::Debug)]
1570pub struct ScriptResult {
1571    pub command: CommandLine,
1572    pub pattern: PatternResult,
1573    pub exit: ExitResult,
1574    pub elapsed: Duration,
1575    #[debug(skip)]
1576    pub output: Lines,
1577}
1578
1579impl ScriptResult {
1580    pub fn evaluate(&self, context: &mut ScriptRunContext) -> Result<(), ScriptRunError> {
1581        let args = &context.args;
1582        let (success, failure, warning, arrow) = if *crate::term::IS_UTF8 {
1583            ("✅", "❌", "⚠️", "→")
1584        } else {
1585            ("[*]", "[X]", "[!]", "->")
1586        };
1587
1588        if let ExitResult::Mismatch(status) = self.exit {
1589            if args.ignore_exit_codes {
1590                cwriteln!(
1591                    context.stream(),
1592                    fg = Color::Yellow,
1593                    "{warning} Ignored incorrect exit code: {status}"
1594                );
1595                cwriteln!(context.stream());
1596            } else {
1597                cwriteln!(
1598                    context.stream(),
1599                    fg = Color::Red,
1600                    "{failure} FAIL: {status}"
1601                );
1602                cwriteln!(
1603                    context.stream(),
1604                    dimmed = true,
1605                    " {arrow} {}",
1606                    self.command.command
1607                );
1608                cwriteln!(context.stream());
1609                return Err(ScriptRunError::Exit(status));
1610            }
1611        }
1612
1613        if let PatternResult::Mismatch(e, trace) = &self.pattern {
1614            if args.ignore_matches {
1615                cwriteln!(
1616                    context.stream(),
1617                    fg = Color::Yellow,
1618                    "{warning} Ignored error: {e} (ignoring mismatches)"
1619                );
1620                cwriteln!(context.stream());
1621            } else {
1622                cwriteln!(context.stream(), fg = Color::Red, "ERROR: {e}");
1623                cwriteln!(context.stream(), dimmed = true, "{trace}");
1624                cwriteln!(context.stream(), fg = Color::Red, "{failure} FAIL");
1625                cwriteln!(context.stream());
1626                return Err(ScriptRunError::Pattern(e.clone()));
1627            }
1628        }
1629
1630        if let PatternResult::ExpectedFailure = self.pattern {
1631            if args.ignore_matches {
1632                cwriteln!(
1633                    context.stream(),
1634                    fg = Color::Yellow,
1635                    "{warning} Should not have matched! (ignoring mismatches)"
1636                );
1637                cwriteln!(context.stream());
1638            } else {
1639                cwriteln!(
1640                    context.stream(),
1641                    fg = Color::Red,
1642                    "{failure} FAIL (output shouldn't match)"
1643                );
1644                cwriteln!(
1645                    context.stream(),
1646                    dimmed = true,
1647                    " {arrow} {}",
1648                    self.command.command
1649                );
1650                cwriteln!(context.stream());
1651                return Err(ScriptRunError::ExpectedFailure);
1652            }
1653        }
1654
1655        if let ExitResult::Matches(status) = self.exit {
1656            if status.success() {
1657                cwrite!(context.stream(), fg = Color::Green, "{success} OK");
1658                if !context.args.simplified_output {
1659                    cwriteln!(
1660                        context.stream(),
1661                        dimmed = true,
1662                        " ({:.2}s)",
1663                        self.elapsed.as_secs_f32()
1664                    );
1665                } else {
1666                    cwriteln!(context.stream());
1667                }
1668            } else {
1669                cwrite!(
1670                    context.stream(),
1671                    fg = Color::Green,
1672                    "{success} OK ({status})"
1673                );
1674                if !context.args.simplified_output {
1675                    cwriteln!(
1676                        context.stream(),
1677                        dimmed = true,
1678                        " ({:.2}s)",
1679                        self.elapsed.as_secs_f32()
1680                    );
1681                } else {
1682                    cwriteln!(context.stream());
1683                }
1684            }
1685            cwriteln!(context.stream());
1686        }
1687
1688        Ok(())
1689    }
1690}
1691
1692#[derive(Debug)]
1693pub enum PatternResult {
1694    Matches,
1695    MatchesFailure,
1696    ExpectedFailure,
1697    Mismatch(OutputPatternMatchFailure, String),
1698}
1699
1700#[derive(Debug)]
1701pub enum ExitResult {
1702    Matches(CommandResult),
1703    Mismatch(CommandResult),
1704    TimedOut,
1705}
1706
1707#[cfg(test)]
1708mod tests {
1709    use crate::parser::v0::parse_script;
1710
1711    use super::*;
1712    use std::error::Error;
1713
1714    #[test]
1715    fn test_script() -> Result<(), Box<dyn Error>> {
1716        let script = r#"
1717pattern VERSION \d+\.\d+\.\d+;
1718
1719$ something --version || echo 1
1720? Something %{VERSION}
1721
1722$ something --help
1723? Usage: something [OPTIONS]
1724repeat {
1725    choice {
1726? %{DATA} %{GREEDYDATA}
1727? %{DATA}=%{DATA} %{GREEDYDATA}
1728    }
1729}
1730"#;
1731
1732        let script = parse_script(ScriptFile::new("test.cli"), script)?;
1733        assert_eq!(script.commands.len(), 3);
1734        eprintln!("{:?}", script);
1735        Ok(())
1736    }
1737
1738    #[test]
1739    fn test_bad_script() -> Result<(), Box<dyn Error>> {
1740        let script = r#"
1741$ (cmd; cmd)
1742$ cmd &
1743    "#;
1744
1745        assert!(matches!(
1746            parse_script(ScriptFile::new("test.cli"), script),
1747            Err(ScriptError {
1748                error: ScriptErrorType::BackgroundProcessNotAllowed,
1749                ..
1750            })
1751        ));
1752        Ok(())
1753    }
1754
1755    #[test]
1756    fn test_script_run_context_expand() {
1757        let mut context = ScriptRunContext::new(
1758            ScriptRunArgs::default(),
1759            Path::new("."),
1760            ScriptOutput::default(),
1761        );
1762        context.set_env("A", "1");
1763        context.set_env("B", "2");
1764        context.set_env("C", "3");
1765        assert_eq!(context.expand_str("$A").unwrap(), "1".to_string());
1766        assert_eq!(context.expand_str("$A $B ").unwrap(), "1 2 ".to_string());
1767        assert_eq!(
1768            context.expand_str("${A} ${B} ").unwrap(),
1769            "1 2 ".to_string()
1770        );
1771        assert_eq!(context.expand_str(r#"\$A"#).unwrap(), "$A".to_string());
1772        assert_eq!(context.expand_str(r#"\${A}"#).unwrap(), "${A}".to_string());
1773        assert_eq!(context.expand_str(r#"\\$A"#).unwrap(), r#"\1"#);
1774        assert_eq!(context.expand_str(r#"\\${A}"#).unwrap(), r#"\1"#);
1775        context.set_env("TEMP_DIR", "/tmp");
1776        assert_eq!(context.expand_str("$TEMP_DIR").unwrap(), "/tmp".to_string());
1777        assert_eq!(
1778            context.expand_str("${TEMP_DIR}").unwrap(),
1779            "/tmp".to_string()
1780        );
1781    }
1782}