Skip to main content

clitest_lib/
script.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    path::Path,
4    process::ExitStatus,
5    sync::{Arc, Mutex, atomic::AtomicBool},
6    thread::ScopedJoinHandle,
7    time::{Duration, Instant},
8};
9
10use grok::Grok;
11use keepcalm::SharedMut;
12use serde::{Serialize, ser::SerializeMap};
13use termcolor::{Color, ColorChoice, WriteColor};
14
15use crate::{
16    command::{CommandLine, CommandResult},
17    failure::{OutputPatternMatchFailure, format_match_trace_tree},
18    util::{NicePathBuf, NiceTempDir},
19};
20use crate::{cwrite, cwriteln, cwriteln_rule};
21use crate::{output::*, util::ShellBit};
22
23const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
24
25#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
26pub struct ScriptLocation {
27    pub file: ScriptFile,
28    pub line: usize,
29}
30
31impl ScriptLocation {
32    pub fn new(file: ScriptFile, line: usize) -> Self {
33        Self { file, line }
34    }
35}
36
37impl std::fmt::Display for ScriptLocation {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        write!(f, "{}:{}", self.file, self.line)
40    }
41}
42
43#[derive(
44    derive_more::Debug, derive_more::Display, Clone, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash,
45)]
46#[display("{}", file)]
47pub struct ScriptFile {
48    pub base_line: usize,
49    pub file: Arc<NicePathBuf>,
50}
51
52impl ScriptFile {
53    pub fn new(file: impl AsRef<Path>) -> Self {
54        Self {
55            base_line: 0,
56            file: Arc::new(NicePathBuf::new(file)),
57        }
58    }
59    pub fn new_with_line(file: impl AsRef<Path>, line: usize) -> Self {
60        Self {
61            base_line: line,
62            file: Arc::new(NicePathBuf::new(file)),
63        }
64    }
65}
66
67impl<T: AsRef<Path>> From<T> for ScriptFile {
68    fn from(file: T) -> Self {
69        Self::new(file)
70    }
71}
72
73#[derive(Clone, derive_more::Debug, Serialize)]
74pub struct Script {
75    pub commands: Arc<Vec<ScriptBlock>>,
76    pub includes: Arc<HashMap<String, Script>>,
77    pub file: ScriptFile,
78}
79
80#[derive(Debug, Clone, Default)]
81pub struct ScriptRunArgs {
82    pub delay_steps: Option<u64>,
83    pub ignore_exit_codes: bool,
84    pub ignore_matches: bool,
85    pub simplified_output: bool,
86    pub show_line_numbers: bool,
87    pub runner: Option<String>,
88    pub quiet: bool,
89    pub verbose: bool,
90    pub global_timeout: Option<Duration>,
91    pub no_color: bool,
92}
93
94#[derive(Debug, Clone, Default)]
95pub struct ScriptEnv {
96    env_vars: HashMap<String, String>,
97}
98
99impl ScriptEnv {
100    pub fn set_defaults(&mut self, pwd: impl AsRef<Path>) {
101        macro_rules! target {
102            ($env:ident, $var:ident, [$($vals:expr),*]) => {
103                $(
104                if cfg!($var = $vals) {
105                    self.env_vars.insert(stringify!($env).to_string(), $vals.to_string());
106                }
107                )*
108            };
109        }
110
111        target!(
112            TARGET_OS,
113            target_os,
114            [
115                "windows",
116                "linux",
117                "macos",
118                "ios",
119                "android",
120                "freebsd",
121                "netbsd",
122                "openbsd",
123                "dragonfly",
124                "haiku",
125                "aix"
126            ]
127        );
128        target!(TARGET_FAMILY, target_family, ["windows", "unix", "wasm"]);
129        target!(
130            TARGET_ARCH,
131            target_arch,
132            [
133                "aarch64",
134                "amdgpu",
135                "arm",
136                "arm64ec",
137                "avr",
138                "bpf",
139                "csky",
140                "hexagon",
141                "loongarch32",
142                "loongarch64",
143                "m68k",
144                "mips",
145                "mips32r6",
146                "mips64",
147                "mips64r6",
148                "msp430",
149                "nvptx64",
150                "powerpc",
151                "powerpc64",
152                "riscv32",
153                "riscv64",
154                "s390x",
155                "sparc",
156                "sparc64",
157                "wasm32",
158                "wasm64",
159                "x86",
160                "x86_64",
161                "xtensa"
162            ]
163        );
164
165        // Set the current working directory as a special variable "PWD"
166        let pwd = NicePathBuf::from(pwd.as_ref()).env_string();
167        self.env_vars.insert("PWD".to_string(), pwd);
168        // Save the initial PWD as INITIAL_PWD so it can easily be restored
169        self.env_vars
170            .insert("INITIAL_PWD".to_string(), self.env_vars["PWD"].clone());
171    }
172
173    pub fn pwd(&self) -> NicePathBuf {
174        self.env_vars
175            .get("PWD")
176            .cloned()
177            .map(NicePathBuf::from)
178            .unwrap_or_else(NicePathBuf::cwd)
179    }
180
181    pub fn get_env(&self, name: &str) -> Option<&str> {
182        self.env_vars.get(name).map(|s| s.as_str())
183    }
184
185    pub fn set_env(&mut self, name: impl Into<String>, value: impl Into<String>) {
186        let name = name.into();
187        if name == "PWD" {
188            self.set_pwd(value.into());
189        } else {
190            self.env_vars.insert(name, value.into());
191        }
192    }
193
194    pub fn set_pwd(&mut self, pwd: impl Into<NicePathBuf>) {
195        let pwd = pwd.into().env_string();
196        self.env_vars.insert("PWD".to_string(), pwd);
197    }
198
199    pub fn expand(&self, value: &ShellBit) -> Result<String, ScriptRunError> {
200        match value {
201            ShellBit::Literal(s) => Ok(s.clone()),
202            ShellBit::Quoted(s) => self.expand_str(s),
203        }
204    }
205
206    /// Perform shell expansion on a string.
207    pub fn expand_str(&self, value: impl AsRef<str>) -> Result<String, ScriptRunError> {
208        enum State {
209            Normal,
210            EscapeNext,
211            InCurly,
212            Dollar,
213            InDollar,
214        }
215
216        let value = value.as_ref();
217
218        // "\" triggers escaping
219        // ${A} expands to the value of A
220        // $A expands to the value of A (variable ends on first non-alphanumeric character)
221
222        let mut state = State::Normal;
223        let mut variable = String::new();
224        let mut expanded = String::new();
225
226        for c in value.chars() {
227            match state {
228                State::Normal => {
229                    if c == '$' {
230                        state = State::Dollar;
231                        continue;
232                    }
233                    if c == '\\' {
234                        state = State::EscapeNext;
235                        continue;
236                    }
237                    expanded.push(c);
238                }
239                State::EscapeNext => {
240                    expanded.push(c);
241                    state = State::Normal;
242                }
243                State::InCurly => {
244                    if c == '}' {
245                        if let Some(value) = self.get_env(&std::mem::take(&mut variable)) {
246                            expanded.push_str(value);
247                        } else {
248                            return Err(ScriptRunError::ExpansionError(format!(
249                                "undefined variable in ${{...}}: {variable:?} (in {value:?})"
250                            )));
251                        }
252                        state = State::Normal;
253                    } else {
254                        variable.push(c);
255                    }
256                }
257                State::Dollar => {
258                    if c.is_alphanumeric() || c == '_' {
259                        state = State::InDollar;
260                        variable.push(c);
261                    } else if c == '{' {
262                        state = State::InCurly;
263                    } else {
264                        return Err(ScriptRunError::ExpansionError(format!(
265                            "invalid variable: {c:?} (in {value:?})"
266                        )));
267                    }
268                }
269                State::InDollar => {
270                    if c.is_alphanumeric() || c == '_' {
271                        variable.push(c);
272                    } else {
273                        if let Some(value) = self.get_env(&std::mem::take(&mut variable)) {
274                            expanded.push_str(value);
275                        } else {
276                            return Err(ScriptRunError::ExpansionError(format!(
277                                "undefined variable in $...: {variable:?} (in {value:?})"
278                            )));
279                        }
280                        expanded.push(c);
281                        state = State::Normal;
282                    }
283                }
284            }
285        }
286        match state {
287            State::InDollar => {
288                if let Some(value) = self.get_env(&variable) {
289                    expanded.push_str(value);
290                } else {
291                    return Err(ScriptRunError::ExpansionError(format!(
292                        "undefined variable: {variable}"
293                    )));
294                }
295            }
296            State::Dollar => {
297                return Err(ScriptRunError::ExpansionError(
298                    "incomplete variable".to_string(),
299                ));
300            }
301            State::InCurly => {
302                return Err(ScriptRunError::ExpansionError(format!(
303                    "unclosed variable: {variable}"
304                )));
305            }
306            State::Normal => {}
307            State::EscapeNext => {
308                return Err(ScriptRunError::ExpansionError(
309                    "unclosed backslash".to_string(),
310                ));
311            }
312        }
313        Ok(expanded)
314    }
315
316    pub fn env_vars(&self) -> &HashMap<String, String> {
317        &self.env_vars
318    }
319}
320
321#[derive(derive_more::Debug, Clone)]
322pub struct ScriptOutput {
323    #[debug(skip)]
324    stream: SharedMut<Box<dyn WriteColorAny>>,
325}
326
327trait WriteColorAny: WriteColor + Send + Sync + std::any::Any + 'static + std::fmt::Debug {
328    /// Workaround for lack of upcasting
329    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String>;
330    fn clone_buffer(&self) -> Result<termcolor::Buffer, String>;
331}
332
333impl WriteColorAny for termcolor::StandardStream {
334    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String> {
335        Err("not a buffer".to_string())
336    }
337    fn clone_buffer(&self) -> Result<termcolor::Buffer, String> {
338        Err("not a buffer".to_string())
339    }
340}
341
342impl WriteColorAny for termcolor::Buffer {
343    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String> {
344        Ok(*self)
345    }
346    fn clone_buffer(&self) -> Result<termcolor::Buffer, String> {
347        Ok(self.clone())
348    }
349}
350
351impl ScriptOutput {
352    pub fn no_color() -> Self {
353        let stm = termcolor::StandardStream::stdout(ColorChoice::Never);
354        Self {
355            stream: SharedMut::new(Box::new(stm) as _),
356        }
357    }
358
359    pub fn quiet(no_color: bool) -> Self {
360        let stm = if no_color {
361            termcolor::Buffer::no_color()
362        } else {
363            termcolor::Buffer::ansi()
364        };
365        Self {
366            stream: SharedMut::new(Box::new(stm) as _),
367        }
368    }
369
370    pub fn take_buffer(self) -> String {
371        let stream = match SharedMut::try_unwrap(self.stream) {
372            Ok(stream) => stream.take_buffer().expect("wrong stream type"),
373            Err(shared) => shared.read().clone_buffer().expect("wrong stream type"),
374        };
375        String::from_utf8_lossy(&stream.into_inner()).to_string()
376    }
377}
378
379impl Default for ScriptOutput {
380    fn default() -> Self {
381        let stm = termcolor::StandardStream::stdout(ColorChoice::Auto);
382        Self {
383            stream: SharedMut::new(Box::new(stm) as _),
384        }
385    }
386}
387
388impl std::io::Write for ScriptOutputLock<'_> {
389    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
390        self.stream.write(buf)
391    }
392    fn flush(&mut self) -> std::io::Result<()> {
393        self.stream.flush()
394    }
395}
396
397impl termcolor::WriteColor for ScriptOutputLock<'_> {
398    fn supports_color(&self) -> bool {
399        self.stream.supports_color()
400    }
401    fn set_color(&mut self, spec: &termcolor::ColorSpec) -> std::io::Result<()> {
402        self.stream.set_color(spec)
403    }
404    fn reset(&mut self) -> std::io::Result<()> {
405        self.stream.reset()
406    }
407    fn is_synchronous(&self) -> bool {
408        self.stream.is_synchronous()
409    }
410    fn set_hyperlink(&mut self, _link: &termcolor::HyperlinkSpec) -> std::io::Result<()> {
411        self.stream.set_hyperlink(_link)
412    }
413    fn supports_hyperlinks(&self) -> bool {
414        self.stream.supports_hyperlinks()
415    }
416}
417
418struct ScriptOutputLock<'a> {
419    stream: keepcalm::SharedWriteLock<'a, Box<dyn WriteColorAny>>,
420}
421
422#[derive(Debug, Clone, Copy, PartialEq, Eq)]
423enum ScriptMode {
424    Normal,
425    Deferred,
426    Background,
427}
428
429#[derive(derive_more::Debug)]
430pub struct ScriptRunContext {
431    pub args: ScriptRunArgs,
432    pub grok: Grok,
433    timeout: Duration,
434    env: ScriptEnv,
435    includes: Arc<HashMap<String, Script>>,
436    background: ScriptMode,
437    #[debug(skip)]
438    kill: ScriptKillReceiver,
439    #[debug(skip)]
440    kill_sender: ScriptKillSender,
441    output: ScriptOutput,
442
443    global_ignore: OutputPatterns,
444    global_reject: OutputPatterns,
445}
446
447impl Default for ScriptRunContext {
448    fn default() -> Self {
449        let kill = Arc::new(AtomicBool::new(false));
450        Self {
451            args: ScriptRunArgs::default(),
452            grok: Grok::with_default_patterns(),
453            timeout: DEFAULT_TIMEOUT,
454            env: ScriptEnv::default(),
455            background: ScriptMode::Normal,
456            includes: Arc::new(HashMap::new()),
457            kill: ScriptKillReceiver::new(kill.clone()),
458            kill_sender: ScriptKillSender::new(kill.clone()),
459            output: ScriptOutput::default(),
460            global_ignore: OutputPatterns::default(),
461            global_reject: OutputPatterns::default(),
462        }
463    }
464}
465
466impl ScriptRunContext {
467    pub fn new_background(&self) -> Self {
468        let kill = Arc::new(AtomicBool::new(false));
469        Self {
470            args: self.args.clone(),
471            grok: self.grok.clone(),
472            // Background processes are not subject to timeouts
473            timeout: Duration::MAX,
474            env: self.env.clone(),
475            background: ScriptMode::Background,
476            kill: ScriptKillReceiver::new(kill.clone()),
477            kill_sender: ScriptKillSender::new(kill.clone()),
478            includes: self.includes.clone(),
479            output: if self.args.verbose {
480                self.output.clone()
481            } else {
482                ScriptOutput::quiet(self.args.no_color)
483            },
484            global_ignore: self.global_ignore.clone(),
485            global_reject: self.global_reject.clone(),
486        }
487    }
488
489    pub fn new_deferred(&self) -> Self {
490        Self {
491            args: self.args.clone(),
492            grok: self.grok.clone(),
493            timeout: self.timeout,
494            env: self.env.clone(),
495            background: ScriptMode::Deferred,
496            kill: self.kill.clone(),
497            kill_sender: self.kill_sender.clone(),
498            includes: self.includes.clone(),
499            output: self.output.clone(),
500            global_ignore: self.global_ignore.clone(),
501            global_reject: self.global_reject.clone(),
502        }
503    }
504
505    pub fn pwd(&self) -> NicePathBuf {
506        self.env.pwd()
507    }
508
509    pub fn get_env(&self, name: &str) -> Option<&str> {
510        self.env.get_env(name)
511    }
512
513    pub fn set_env(&mut self, name: impl Into<String>, value: impl Into<String>) {
514        self.env.set_env(name, value);
515    }
516
517    pub fn set_pwd(&mut self, pwd: impl Into<NicePathBuf>) {
518        self.env.set_pwd(pwd);
519    }
520
521    pub fn take_output(self) -> String {
522        self.output.take_buffer()
523    }
524
525    fn expand(&self, value: &ShellBit) -> Result<String, ScriptRunError> {
526        self.env.expand(value)
527    }
528
529    /// Get a mutable reference to the output stream.
530    pub fn stream(&self) -> impl termcolor::WriteColor + use<'_> {
531        ScriptOutputLock {
532            stream: self.output.stream.write(),
533        }
534    }
535}
536
537#[derive(Clone)]
538pub struct ScriptKillReceiver {
539    kill_receiver: Arc<AtomicBool>,
540}
541
542impl ScriptKillReceiver {
543    pub fn new(kill_receiver: Arc<AtomicBool>) -> Self {
544        Self { kill_receiver }
545    }
546
547    pub fn is_killed(&self) -> bool {
548        self.kill_receiver.load(std::sync::atomic::Ordering::SeqCst)
549    }
550
551    pub fn run_with<T>(&self, kill: impl FnOnce() + Send, wait: impl FnOnce() -> T) -> T {
552        std::thread::scope(|s| {
553            let done = Arc::new(AtomicBool::new(false));
554            let done_clone = done.clone();
555            let t = s.spawn(move || {
556                while !done_clone.load(std::sync::atomic::Ordering::SeqCst) {
557                    if self.is_killed() {
558                        kill();
559                        break;
560                    }
561                    std::thread::sleep(Duration::from_millis(10));
562                }
563            });
564            let res = wait();
565            done.store(true, std::sync::atomic::Ordering::SeqCst);
566            t.join().unwrap();
567            res
568        })
569    }
570
571    #[cfg(windows)]
572    pub fn run_cmd(
573        &self,
574        output: std::process::Child,
575        warn_time: Duration,
576    ) -> std::io::Result<ExitStatus> {
577        use std::os::windows::io::AsRawHandle;
578        use win32job::Job;
579
580        fn map_job_error(e: win32job::JobError) -> std::io::Error {
581            match e {
582                win32job::JobError::AssignFailed(e) => e,
583                win32job::JobError::CreateFailed(e) => e,
584                win32job::JobError::GetInfoFailed(e) => e,
585                win32job::JobError::SetInfoFailed(e) => e,
586                _ => std::io::Error::new(std::io::ErrorKind::Other, "Unknown error"),
587            }
588        }
589
590        // Create a new Job object
591        let job = Job::create().map_err(map_job_error)?;
592
593        // Configure the job to terminate all child processes when the job is closed
594        let mut info = job.query_extended_limit_info().map_err(map_job_error)?;
595        info.limit_kill_on_job_close();
596        job.set_extended_limit_info(&info).map_err(map_job_error)?;
597        job.assign_process(output.as_raw_handle() as _)?;
598
599        // Resume the main thread for the process
600        let id = output.id();
601        for thread_entry in tlhelp32::Snapshot::new_thread()? {
602            if thread_entry.owner_process_id == id {
603                use windows_sys::Win32::Foundation::CloseHandle;
604                use windows_sys::Win32::System::Threading::*;
605
606                unsafe {
607                    let thread = OpenThread(THREAD_SUSPEND_RESUME, 0, thread_entry.thread_id);
608                    if thread.is_null() {
609                        return Err(std::io::Error::last_os_error().into());
610                    }
611                    ResumeThread(thread);
612                    CloseHandle(thread);
613                }
614            }
615        }
616
617        let job = Mutex::new(Some(job));
618        let output = Mutex::new(output);
619        self.run_with(
620            || {
621                _ = job.lock().unwrap().take();
622                _ = output.lock().unwrap().kill();
623            },
624            || {
625                let start = std::time::Instant::now();
626                let mut warned = false;
627                loop {
628                    let res = output.lock().unwrap().try_wait()?;
629                    if let Some(status) = res {
630                        return Ok::<_, std::io::Error>(status);
631                    }
632                    if start.elapsed() > warn_time {
633                        if !warned {
634                            let child = output.lock().unwrap().id();
635                            eprintln!("Process #{child} taking too long to finish.");
636                            warned = true;
637                        }
638                    }
639                    std::thread::sleep(Duration::from_millis(10));
640                }
641            },
642        )
643    }
644
645    #[cfg(unix)]
646    pub fn run_cmd(
647        &self,
648        output: std::process::Child,
649        warn_time: Duration,
650    ) -> std::io::Result<ExitStatus> {
651        let output = Mutex::new(output);
652        self.run_with(
653            || {
654                use signal_child::{signal, signal::*};
655                let id = output.lock().unwrap().id() as i32;
656                _ = signal(-id, SIGINT);
657                std::thread::sleep(Duration::from_millis(10));
658                _ = signal(-id, SIGTERM);
659            },
660            || {
661                let start = std::time::Instant::now();
662                let mut warned = false;
663                loop {
664                    let res = output.lock().unwrap().try_wait()?;
665                    if let Some(status) = res {
666                        return Ok::<_, std::io::Error>(status);
667                    }
668                    if start.elapsed() > warn_time && !warned {
669                        let child = output.lock().unwrap().id();
670                        eprintln!("Process #{child} taking too long to finish.");
671                        warned = true;
672                    }
673                    std::thread::sleep(Duration::from_millis(10));
674                }
675            },
676        )
677    }
678}
679
680#[derive(Clone)]
681pub struct ScriptKillSender {
682    kill_sender: Arc<AtomicBool>,
683}
684
685impl ScriptKillSender {
686    pub fn new(kill_sender: Arc<AtomicBool>) -> Self {
687        Self { kill_sender }
688    }
689
690    pub fn kill(&self) {
691        self.kill_sender
692            .store(true, std::sync::atomic::Ordering::SeqCst);
693    }
694}
695
696impl ScriptRunContext {
697    pub fn new(args: ScriptRunArgs, script_path: impl AsRef<Path>, output: ScriptOutput) -> Self {
698        let mut env = ScriptEnv::default();
699        env.set_defaults(script_path.as_ref().parent().unwrap());
700
701        let kill = Arc::new(AtomicBool::new(false));
702
703        Self {
704            timeout: args.global_timeout.unwrap_or(DEFAULT_TIMEOUT),
705            args,
706            env,
707            grok: Grok::with_default_patterns(),
708            includes: Arc::new(HashMap::new()),
709            background: ScriptMode::Normal,
710            kill: ScriptKillReceiver::new(kill.clone()),
711            kill_sender: ScriptKillSender::new(kill.clone()),
712            output,
713            global_ignore: OutputPatterns::default(),
714            global_reject: OutputPatterns::default(),
715        }
716    }
717}
718
719#[derive(Clone, Debug, PartialEq, Eq)]
720pub struct ScriptLine {
721    pub location: ScriptLocation,
722    text: String,
723}
724
725impl ScriptLine {
726    pub fn new(file: ScriptFile, line: usize, text: impl AsRef<str>) -> Self {
727        Self {
728            location: ScriptLocation::new(file, line),
729            text: text.as_ref().to_string(),
730        }
731    }
732
733    pub fn parse(file: ScriptFile, text: impl AsRef<str>) -> Vec<Self> {
734        text.as_ref()
735            .lines()
736            .enumerate()
737            .map(|(line, text)| Self {
738                location: ScriptLocation::new(file.clone(), line + file.base_line + 1),
739                text: text.to_string(),
740            })
741            .collect()
742    }
743
744    pub fn starts_with(&self, text: &str) -> bool {
745        self.text.trim().starts_with(text)
746    }
747
748    pub fn first_char(&self) -> Option<char> {
749        self.text.trim().chars().next()
750    }
751
752    pub fn text(&self) -> &str {
753        self.text.trim()
754    }
755
756    pub fn text_untrimmed(&self) -> &str {
757        &self.text
758    }
759
760    pub fn is_empty(&self) -> bool {
761        self.text.trim().is_empty()
762    }
763
764    pub fn strip_prefix(&self, prefix: &str) -> Option<&str> {
765        self.text.strip_prefix(prefix)
766    }
767}
768
769#[derive(Debug, thiserror::Error, derive_more::Display)]
770#[display("{error} at {location}{}", associated_data.as_deref().map_or("".to_string(), |d| format!(": {d}")))]
771pub struct ScriptError {
772    pub error: ScriptErrorType,
773    pub location: ScriptLocation,
774    pub associated_data: Option<String>,
775}
776
777impl ScriptError {
778    pub fn new(error: ScriptErrorType, location: ScriptLocation) -> Self {
779        if std::env::var("PANIC_ON_ERROR").is_ok() {
780            panic!("ScriptError: {error} at {location}");
781        }
782        Self {
783            error,
784            location,
785            associated_data: None,
786        }
787    }
788
789    pub fn new_with_data(
790        error: ScriptErrorType,
791        location: ScriptLocation,
792        associated_data: String,
793    ) -> Self {
794        if std::env::var("PANIC_ON_ERROR").is_ok() {
795            panic!("ScriptError: {error} at {location}: {associated_data}");
796        }
797        Self {
798            error,
799            location,
800            associated_data: Some(associated_data),
801        }
802    }
803}
804
805#[derive(Debug, thiserror::Error, Eq, PartialEq)]
806pub enum ScriptErrorType {
807    #[error("background process not allowed")]
808    BackgroundProcessNotAllowed,
809    #[error("unclosed quote")]
810    UnclosedQuote,
811    #[error("unclosed backslash")]
812    UnclosedBackslash,
813    #[error("illegal shell command format")]
814    IllegalShellCommand,
815    #[error("unsupported redirection")]
816    UnsupportedRedirection,
817    #[error("invalid pattern definition")]
818    InvalidPatternDefinition,
819    #[error("invalid pattern")]
820    InvalidPattern,
821    #[error("invalid meta command")]
822    InvalidMetaCommand,
823    #[error("invalid pattern at global level (only reject or ignore allowed here)")]
824    InvalidGlobalPattern,
825    #[error("invalid block type")]
826    InvalidBlockType,
827    #[error("invalid block arguments")]
828    InvalidBlockArgs,
829    #[error("unsupported command position")]
830    UnsupportedCommandPosition,
831    #[error("invalid trailing pattern after *")]
832    InvalidAnyPattern,
833    #[error("invalid exit status")]
834    InvalidExitStatus,
835    #[error("invalid set variable")]
836    InvalidSetVariable,
837    #[error("invalid version header, expected `#!/usr/bin/env clitest --v0`")]
838    InvalidVersion,
839    #[error("invalid internal command")]
840    InvalidInternalCommand,
841    #[error("missing command lines")]
842    MissingCommandLines,
843    #[error(
844        "block end without matching block start, too many closing braces or braces not properly nested"
845    )]
846    InvalidBlockEnd,
847    #[error("invalid if condition")]
848    InvalidIfCondition,
849    #[error("expected block or semi-colon (did you forget to add ';' at the end of this line?)")]
850    ExpectedBlockOrSemi,
851}
852
853#[derive(Debug, thiserror::Error)]
854pub enum ScriptRunError {
855    #[error("{0}")]
856    Pattern(#[from] OutputPatternMatchFailure),
857    #[error("{0}")]
858    PatternPrepareError(#[from] OutputPatternPrepareError),
859    #[error("{0} at line {1}")]
860    Exit(CommandResult, ScriptLocation),
861    #[error("included file not found: {0}")]
862    IncludedFileNotFound(String),
863    #[error("expected failure, but passed at line {0}")]
864    ExpectedFailure(ScriptLocation),
865    #[error("{0}")]
866    ExpansionError(String),
867    #[error("{0}")]
868    IO(#[from] std::io::Error),
869    #[error("killed")]
870    Killed,
871    #[error("background process took too long to finish")]
872    BackgroundProcessTookTooLong,
873    #[error("retry took too long to finish")]
874    RetryTookTooLong,
875    /// Internal flow control: exit the script
876    #[error("exiting script")]
877    ExitScript,
878}
879
880impl ScriptRunError {
881    #[expect(unused)]
882    pub fn short(&self) -> String {
883        match self {
884            Self::Pattern(_) => "Pattern".to_string(),
885            Self::PatternPrepareError(e) => format!("PatternPrepareError({e:?})"),
886            Self::Exit(status, _) => format!("Exit({status})"),
887            Self::ExpectedFailure(_) => "ExpectedFailure".to_string(),
888            Self::IO(e) => format!("IO({:?})", e.kind()),
889            Self::Killed => "Killed".to_string(),
890            Self::BackgroundProcessTookTooLong => "BackgroundProcessTookTooLong".to_string(),
891            Self::ExpansionError(e) => "ExpansionError".to_string(),
892            Self::RetryTookTooLong => "RetryTookTooLong".to_string(),
893            Self::ExitScript => unreachable!(),
894            Self::IncludedFileNotFound(path) => format!("IncludedFileNotFound({path})"),
895        }
896    }
897}
898
899impl Script {
900    pub fn new(file: ScriptFile) -> Self {
901        Self {
902            commands: Arc::new(vec![]),
903            includes: Arc::new(HashMap::new()),
904            file,
905        }
906    }
907
908    /// Collect all included script paths from the script.
909    pub fn includes(&self) -> Vec<(ScriptLocation, String)> {
910        self.commands
911            .iter()
912            .flat_map(|block| block.includes())
913            .collect()
914    }
915
916    pub fn run(&self, context: &mut ScriptRunContext) -> Result<(), ScriptRunError> {
917        let old_includes = context.includes.clone();
918        context.includes = self.includes.clone();
919        let res = ScriptBlock::run_blocks(context, &self.commands);
920        context.includes = old_includes;
921        let v = match res {
922            Ok(v) => v,
923            // Bypass normal script processing and exit successfully
924            Err(ScriptRunError::ExitScript) => return Ok(()),
925            Err(e) => return Err(e),
926        };
927        assert!(v.is_empty(), "script did not run to completion: {v:?}");
928        Ok(())
929    }
930
931    pub fn run_with_args(
932        &self,
933        args: ScriptRunArgs,
934        output: ScriptOutput,
935    ) -> Result<(), ScriptRunError> {
936        let start = Instant::now();
937        let script_path = &*self.file.file;
938        let mut context = ScriptRunContext::new(args, script_path, output);
939
940        // Write "Running..." message with colors
941        cwrite!(context.stream(), "Running ");
942        cwrite!(context.stream(), fg = Color::Cyan, "{}", script_path);
943        cwriteln!(context.stream(), " ...");
944        cwriteln!(context.stream());
945
946        let result = self.run(&mut context);
947
948        // Handle success and error output
949        if let Err(ref e) = result {
950            cwrite!(context.stream(), fg = Color::Cyan, "{} ", script_path);
951            cwrite!(context.stream(), fg = Color::Red, "FAILED");
952            if !context.args.simplified_output {
953                cwriteln!(context.stream(), " ({:.2}s)", start.elapsed().as_secs_f32());
954            } else {
955                cwriteln!(context.stream());
956            }
957            cwrite!(context.stream(), fg = Color::Red, "Error: ");
958            cwriteln!(context.stream(), "{}", e);
959            cwriteln!(context.stream());
960        } else {
961            cwrite!(context.stream(), fg = Color::Cyan, "{} ", script_path);
962            cwrite!(context.stream(), fg = Color::Green, "PASSED");
963            if !context.args.simplified_output {
964                cwriteln!(context.stream(), " ({:.2}s)", start.elapsed().as_secs_f32());
965            } else {
966                cwriteln!(context.stream());
967            }
968        }
969
970        result
971    }
972}
973
974#[derive(Debug, Default, Serialize)]
975pub enum CommandExit {
976    #[default]
977    Success,
978    Failure(i32),
979    Timeout,
980    Any,
981    AnyFailure,
982}
983
984impl CommandExit {
985    pub fn matches(&self, status: CommandResult) -> bool {
986        match (self, status) {
987            (CommandExit::Success, CommandResult::Exit(status)) => status.success(),
988            (CommandExit::Failure(code), CommandResult::Exit(status)) => {
989                *code == status.code().unwrap_or(-1)
990            }
991            (CommandExit::Timeout, CommandResult::TimedOut) => true,
992            (CommandExit::Any, _) => true,
993            (CommandExit::AnyFailure, CommandResult::Exit(status)) => !status.success(),
994            (CommandExit::AnyFailure, _) => true,
995            _ => false,
996        }
997    }
998
999    pub fn is_success(&self) -> bool {
1000        matches!(self, CommandExit::Success)
1001    }
1002}
1003
1004#[derive(derive_more::Debug)]
1005#[allow(clippy::large_enum_variant)]
1006pub enum ScriptBlock {
1007    Command(ScriptCommand),
1008    InternalCommand(ScriptLocation, InternalCommand),
1009    Background(Vec<ScriptBlock>),
1010    Defer(Vec<ScriptBlock>),
1011    If(IfCondition, Vec<ScriptBlock>),
1012    For(ForCondition, Vec<ScriptBlock>),
1013    Retry(Vec<ScriptBlock>),
1014    GlobalIgnore(OutputPatterns),
1015    GlobalReject(OutputPatterns),
1016}
1017
1018impl ScriptBlock {
1019    pub fn includes(&self) -> Vec<(ScriptLocation, String)> {
1020        match self {
1021            ScriptBlock::Command(..) => vec![],
1022            ScriptBlock::InternalCommand(location, InternalCommand::Include(path)) => {
1023                vec![(location.clone(), path.clone())]
1024            }
1025            ScriptBlock::InternalCommand(..) => vec![],
1026            ScriptBlock::Background(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1027            ScriptBlock::Defer(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1028            ScriptBlock::If(_, blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1029            ScriptBlock::For(_, blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1030            ScriptBlock::Retry(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1031            ScriptBlock::GlobalIgnore(_) => vec![],
1032            ScriptBlock::GlobalReject(_) => vec![],
1033        }
1034    }
1035
1036    #[allow(clippy::type_complexity)]
1037    pub fn run_blocks(
1038        context: &mut ScriptRunContext,
1039        blocks: &[ScriptBlock],
1040    ) -> Result<Vec<ScriptResult>, ScriptRunError> {
1041        enum Deferred<'a> {
1042            Scripts(&'a [ScriptBlock]),
1043            Internal(
1044                Box<
1045                    dyn FnOnce(&mut ScriptRunContext) -> Result<(), ScriptRunError>
1046                        + Send
1047                        + Sync
1048                        + 'a,
1049                >,
1050            ),
1051            Background(
1052                ScopedJoinHandle<'a, Result<Vec<ScriptResult>, ScriptRunError>>,
1053                ScriptKillSender,
1054            ),
1055        }
1056
1057        let mut results = Vec::new();
1058        std::thread::scope(|s| {
1059            let mut defer_blocks = VecDeque::new();
1060            let mut pending_error = None;
1061            for block in blocks {
1062                if context.kill.is_killed() {
1063                    return Err(ScriptRunError::Killed);
1064                }
1065                match block {
1066                    ScriptBlock::Background(blocks) => {
1067                        let mut context = context.new_background();
1068                        let kill_sender = context.kill_sender.clone();
1069                        let handle = s.spawn(move || Self::run_blocks(&mut context, blocks));
1070                        defer_blocks.push_front(Deferred::Background(handle, kill_sender));
1071                    }
1072                    ScriptBlock::Defer(blocks) => {
1073                        // Insert at the front of the queue by extending and
1074                        // then rotating
1075                        defer_blocks.push_front(Deferred::Scripts(blocks));
1076                    }
1077                    ScriptBlock::InternalCommand(_, command) => {
1078                        if context.background == ScriptMode::Deferred {
1079                            cwrite!(context.stream(), dimmed = true, "(deferred) ");
1080                        }
1081                        if let Some(f) = command.run(context)? {
1082                            defer_blocks.push_front(Deferred::Internal(f));
1083                        }
1084                    }
1085                    _ => match block.run(context) {
1086                        Ok(res) => results.extend(res),
1087                        Err(e) => {
1088                            pending_error = Some(e);
1089                            break;
1090                        }
1091                    },
1092                }
1093            }
1094            for block in defer_blocks {
1095                match block {
1096                    Deferred::Scripts(blocks) => {
1097                        let mut context = context.new_deferred();
1098                        ScriptBlock::run_blocks(&mut context, blocks)?;
1099                    }
1100                    Deferred::Internal(block) => {
1101                        cwrite!(context.stream(), dimmed = true, "(cleanup) ");
1102                        block(context)?;
1103                    }
1104                    Deferred::Background(handle, kill_sender) => {
1105                        kill_sender.kill();
1106                        let start = std::time::Instant::now();
1107                        let mut warned = false;
1108
1109                        let timeout = context.timeout;
1110                        let warn_at = timeout * 8 / 10;
1111
1112                        let results = loop {
1113                            if handle.is_finished() {
1114                                break handle.join().unwrap()?;
1115                            }
1116                            std::thread::sleep(std::time::Duration::from_millis(10));
1117                            if !warned && start.elapsed() > warn_at {
1118                                cwriteln!(
1119                                    context.stream(),
1120                                    fg = Color::Yellow,
1121                                    "Background process is taking too long to finish."
1122                                );
1123                                warned = true;
1124                            }
1125                            if start.elapsed() > timeout {
1126                                cwriteln!(
1127                                    context.stream(),
1128                                    fg = Color::Red,
1129                                    "Background process took too long to finish."
1130                                );
1131                                return Err(ScriptRunError::BackgroundProcessTookTooLong);
1132                            }
1133                        };
1134                        for result in results {
1135                            cwrite!(context.stream(), dimmed = true, "(background) ");
1136                            for line in result.command.command.split('\n') {
1137                                cwriteln!(context.stream(), fg = Color::Green, "{}", line);
1138                            }
1139                            if context.args.simplified_output {
1140                                cwriteln!(context.stream(), dimmed = true, "---");
1141                            } else {
1142                                cwriteln_rule!(
1143                                    context.stream(),
1144                                    fg = Color::Cyan,
1145                                    "{}",
1146                                    result.command.location
1147                                );
1148                            }
1149                            for line in &result.output {
1150                                cwriteln!(context.stream(), "{}", line);
1151                            }
1152                            if result.output.is_empty() {
1153                                cwriteln!(context.stream(), dimmed = true, "(no output)");
1154                            }
1155                            if context.args.simplified_output {
1156                                cwriteln!(context.stream(), dimmed = true, "---");
1157                            } else {
1158                                cwriteln_rule!(context.stream());
1159                            }
1160                            result.evaluate(context)?;
1161                        }
1162                    }
1163                }
1164            }
1165            if let Some(error) = pending_error {
1166                return Err(error);
1167            }
1168            Ok(results)
1169        })
1170    }
1171
1172    pub fn run(&self, context: &mut ScriptRunContext) -> Result<Vec<ScriptResult>, ScriptRunError> {
1173        let pwd = context.pwd();
1174        let res = pwd.exists();
1175        if !matches!(res, Ok(true)) {
1176            cwriteln!(
1177                context.stream(),
1178                fg = Color::Red,
1179                "$PWD {pwd:?} doesn't exist. Run `cd $INITIAL_PWD` to fix.",
1180            );
1181            return Err(ScriptRunError::IO(std::io::Error::new(
1182                std::io::ErrorKind::NotFound,
1183                format!("PWD does not exist: {pwd:?}"),
1184            )));
1185        }
1186
1187        match self {
1188            ScriptBlock::Command(command) => {
1189                if context.background == ScriptMode::Deferred {
1190                    cwrite!(context.stream(), dimmed = true, "(deferred) ");
1191                }
1192                let result = command.run(context)?;
1193                if context.background != ScriptMode::Background {
1194                    result.evaluate(context)?;
1195                    Ok(vec![])
1196                } else {
1197                    Ok(vec![result])
1198                }
1199            }
1200            ScriptBlock::If(condition, blocks) => {
1201                let condition = condition.expand(context)?;
1202                if condition.matches(context) {
1203                    Self::run_blocks(context, blocks)
1204                } else {
1205                    Ok(vec![])
1206                }
1207            }
1208            ScriptBlock::For(ForCondition::Env(env, values), blocks) => {
1209                let mut results = Vec::new();
1210                for value in values {
1211                    context.set_env(env, context.expand(value)?);
1212                    results.extend(Self::run_blocks(context, blocks)?);
1213                }
1214                Ok(results)
1215            }
1216            ScriptBlock::Retry(blocks) => {
1217                let start = Instant::now();
1218                let mut backoff = Duration::from_millis(100);
1219
1220                cwrite!(context.stream(), fg = Color::Green, "retry: ");
1221                cwriteln!(context.stream(), "running...");
1222
1223                loop {
1224                    let mut nested_context = context.new_background();
1225                    if let Ok(results) = Self::run_blocks(&mut nested_context, blocks) {
1226                        let mut all_ok = true;
1227                        for result in results {
1228                            if result.evaluate(&mut nested_context).is_err() {
1229                                all_ok = false;
1230                                break;
1231                            }
1232                        }
1233                        if all_ok {
1234                            let output = nested_context.take_output();
1235                            cwrite!(context.stream(), fg = Color::Green, "retry: ");
1236                            cwriteln!(context.stream(), "success");
1237                            cwriteln!(context.stream());
1238                            cwriteln!(context.stream(), "{output}");
1239                            return Ok(vec![]);
1240                        }
1241                    }
1242
1243                    if start.elapsed() > context.timeout {
1244                        let output = nested_context.take_output();
1245                        cwrite!(context.stream(), fg = Color::Green, "retry: ");
1246                        cwriteln!(context.stream(), fg = Color::Red, "timed out");
1247                        cwriteln!(context.stream());
1248                        cwriteln!(context.stream(), "{output}");
1249                        cwriteln_rule!(context.stream());
1250                        return Err(ScriptRunError::RetryTookTooLong);
1251                    }
1252                    std::thread::sleep(backoff);
1253                    backoff *= 2;
1254                }
1255            }
1256            ScriptBlock::GlobalIgnore(patterns) => {
1257                for pattern in patterns.iter() {
1258                    pattern.prepare(&context.grok)?;
1259                }
1260                context.global_ignore.extend(patterns);
1261                Ok(vec![])
1262            }
1263            ScriptBlock::GlobalReject(patterns) => {
1264                for pattern in patterns.iter() {
1265                    pattern.prepare(&context.grok)?;
1266                }
1267                context.global_reject.extend(patterns);
1268                Ok(vec![])
1269            }
1270            _ => unreachable!("Unexpected block type: {self:?}"),
1271        }
1272    }
1273}
1274
1275impl Serialize for ScriptBlock {
1276    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1277    where
1278        S: serde::Serializer,
1279    {
1280        match self {
1281            ScriptBlock::Command(command) => command.serialize(serializer),
1282            ScriptBlock::InternalCommand(_, command) => command.serialize(serializer),
1283            ScriptBlock::Background(blocks) => {
1284                let mut ser = serializer.serialize_map(Some(1))?;
1285                ser.serialize_entry("background", blocks)?;
1286                ser.end()
1287            }
1288            ScriptBlock::Defer(blocks) => {
1289                let mut ser = serializer.serialize_map(Some(1))?;
1290                ser.serialize_entry("defer", blocks)?;
1291                ser.end()
1292            }
1293            ScriptBlock::If(condition, blocks) => {
1294                let mut ser = serializer.serialize_map(Some(2))?;
1295                ser.serialize_entry("if", condition)?;
1296                ser.serialize_entry("blocks", blocks)?;
1297                ser.end()
1298            }
1299            ScriptBlock::For(condition, blocks) => {
1300                let mut ser = serializer.serialize_map(Some(2))?;
1301                ser.serialize_entry("for", condition)?;
1302                ser.serialize_entry("blocks", blocks)?;
1303                ser.end()
1304            }
1305            ScriptBlock::Retry(blocks) => {
1306                let mut ser = serializer.serialize_map(Some(1))?;
1307                ser.serialize_entry("retry", blocks)?;
1308                ser.end()
1309            }
1310            ScriptBlock::GlobalIgnore(patterns) => {
1311                let mut ser = serializer.serialize_map(Some(1))?;
1312                ser.serialize_entry("ignore", patterns)?;
1313                ser.end()
1314            }
1315            ScriptBlock::GlobalReject(patterns) => {
1316                let mut ser = serializer.serialize_map(Some(1))?;
1317                ser.serialize_entry("reject", patterns)?;
1318                ser.end()
1319            }
1320        }
1321    }
1322}
1323
1324#[derive(Debug, Clone, Serialize)]
1325pub enum InternalCommand {
1326    UsingTempdir,
1327    UsingDir(ShellBit, bool),
1328    ChangeDir(ShellBit),
1329    Set(String, ShellBit),
1330    Include(String),
1331    ExitScript,
1332    Pattern(String, String),
1333}
1334
1335impl InternalCommand {
1336    #[allow(clippy::type_complexity)]
1337    pub fn run(
1338        &self,
1339        context: &mut ScriptRunContext,
1340    ) -> Result<
1341        Option<Box<dyn FnOnce(&mut ScriptRunContext) -> Result<(), ScriptRunError> + Send + Sync>>,
1342        ScriptRunError,
1343    > {
1344        match self.clone() {
1345            InternalCommand::Include(path) => {
1346                let Some(script) = context.includes.get(&path) else {
1347                    return Err(ScriptRunError::IncludedFileNotFound(path));
1348                };
1349                script.clone().run(context)?;
1350                Ok(None)
1351            }
1352            InternalCommand::Pattern(name, pattern) => {
1353                context.grok.add_pattern(name, pattern);
1354                Ok(None)
1355            }
1356            InternalCommand::UsingTempdir => {
1357                let current_pwd = context.pwd();
1358                let tempdir = NiceTempDir::new();
1359                cwrite!(context.stream(), fg = Color::Yellow, "using tempdir: ");
1360                cwriteln!(context.stream(), "{}", tempdir);
1361                cwriteln!(context.stream());
1362                context.set_pwd(&tempdir);
1363                let pwd = context.pwd();
1364                if !pwd.exists()? {
1365                    return Err(ScriptRunError::IO(std::io::Error::new(
1366                        std::io::ErrorKind::NotFound,
1367                        format!("newly created tempdir does not exist: {pwd:?}"),
1368                    )));
1369                }
1370                Ok(Some(Box::new(move |context: &mut ScriptRunContext| {
1371                    cwriteln!(
1372                        context.stream(),
1373                        fg = Color::Yellow,
1374                        "removing {} && cd {}",
1375                        tempdir,
1376                        current_pwd
1377                    );
1378                    cwriteln!(context.stream());
1379                    if !tempdir.exists()? {
1380                        cwriteln!(
1381                            context.stream(),
1382                            fg = Color::Red,
1383                            "tempdir does not exist: {tempdir}"
1384                        );
1385                    }
1386                    if let Err(e) = tempdir.remove_dir_all() {
1387                        cwriteln!(
1388                            context.stream(),
1389                            fg = Color::Red,
1390                            "error removing tempdir: {e:?}"
1391                        );
1392                    }
1393                    Ok::<_, ScriptRunError>(())
1394                })))
1395            }
1396            InternalCommand::UsingDir(dir, new) => {
1397                let current_pwd = context.pwd();
1398                let dir = context.expand(&dir)?;
1399                let new_pwd = current_pwd.join(dir);
1400                if new {
1401                    cwrite!(context.stream(), fg = Color::Yellow, "using new dir: ");
1402                } else {
1403                    cwrite!(context.stream(), fg = Color::Yellow, "using dir: ");
1404                }
1405                cwriteln!(context.stream(), "{}", new_pwd);
1406                cwriteln!(context.stream());
1407
1408                if new {
1409                    new_pwd.create_dir_all()?;
1410                } else if !new_pwd.exists()? {
1411                    return Err(ScriptRunError::IO(std::io::Error::new(
1412                        std::io::ErrorKind::NotFound,
1413                        "directory does not exist",
1414                    )));
1415                }
1416                context.set_pwd(&new_pwd);
1417                Ok(Some(Box::new(move |context: &mut ScriptRunContext| {
1418                    if new {
1419                        cwriteln!(
1420                            context.stream(),
1421                            fg = Color::Yellow,
1422                            "removing {} && cd {}",
1423                            new_pwd,
1424                            current_pwd
1425                        );
1426                        cwriteln!(context.stream());
1427                    } else {
1428                        cwriteln!(context.stream(), fg = Color::Yellow, "cd {}", current_pwd);
1429                        cwriteln!(context.stream());
1430                    }
1431                    if new {
1432                        new_pwd.remove_dir_all()?;
1433                    }
1434                    context.set_pwd(current_pwd);
1435                    Ok::<_, ScriptRunError>(())
1436                })))
1437            }
1438            InternalCommand::ChangeDir(dir) => {
1439                let dir = context.expand(&dir)?;
1440
1441                cwriteln!(context.stream(), fg = Color::Yellow, "cd {dir}");
1442                cwriteln!(context.stream());
1443                let current_pwd = context.pwd();
1444                let new_pwd = current_pwd.join(dir);
1445                if !new_pwd.exists()? {
1446                    return Err(ScriptRunError::IO(std::io::Error::new(
1447                        std::io::ErrorKind::NotFound,
1448                        format!("directory does not exist: {new_pwd}"),
1449                    )));
1450                }
1451                context.set_pwd(new_pwd);
1452                Ok(None)
1453            }
1454            InternalCommand::Set(name, value) => {
1455                let value = context.expand(&value)?;
1456
1457                context.set_env(&name, &value);
1458                let new_value = context.get_env(&name).unwrap_or_default();
1459                if new_value != value {
1460                    cwriteln!(
1461                        context.stream(),
1462                        fg = Color::Yellow,
1463                        "set {name} {value} (-> {new_value})"
1464                    );
1465                } else {
1466                    cwriteln!(context.stream(), fg = Color::Yellow, "set {name} {value}");
1467                }
1468                cwriteln!(context.stream());
1469
1470                Ok(None)
1471            }
1472            InternalCommand::ExitScript => {
1473                cwriteln!(context.stream(), fg = Color::Yellow, "exiting script");
1474                cwriteln!(context.stream());
1475                Err(ScriptRunError::ExitScript)
1476            }
1477        }
1478    }
1479}
1480
1481#[derive(Debug, Clone)]
1482pub enum IfCondition {
1483    True,
1484    False,
1485    EnvEq(bool, String, ShellBit),
1486}
1487
1488impl IfCondition {
1489    pub fn matches(&self, context: &ScriptRunContext) -> bool {
1490        match self {
1491            IfCondition::True => true,
1492            IfCondition::False => false,
1493            IfCondition::EnvEq(negated, name, expected) => {
1494                let value = context.get_env(name).unwrap_or_default();
1495                (expected == value) ^ negated
1496            }
1497        }
1498    }
1499
1500    pub fn expand(&self, context: &ScriptRunContext) -> Result<IfCondition, ScriptRunError> {
1501        match self {
1502            IfCondition::True => Ok(IfCondition::True),
1503            IfCondition::False => Ok(IfCondition::False),
1504            IfCondition::EnvEq(negated, name, expected) => {
1505                let value = context.expand(expected)?;
1506                Ok(IfCondition::EnvEq(
1507                    *negated,
1508                    name.clone(),
1509                    ShellBit::Literal(value),
1510                ))
1511            }
1512        }
1513    }
1514}
1515
1516impl Serialize for IfCondition {
1517    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1518    where
1519        S: serde::Serializer,
1520    {
1521        match self {
1522            IfCondition::True => "true".serialize(serializer),
1523            IfCondition::False => "false".serialize(serializer),
1524            IfCondition::EnvEq(negated, name, value) => {
1525                let mut ser = serializer.serialize_map(Some(3))?;
1526                ser.serialize_entry("op", if *negated { "!=" } else { "==" })?;
1527                ser.serialize_entry("env", name)?;
1528                ser.serialize_entry("value", value)?;
1529                ser.end()
1530            }
1531        }
1532    }
1533}
1534
1535#[derive(Debug)]
1536pub enum ForCondition {
1537    Env(String, Vec<ShellBit>),
1538}
1539
1540impl Serialize for ForCondition {
1541    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1542    where
1543        S: serde::Serializer,
1544    {
1545        match self {
1546            ForCondition::Env(name, values) => {
1547                let mut ser = serializer.serialize_map(Some(2))?;
1548                ser.serialize_entry("env", name)?;
1549                ser.serialize_entry("values", values)?;
1550                ser.end()
1551            }
1552        }
1553    }
1554}
1555
1556fn is_bool_false(b: &bool) -> bool {
1557    !b
1558}
1559
1560#[derive(Debug, Serialize)]
1561pub struct ScriptCommand {
1562    pub command: CommandLine,
1563    pub pattern: OutputPattern,
1564
1565    #[serde(skip_serializing_if = "CommandExit::is_success")]
1566    pub exit: CommandExit,
1567
1568    #[serde(skip_serializing_if = "is_bool_false")]
1569    pub expect_failure: bool,
1570
1571    /// Single set variable (entire command output trimmed)
1572    #[serde(skip_serializing_if = "Option::is_none")]
1573    pub set_var: Option<String>,
1574
1575    /// Specific set variables
1576    pub set_vars: HashMap<String, ShellBit>,
1577
1578    /// Specific command timeout
1579    #[serde(skip_serializing_if = "Option::is_none")]
1580    pub timeout: Option<Duration>,
1581
1582    /// Input grok expectations
1583    pub expect: HashMap<String, ShellBit>,
1584}
1585
1586impl ScriptCommand {
1587    pub fn new(command: CommandLine) -> Self {
1588        let location = command.location.clone();
1589        Self {
1590            command,
1591            pattern: OutputPattern {
1592                pattern: OutputPatternType::None,
1593                ignore: Default::default(),
1594                reject: Default::default(),
1595                location,
1596            },
1597            exit: Default::default(),
1598            timeout: None,
1599            expect_failure: false,
1600            set_var: None,
1601            set_vars: Default::default(),
1602            expect: Default::default(),
1603        }
1604    }
1605
1606    pub fn run(&self, context: &mut ScriptRunContext) -> Result<ScriptResult, ScriptRunError> {
1607        let command = &self.command;
1608        let args = &context.args;
1609        let start = Instant::now();
1610
1611        if let Some(delay) = args.delay_steps {
1612            std::thread::sleep(std::time::Duration::from_millis(delay));
1613        }
1614
1615        for line in command.command.split('\n') {
1616            cwriteln!(context.stream(), fg = Color::Green, "{}", line);
1617        }
1618        if args.simplified_output {
1619            cwriteln!(context.stream(), dimmed = true, "---");
1620        } else {
1621            cwriteln_rule!(context.stream(), fg = Color::Cyan, "{}", command.location);
1622        }
1623        let (output, status) = command.run(
1624            &mut context.stream(),
1625            context.args.show_line_numbers,
1626            context.args.runner.clone(),
1627            self.timeout.unwrap_or(context.timeout),
1628            context.env.env_vars(),
1629            &context.kill,
1630            &context.kill_sender,
1631        )?;
1632
1633        let exit_result = if !self.exit.matches(status) {
1634            ExitResult::Mismatch(status)
1635        } else {
1636            ExitResult::Matches(status)
1637        };
1638
1639        // Side-effects
1640        if let Some(set_var) = &self.set_var {
1641            context.set_env(set_var, output.to_string().trim());
1642        }
1643
1644        let match_context = OutputMatchContext::new(context);
1645        for (key, value) in &self.expect {
1646            match_context.expect(key, context.expand(value)?);
1647        }
1648        self.pattern.prepare(&context.grok)?;
1649        let prepared_output = output
1650            .with_ignore(&context.global_ignore)
1651            .with_reject(&context.global_reject);
1652        let pattern_result = match self.pattern.matches(match_context.clone(), prepared_output) {
1653            Ok(_) => {
1654                let mut env = context.env.clone();
1655                for (key, value) in match_context.expects() {
1656                    env.set_env(key, value);
1657                }
1658                for (key, value) in &self.set_vars {
1659                    context.set_env(key, env.expand(value)?);
1660                }
1661
1662                if self.expect_failure {
1663                    PatternResult::ExpectedFailure
1664                } else {
1665                    PatternResult::Matches
1666                }
1667            }
1668            Err(e) => {
1669                if self.expect_failure {
1670                    PatternResult::MatchesFailure
1671                } else {
1672                    let trace = format_match_trace_tree(&match_context.traces());
1673                    PatternResult::Mismatch(e, trace)
1674                }
1675            }
1676        };
1677
1678        if output.is_empty() {
1679            cwriteln!(context.stream(), dimmed = true, "(no output)");
1680        }
1681
1682        if context.args.simplified_output {
1683            cwriteln!(context.stream(), dimmed = true, "---");
1684        } else {
1685            cwriteln_rule!(context.stream());
1686        }
1687
1688        Ok(ScriptResult {
1689            command: command.clone(),
1690            pattern: pattern_result,
1691            exit: exit_result,
1692            elapsed: start.elapsed(),
1693            output,
1694        })
1695    }
1696}
1697
1698#[derive(derive_more::Debug)]
1699pub struct ScriptResult {
1700    pub command: CommandLine,
1701    pub pattern: PatternResult,
1702    pub exit: ExitResult,
1703    pub elapsed: Duration,
1704    #[debug(skip)]
1705    pub output: Lines,
1706}
1707
1708impl ScriptResult {
1709    pub fn evaluate(&self, context: &mut ScriptRunContext) -> Result<(), ScriptRunError> {
1710        let args = &context.args;
1711        let (success, failure, warning, arrow) = if *crate::term::IS_UTF8 {
1712            ("✅", "❌", "⚠️", "→")
1713        } else {
1714            ("[*]", "[X]", "[!]", "->")
1715        };
1716
1717        if let ExitResult::Mismatch(status) = self.exit {
1718            if args.ignore_exit_codes {
1719                cwriteln!(
1720                    context.stream(),
1721                    fg = Color::Yellow,
1722                    "{warning} Ignored incorrect exit code: {status}"
1723                );
1724                cwriteln!(context.stream());
1725            } else {
1726                cwriteln!(
1727                    context.stream(),
1728                    fg = Color::Red,
1729                    "{failure} FAIL: {status}"
1730                );
1731                cwriteln!(
1732                    context.stream(),
1733                    dimmed = true,
1734                    " {arrow} {}",
1735                    self.command.command
1736                );
1737                cwriteln!(context.stream());
1738                return Err(ScriptRunError::Exit(status, self.command.location.clone()));
1739            }
1740        }
1741
1742        if let PatternResult::Mismatch(e, trace) = &self.pattern {
1743            if args.ignore_matches {
1744                cwriteln!(
1745                    context.stream(),
1746                    fg = Color::Yellow,
1747                    "{warning} Ignored error: {e} (ignoring mismatches)"
1748                );
1749                cwriteln!(context.stream());
1750            } else {
1751                cwriteln!(context.stream(), fg = Color::Red, "ERROR: {e}");
1752                cwriteln!(context.stream(), dimmed = true, "{trace}");
1753                cwriteln!(context.stream(), fg = Color::Red, "{failure} FAIL");
1754                cwriteln!(context.stream());
1755                return Err(ScriptRunError::Pattern(e.clone()));
1756            }
1757        }
1758
1759        if let PatternResult::ExpectedFailure = self.pattern {
1760            if args.ignore_matches {
1761                cwriteln!(
1762                    context.stream(),
1763                    fg = Color::Yellow,
1764                    "{warning} Should not have matched! (ignoring mismatches)"
1765                );
1766                cwriteln!(context.stream());
1767            } else {
1768                cwriteln!(
1769                    context.stream(),
1770                    fg = Color::Red,
1771                    "{failure} FAIL (output shouldn't match)"
1772                );
1773                cwriteln!(
1774                    context.stream(),
1775                    dimmed = true,
1776                    " {arrow} {}",
1777                    self.command.command
1778                );
1779                cwriteln!(context.stream());
1780                return Err(ScriptRunError::ExpectedFailure(
1781                    self.command.location.clone(),
1782                ));
1783            }
1784        }
1785
1786        if let ExitResult::Matches(status) = self.exit {
1787            if status.success() {
1788                cwrite!(context.stream(), fg = Color::Green, "{success} OK");
1789                if !context.args.simplified_output {
1790                    cwriteln!(
1791                        context.stream(),
1792                        dimmed = true,
1793                        " ({:.2}s)",
1794                        self.elapsed.as_secs_f32()
1795                    );
1796                } else {
1797                    cwriteln!(context.stream());
1798                }
1799            } else {
1800                cwrite!(
1801                    context.stream(),
1802                    fg = Color::Green,
1803                    "{success} OK ({status})"
1804                );
1805                if !context.args.simplified_output {
1806                    cwriteln!(
1807                        context.stream(),
1808                        dimmed = true,
1809                        " ({:.2}s)",
1810                        self.elapsed.as_secs_f32()
1811                    );
1812                } else {
1813                    cwriteln!(context.stream());
1814                }
1815            }
1816            cwriteln!(context.stream());
1817        }
1818
1819        Ok(())
1820    }
1821}
1822
1823#[derive(Debug)]
1824pub enum PatternResult {
1825    Matches,
1826    MatchesFailure,
1827    ExpectedFailure,
1828    Mismatch(OutputPatternMatchFailure, String),
1829}
1830
1831#[derive(Debug)]
1832pub enum ExitResult {
1833    Matches(CommandResult),
1834    Mismatch(CommandResult),
1835    TimedOut,
1836}
1837
1838#[cfg(test)]
1839mod tests {
1840    use crate::parser::v0::parse_script;
1841
1842    use super::*;
1843    use std::error::Error;
1844
1845    #[test]
1846    fn test_script() -> Result<(), Box<dyn Error>> {
1847        let script = r#"
1848pattern VERSION \d+\.\d+\.\d+;
1849
1850$ something --version || echo 1
1851? Something %{VERSION}
1852
1853$ something --help
1854? Usage: something [OPTIONS]
1855repeat {
1856    choice {
1857? %{DATA} %{GREEDYDATA}
1858? %{DATA}=%{DATA} %{GREEDYDATA}
1859    }
1860}
1861"#;
1862
1863        let script = parse_script(ScriptFile::new("test.cli"), script)?;
1864        assert_eq!(script.commands.len(), 3);
1865        eprintln!("{script:?}");
1866        Ok(())
1867    }
1868
1869    #[test]
1870    fn test_bad_script() -> Result<(), Box<dyn Error>> {
1871        let script = r#"
1872$ (cmd; cmd)
1873$ cmd &
1874    "#;
1875
1876        assert!(matches!(
1877            parse_script(ScriptFile::new("test.cli"), script),
1878            Err(ScriptError {
1879                error: ScriptErrorType::BackgroundProcessNotAllowed,
1880                ..
1881            })
1882        ));
1883        Ok(())
1884    }
1885
1886    #[test]
1887    fn test_script_run_context_expand() {
1888        let mut context = ScriptEnv::default();
1889        context.set_env("A", "1");
1890        context.set_env("B", "2");
1891        context.set_env("C", "3");
1892        assert_eq!(context.expand_str("$A").unwrap(), "1".to_string());
1893        assert_eq!(context.expand_str("$A $B ").unwrap(), "1 2 ".to_string());
1894        assert_eq!(
1895            context.expand_str("${A} ${B} ").unwrap(),
1896            "1 2 ".to_string()
1897        );
1898        assert_eq!(context.expand_str(r#"\$A"#).unwrap(), "$A".to_string());
1899        assert_eq!(context.expand_str(r#"\${A}"#).unwrap(), "${A}".to_string());
1900        assert_eq!(context.expand_str(r#"\\$A"#).unwrap(), r#"\1"#);
1901        assert_eq!(context.expand_str(r#"\\${A}"#).unwrap(), r#"\1"#);
1902        context.set_env("TEMP_DIR", "/tmp");
1903        assert_eq!(context.expand_str("$TEMP_DIR").unwrap(), "/tmp".to_string());
1904        assert_eq!(
1905            context.expand_str("${TEMP_DIR}").unwrap(),
1906            "/tmp".to_string()
1907        );
1908    }
1909}