Skip to main content

clitest_lib/
script.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    path::Path,
4    process::ExitStatus,
5    sync::{Arc, Mutex, atomic::AtomicBool},
6    thread::ScopedJoinHandle,
7    time::{Duration, Instant},
8};
9
10use grok::Grok;
11use keepcalm::SharedMut;
12use serde::{Serialize, ser::SerializeMap};
13use termcolor::{Color, ColorChoice, WriteColor};
14
15use crate::{
16    command::{CommandLine, CommandResult},
17    util::{NicePathBuf, NiceTempDir},
18};
19use crate::{cwrite, cwriteln, cwriteln_rule};
20use crate::{output::*, util::ShellBit};
21
22const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
23
24#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
25pub struct ScriptLocation {
26    pub file: ScriptFile,
27    pub line: usize,
28}
29
30impl ScriptLocation {
31    pub fn new(file: ScriptFile, line: usize) -> Self {
32        Self { file, line }
33    }
34}
35
36impl std::fmt::Display for ScriptLocation {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        write!(f, "{}:{}", self.file, self.line)
39    }
40}
41
42#[derive(
43    derive_more::Debug, derive_more::Display, Clone, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash,
44)]
45#[display("{}", file)]
46pub struct ScriptFile {
47    pub file: Arc<NicePathBuf>,
48}
49
50impl ScriptFile {
51    pub fn new(file: impl AsRef<Path>) -> Self {
52        Self {
53            file: Arc::new(NicePathBuf::new(file)),
54        }
55    }
56}
57
58impl<T: AsRef<Path>> From<T> for ScriptFile {
59    fn from(file: T) -> Self {
60        Self::new(file)
61    }
62}
63
64#[derive(Clone, derive_more::Debug, Serialize)]
65pub struct Script {
66    pub commands: Arc<Vec<ScriptBlock>>,
67    pub includes: Arc<HashMap<String, Script>>,
68    pub file: ScriptFile,
69}
70
71#[derive(Debug, Clone, Default)]
72pub struct ScriptRunArgs {
73    pub delay_steps: Option<u64>,
74    pub ignore_exit_codes: bool,
75    pub ignore_matches: bool,
76    pub simplified_output: bool,
77    pub show_line_numbers: bool,
78    pub runner: Option<String>,
79    pub quiet: bool,
80    pub verbose: bool,
81    pub global_timeout: Option<Duration>,
82    pub no_color: bool,
83}
84
85#[derive(Debug, Clone, Default)]
86pub struct ScriptEnv {
87    env_vars: HashMap<String, String>,
88}
89
90impl ScriptEnv {
91    pub fn set_defaults(&mut self, pwd: impl AsRef<Path>) {
92        macro_rules! target {
93            ($env:ident, $var:ident, [$($vals:expr),*]) => {
94                $(
95                if cfg!($var = $vals) {
96                    self.env_vars.insert(stringify!($env).to_string(), $vals.to_string());
97                }
98                )*
99            };
100        }
101
102        target!(
103            TARGET_OS,
104            target_os,
105            [
106                "windows",
107                "linux",
108                "macos",
109                "ios",
110                "android",
111                "freebsd",
112                "netbsd",
113                "openbsd",
114                "dragonfly",
115                "haiku",
116                "aix"
117            ]
118        );
119        target!(TARGET_FAMILY, target_family, ["windows", "unix", "wasm"]);
120        target!(
121            TARGET_ARCH,
122            target_arch,
123            [
124                "aarch64",
125                "amdgpu",
126                "arm",
127                "arm64ec",
128                "avr",
129                "bpf",
130                "csky",
131                "hexagon",
132                "loongarch32",
133                "loongarch64",
134                "m68k",
135                "mips",
136                "mips32r6",
137                "mips64",
138                "mips64r6",
139                "msp430",
140                "nvptx64",
141                "powerpc",
142                "powerpc64",
143                "riscv32",
144                "riscv64",
145                "s390x",
146                "sparc",
147                "sparc64",
148                "wasm32",
149                "wasm64",
150                "x86",
151                "x86_64",
152                "xtensa"
153            ]
154        );
155
156        // Set the current working directory as a special variable "PWD"
157        let pwd = NicePathBuf::from(pwd.as_ref()).env_string();
158        self.env_vars.insert("PWD".to_string(), pwd);
159        // Save the initial PWD as INITIAL_PWD so it can easily be restored
160        self.env_vars
161            .insert("INITIAL_PWD".to_string(), self.env_vars["PWD"].clone());
162    }
163
164    pub fn pwd(&self) -> NicePathBuf {
165        self.env_vars
166            .get("PWD")
167            .cloned()
168            .map(NicePathBuf::from)
169            .unwrap_or_else(NicePathBuf::cwd)
170    }
171
172    pub fn get_env(&self, name: &str) -> Option<&str> {
173        self.env_vars.get(name).map(|s| s.as_str())
174    }
175
176    pub fn set_env(&mut self, name: impl Into<String>, value: impl Into<String>) {
177        let name = name.into();
178        if name == "PWD" {
179            self.set_pwd(value.into());
180        } else {
181            self.env_vars.insert(name, value.into());
182        }
183    }
184
185    pub fn set_pwd(&mut self, pwd: impl Into<NicePathBuf>) {
186        let pwd = pwd.into().env_string();
187        self.env_vars.insert("PWD".to_string(), pwd);
188    }
189
190    pub fn expand(&self, value: &ShellBit) -> Result<String, ScriptRunError> {
191        match value {
192            ShellBit::Literal(s) => Ok(s.clone()),
193            ShellBit::Quoted(s) => self.expand_str(s),
194        }
195    }
196
197    /// Perform shell expansion on a string.
198    pub fn expand_str(&self, value: impl AsRef<str>) -> Result<String, ScriptRunError> {
199        enum State {
200            Normal,
201            EscapeNext,
202            InCurly,
203            Dollar,
204            InDollar,
205        }
206
207        let value = value.as_ref();
208
209        // "\" triggers escaping
210        // ${A} expands to the value of A
211        // $A expands to the value of A (variable ends on first non-alphanumeric character)
212
213        let mut state = State::Normal;
214        let mut variable = String::new();
215        let mut expanded = String::new();
216
217        for c in value.chars() {
218            match state {
219                State::Normal => {
220                    if c == '$' {
221                        state = State::Dollar;
222                        continue;
223                    }
224                    if c == '\\' {
225                        state = State::EscapeNext;
226                        continue;
227                    }
228                    expanded.push(c);
229                }
230                State::EscapeNext => {
231                    expanded.push(c);
232                    state = State::Normal;
233                }
234                State::InCurly => {
235                    if c == '}' {
236                        if let Some(value) = self.get_env(&std::mem::take(&mut variable)) {
237                            expanded.push_str(value);
238                        } else {
239                            return Err(ScriptRunError::ExpansionError(format!(
240                                "undefined variable in ${{...}}: {variable:?} (in {value:?})"
241                            )));
242                        }
243                        state = State::Normal;
244                    } else {
245                        variable.push(c);
246                    }
247                }
248                State::Dollar => {
249                    if c.is_alphanumeric() || c == '_' {
250                        state = State::InDollar;
251                        variable.push(c);
252                    } else if c == '{' {
253                        state = State::InCurly;
254                    } else {
255                        return Err(ScriptRunError::ExpansionError(format!(
256                            "invalid variable: {c:?} (in {value:?})"
257                        )));
258                    }
259                }
260                State::InDollar => {
261                    if c.is_alphanumeric() || c == '_' {
262                        variable.push(c);
263                    } else {
264                        if let Some(value) = self.get_env(&std::mem::take(&mut variable)) {
265                            expanded.push_str(value);
266                        } else {
267                            return Err(ScriptRunError::ExpansionError(format!(
268                                "undefined variable in $...: {variable:?} (in {value:?})"
269                            )));
270                        }
271                        expanded.push(c);
272                        state = State::Normal;
273                    }
274                }
275            }
276        }
277        match state {
278            State::InDollar => {
279                if let Some(value) = self.get_env(&variable) {
280                    expanded.push_str(value);
281                } else {
282                    return Err(ScriptRunError::ExpansionError(format!(
283                        "undefined variable: {variable}"
284                    )));
285                }
286            }
287            State::Dollar => {
288                return Err(ScriptRunError::ExpansionError(
289                    "incomplete variable".to_string(),
290                ));
291            }
292            State::InCurly => {
293                return Err(ScriptRunError::ExpansionError(format!(
294                    "unclosed variable: {variable}"
295                )));
296            }
297            State::Normal => {}
298            State::EscapeNext => {
299                return Err(ScriptRunError::ExpansionError(
300                    "unclosed backslash".to_string(),
301                ));
302            }
303        }
304        Ok(expanded)
305    }
306
307    pub fn env_vars(&self) -> &HashMap<String, String> {
308        &self.env_vars
309    }
310}
311
312#[derive(derive_more::Debug, Clone)]
313pub struct ScriptOutput {
314    #[debug(skip)]
315    stream: SharedMut<Box<dyn WriteColorAny>>,
316}
317
318trait WriteColorAny: WriteColor + Send + Sync + std::any::Any + 'static + std::fmt::Debug {
319    /// Workaround for lack of upcasting
320    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String>;
321    fn clone_buffer(&self) -> Result<termcolor::Buffer, String>;
322}
323
324impl WriteColorAny for termcolor::StandardStream {
325    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String> {
326        Err("not a buffer".to_string())
327    }
328    fn clone_buffer(&self) -> Result<termcolor::Buffer, String> {
329        Err("not a buffer".to_string())
330    }
331}
332
333impl WriteColorAny for termcolor::Buffer {
334    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String> {
335        Ok(*self)
336    }
337    fn clone_buffer(&self) -> Result<termcolor::Buffer, String> {
338        Ok(self.clone())
339    }
340}
341
342impl ScriptOutput {
343    pub fn no_color() -> Self {
344        let stm = termcolor::StandardStream::stdout(ColorChoice::Never);
345        Self {
346            stream: SharedMut::new(Box::new(stm) as _),
347        }
348    }
349
350    pub fn quiet(no_color: bool) -> Self {
351        let stm = if no_color {
352            termcolor::Buffer::no_color()
353        } else {
354            termcolor::Buffer::ansi()
355        };
356        Self {
357            stream: SharedMut::new(Box::new(stm) as _),
358        }
359    }
360
361    pub fn take_buffer(self) -> String {
362        let stream = match SharedMut::try_unwrap(self.stream) {
363            Ok(stream) => stream.take_buffer().expect("wrong stream type"),
364            Err(shared) => shared.read().clone_buffer().expect("wrong stream type"),
365        };
366        String::from_utf8_lossy(&stream.into_inner()).to_string()
367    }
368}
369
370impl Default for ScriptOutput {
371    fn default() -> Self {
372        let stm = termcolor::StandardStream::stdout(ColorChoice::Auto);
373        Self {
374            stream: SharedMut::new(Box::new(stm) as _),
375        }
376    }
377}
378
379impl std::io::Write for ScriptOutputLock<'_> {
380    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
381        self.stream.write(buf)
382    }
383    fn flush(&mut self) -> std::io::Result<()> {
384        self.stream.flush()
385    }
386}
387
388impl termcolor::WriteColor for ScriptOutputLock<'_> {
389    fn supports_color(&self) -> bool {
390        self.stream.supports_color()
391    }
392    fn set_color(&mut self, spec: &termcolor::ColorSpec) -> std::io::Result<()> {
393        self.stream.set_color(spec)
394    }
395    fn reset(&mut self) -> std::io::Result<()> {
396        self.stream.reset()
397    }
398    fn is_synchronous(&self) -> bool {
399        self.stream.is_synchronous()
400    }
401    fn set_hyperlink(&mut self, _link: &termcolor::HyperlinkSpec) -> std::io::Result<()> {
402        self.stream.set_hyperlink(_link)
403    }
404    fn supports_hyperlinks(&self) -> bool {
405        self.stream.supports_hyperlinks()
406    }
407}
408
409struct ScriptOutputLock<'a> {
410    stream: keepcalm::SharedWriteLock<'a, Box<dyn WriteColorAny>>,
411}
412
413#[derive(Debug, Clone, Copy, PartialEq, Eq)]
414enum ScriptMode {
415    Normal,
416    Deferred,
417    Background,
418}
419
420#[derive(derive_more::Debug)]
421pub struct ScriptRunContext {
422    pub args: ScriptRunArgs,
423    pub grok: Grok,
424    timeout: Duration,
425    env: ScriptEnv,
426    includes: Arc<HashMap<String, Script>>,
427    background: ScriptMode,
428    #[debug(skip)]
429    kill: ScriptKillReceiver,
430    #[debug(skip)]
431    kill_sender: ScriptKillSender,
432    output: ScriptOutput,
433
434    global_ignore: OutputPatterns,
435    global_reject: OutputPatterns,
436}
437
438impl Default for ScriptRunContext {
439    fn default() -> Self {
440        let kill = Arc::new(AtomicBool::new(false));
441        Self {
442            args: ScriptRunArgs::default(),
443            grok: Grok::with_default_patterns(),
444            timeout: DEFAULT_TIMEOUT,
445            env: ScriptEnv::default(),
446            background: ScriptMode::Normal,
447            includes: Arc::new(HashMap::new()),
448            kill: ScriptKillReceiver::new(kill.clone()),
449            kill_sender: ScriptKillSender::new(kill.clone()),
450            output: ScriptOutput::default(),
451            global_ignore: OutputPatterns::default(),
452            global_reject: OutputPatterns::default(),
453        }
454    }
455}
456
457impl ScriptRunContext {
458    pub fn new_background(&self) -> Self {
459        let kill = Arc::new(AtomicBool::new(false));
460        Self {
461            args: self.args.clone(),
462            grok: self.grok.clone(),
463            // Background processes are not subject to timeouts
464            timeout: Duration::MAX,
465            env: self.env.clone(),
466            background: ScriptMode::Background,
467            kill: ScriptKillReceiver::new(kill.clone()),
468            kill_sender: ScriptKillSender::new(kill.clone()),
469            includes: self.includes.clone(),
470            output: if self.args.verbose {
471                self.output.clone()
472            } else {
473                ScriptOutput::quiet(self.args.no_color)
474            },
475            global_ignore: self.global_ignore.clone(),
476            global_reject: self.global_reject.clone(),
477        }
478    }
479
480    pub fn new_deferred(&self) -> Self {
481        Self {
482            args: self.args.clone(),
483            grok: self.grok.clone(),
484            timeout: self.timeout,
485            env: self.env.clone(),
486            background: ScriptMode::Deferred,
487            kill: self.kill.clone(),
488            kill_sender: self.kill_sender.clone(),
489            includes: self.includes.clone(),
490            output: self.output.clone(),
491            global_ignore: self.global_ignore.clone(),
492            global_reject: self.global_reject.clone(),
493        }
494    }
495
496    pub fn pwd(&self) -> NicePathBuf {
497        self.env.pwd()
498    }
499
500    pub fn get_env(&self, name: &str) -> Option<&str> {
501        self.env.get_env(name)
502    }
503
504    pub fn set_env(&mut self, name: impl Into<String>, value: impl Into<String>) {
505        self.env.set_env(name, value);
506    }
507
508    pub fn set_pwd(&mut self, pwd: impl Into<NicePathBuf>) {
509        self.env.set_pwd(pwd);
510    }
511
512    pub fn take_output(self) -> String {
513        self.output.take_buffer()
514    }
515
516    fn expand(&self, value: &ShellBit) -> Result<String, ScriptRunError> {
517        self.env.expand(value)
518    }
519
520    /// Get a mutable reference to the output stream.
521    pub fn stream(&self) -> impl termcolor::WriteColor + use<'_> {
522        ScriptOutputLock {
523            stream: self.output.stream.write(),
524        }
525    }
526}
527
528#[derive(Clone)]
529pub struct ScriptKillReceiver {
530    kill_receiver: Arc<AtomicBool>,
531}
532
533impl ScriptKillReceiver {
534    pub fn new(kill_receiver: Arc<AtomicBool>) -> Self {
535        Self { kill_receiver }
536    }
537
538    pub fn is_killed(&self) -> bool {
539        self.kill_receiver.load(std::sync::atomic::Ordering::SeqCst)
540    }
541
542    pub fn run_with<T>(&self, kill: impl FnOnce() + Send, wait: impl FnOnce() -> T) -> T {
543        std::thread::scope(|s| {
544            let done = Arc::new(AtomicBool::new(false));
545            let done_clone = done.clone();
546            let t = s.spawn(move || {
547                while !done_clone.load(std::sync::atomic::Ordering::SeqCst) {
548                    if self.is_killed() {
549                        kill();
550                        break;
551                    }
552                    std::thread::sleep(Duration::from_millis(10));
553                }
554            });
555            let res = wait();
556            done.store(true, std::sync::atomic::Ordering::SeqCst);
557            t.join().unwrap();
558            res
559        })
560    }
561
562    #[cfg(windows)]
563    pub fn run_cmd(
564        &self,
565        output: std::process::Child,
566        warn_time: Duration,
567    ) -> std::io::Result<ExitStatus> {
568        use std::os::windows::io::AsRawHandle;
569        use win32job::Job;
570
571        fn map_job_error(e: win32job::JobError) -> std::io::Error {
572            match e {
573                win32job::JobError::AssignFailed(e) => e,
574                win32job::JobError::CreateFailed(e) => e,
575                win32job::JobError::GetInfoFailed(e) => e,
576                win32job::JobError::SetInfoFailed(e) => e,
577                _ => std::io::Error::new(std::io::ErrorKind::Other, "Unknown error"),
578            }
579        }
580
581        // Create a new Job object
582        let job = Job::create().map_err(map_job_error)?;
583
584        // Configure the job to terminate all child processes when the job is closed
585        let mut info = job.query_extended_limit_info().map_err(map_job_error)?;
586        info.limit_kill_on_job_close();
587        job.set_extended_limit_info(&info).map_err(map_job_error)?;
588        job.assign_process(output.as_raw_handle() as _)?;
589
590        // Resume the main thread for the process
591        let id = output.id();
592        for thread_entry in tlhelp32::Snapshot::new_thread()? {
593            if thread_entry.owner_process_id == id {
594                use windows_sys::Win32::Foundation::CloseHandle;
595                use windows_sys::Win32::System::Threading::*;
596
597                unsafe {
598                    let thread = OpenThread(THREAD_SUSPEND_RESUME, 0, thread_entry.thread_id);
599                    if thread.is_null() {
600                        return Err(std::io::Error::last_os_error().into());
601                    }
602                    ResumeThread(thread);
603                    CloseHandle(thread);
604                }
605            }
606        }
607
608        let job = Mutex::new(Some(job));
609        let output = Mutex::new(output);
610        self.run_with(
611            || {
612                _ = job.lock().unwrap().take();
613                _ = output.lock().unwrap().kill();
614            },
615            || {
616                let start = std::time::Instant::now();
617                let mut warned = false;
618                loop {
619                    let res = output.lock().unwrap().try_wait()?;
620                    if let Some(status) = res {
621                        return Ok::<_, std::io::Error>(status);
622                    }
623                    if start.elapsed() > warn_time {
624                        if !warned {
625                            let child = output.lock().unwrap().id();
626                            eprintln!("Process #{child} taking too long to finish.");
627                            warned = true;
628                        }
629                    }
630                    std::thread::sleep(Duration::from_millis(10));
631                }
632            },
633        )
634    }
635
636    #[cfg(unix)]
637    pub fn run_cmd(
638        &self,
639        output: std::process::Child,
640        warn_time: Duration,
641    ) -> std::io::Result<ExitStatus> {
642        let output = Mutex::new(output);
643        self.run_with(
644            || {
645                use signal_child::{signal, signal::*};
646                let id = output.lock().unwrap().id() as i32;
647                _ = signal(-id, SIGINT);
648                std::thread::sleep(Duration::from_millis(10));
649                _ = signal(-id, SIGTERM);
650            },
651            || {
652                let start = std::time::Instant::now();
653                let mut warned = false;
654                loop {
655                    let res = output.lock().unwrap().try_wait()?;
656                    if let Some(status) = res {
657                        return Ok::<_, std::io::Error>(status);
658                    }
659                    if start.elapsed() > warn_time && !warned {
660                        let child = output.lock().unwrap().id();
661                        eprintln!("Process #{child} taking too long to finish.");
662                        warned = true;
663                    }
664                    std::thread::sleep(Duration::from_millis(10));
665                }
666            },
667        )
668    }
669}
670
671#[derive(Clone)]
672pub struct ScriptKillSender {
673    kill_sender: Arc<AtomicBool>,
674}
675
676impl ScriptKillSender {
677    pub fn new(kill_sender: Arc<AtomicBool>) -> Self {
678        Self { kill_sender }
679    }
680
681    pub fn kill(&self) {
682        self.kill_sender
683            .store(true, std::sync::atomic::Ordering::SeqCst);
684    }
685}
686
687impl ScriptRunContext {
688    pub fn new(args: ScriptRunArgs, script_path: impl AsRef<Path>, output: ScriptOutput) -> Self {
689        let mut env = ScriptEnv::default();
690        env.set_defaults(script_path.as_ref().parent().unwrap());
691
692        let kill = Arc::new(AtomicBool::new(false));
693
694        Self {
695            timeout: args.global_timeout.unwrap_or(DEFAULT_TIMEOUT),
696            args,
697            env,
698            grok: Grok::with_default_patterns(),
699            includes: Arc::new(HashMap::new()),
700            background: ScriptMode::Normal,
701            kill: ScriptKillReceiver::new(kill.clone()),
702            kill_sender: ScriptKillSender::new(kill.clone()),
703            output,
704            global_ignore: OutputPatterns::default(),
705            global_reject: OutputPatterns::default(),
706        }
707    }
708}
709
710#[derive(Clone, Debug, PartialEq, Eq)]
711pub struct ScriptLine {
712    pub location: ScriptLocation,
713    text: String,
714}
715
716impl ScriptLine {
717    pub fn new(file: ScriptFile, line: usize, text: impl AsRef<str>) -> Self {
718        Self {
719            location: ScriptLocation::new(file, line),
720            text: text.as_ref().to_string(),
721        }
722    }
723
724    pub fn parse(file: ScriptFile, text: impl AsRef<str>) -> Vec<Self> {
725        text.as_ref()
726            .lines()
727            .enumerate()
728            .map(|(line, text)| Self {
729                location: ScriptLocation::new(file.clone(), line + 1),
730                text: text.to_string(),
731            })
732            .collect()
733    }
734
735    pub fn starts_with(&self, text: &str) -> bool {
736        self.text.trim().starts_with(text)
737    }
738
739    pub fn first_char(&self) -> Option<char> {
740        self.text.trim().chars().next()
741    }
742
743    pub fn text(&self) -> &str {
744        self.text.trim()
745    }
746
747    pub fn text_untrimmed(&self) -> &str {
748        &self.text
749    }
750
751    pub fn is_empty(&self) -> bool {
752        self.text.trim().is_empty()
753    }
754
755    pub fn strip_prefix(&self, prefix: &str) -> Option<&str> {
756        self.text.strip_prefix(prefix)
757    }
758}
759
760#[derive(Debug, thiserror::Error, derive_more::Display)]
761#[display("{error} at {location}{}", associated_data.as_deref().map_or("".to_string(), |d| format!(": {d}")))]
762pub struct ScriptError {
763    pub error: ScriptErrorType,
764    pub location: ScriptLocation,
765    pub associated_data: Option<String>,
766}
767
768impl ScriptError {
769    pub fn new(error: ScriptErrorType, location: ScriptLocation) -> Self {
770        if std::env::var("PANIC_ON_ERROR").is_ok() {
771            panic!("ScriptError: {error} at {location}");
772        }
773        Self {
774            error,
775            location,
776            associated_data: None,
777        }
778    }
779
780    pub fn new_with_data(
781        error: ScriptErrorType,
782        location: ScriptLocation,
783        associated_data: String,
784    ) -> Self {
785        if std::env::var("PANIC_ON_ERROR").is_ok() {
786            panic!("ScriptError: {error} at {location}: {associated_data}");
787        }
788        Self {
789            error,
790            location,
791            associated_data: Some(associated_data),
792        }
793    }
794}
795
796#[derive(Debug, thiserror::Error, Eq, PartialEq)]
797pub enum ScriptErrorType {
798    #[error("background process not allowed")]
799    BackgroundProcessNotAllowed,
800    #[error("unclosed quote")]
801    UnclosedQuote,
802    #[error("unclosed backslash")]
803    UnclosedBackslash,
804    #[error("illegal shell command format")]
805    IllegalShellCommand,
806    #[error("unsupported redirection")]
807    UnsupportedRedirection,
808    #[error("invalid pattern definition")]
809    InvalidPatternDefinition,
810    #[error("invalid pattern")]
811    InvalidPattern,
812    #[error("invalid meta command")]
813    InvalidMetaCommand,
814    #[error("invalid pattern at global level (only reject or ignore allowed here)")]
815    InvalidGlobalPattern,
816    #[error("invalid block type")]
817    InvalidBlockType,
818    #[error("invalid block arguments")]
819    InvalidBlockArgs,
820    #[error("unsupported command position")]
821    UnsupportedCommandPosition,
822    #[error("invalid trailing pattern after *")]
823    InvalidAnyPattern,
824    #[error("invalid exit status")]
825    InvalidExitStatus,
826    #[error("invalid set variable")]
827    InvalidSetVariable,
828    #[error("invalid version header, expected `#!/usr/bin/env clitest --v0`")]
829    InvalidVersion,
830    #[error("invalid internal command")]
831    InvalidInternalCommand,
832    #[error("missing command lines")]
833    MissingCommandLines,
834    #[error(
835        "block end without matching block start, too many closing braces or braces not properly nested"
836    )]
837    InvalidBlockEnd,
838    #[error("invalid if condition")]
839    InvalidIfCondition,
840    #[error("expected block or semi-colon (did you forget to add ';' at the end of this line?)")]
841    ExpectedBlockOrSemi,
842}
843
844#[derive(Debug, thiserror::Error)]
845pub enum ScriptRunError {
846    #[error("{0}")]
847    Pattern(#[from] OutputPatternMatchFailure),
848    #[error("{0}")]
849    PatternPrepareError(#[from] OutputPatternPrepareError),
850    #[error("{0} at line {1}")]
851    Exit(CommandResult, ScriptLocation),
852    #[error("included file not found: {0}")]
853    IncludedFileNotFound(String),
854    #[error("expected failure, but passed at line {0}")]
855    ExpectedFailure(ScriptLocation),
856    #[error("{0}")]
857    ExpansionError(String),
858    #[error("{0}")]
859    IO(#[from] std::io::Error),
860    #[error("killed")]
861    Killed,
862    #[error("background process took too long to finish")]
863    BackgroundProcessTookTooLong,
864    #[error("retry took too long to finish")]
865    RetryTookTooLong,
866    /// Internal flow control: exit the script
867    #[error("exiting script")]
868    ExitScript,
869}
870
871impl ScriptRunError {
872    #[expect(unused)]
873    pub fn short(&self) -> String {
874        match self {
875            Self::Pattern(_) => "Pattern".to_string(),
876            Self::PatternPrepareError(e) => format!("PatternPrepareError({e:?})"),
877            Self::Exit(status, _) => format!("Exit({status})"),
878            Self::ExpectedFailure(_) => "ExpectedFailure".to_string(),
879            Self::IO(e) => format!("IO({:?})", e.kind()),
880            Self::Killed => "Killed".to_string(),
881            Self::BackgroundProcessTookTooLong => "BackgroundProcessTookTooLong".to_string(),
882            Self::ExpansionError(e) => "ExpansionError".to_string(),
883            Self::RetryTookTooLong => "RetryTookTooLong".to_string(),
884            Self::ExitScript => unreachable!(),
885            Self::IncludedFileNotFound(path) => format!("IncludedFileNotFound({path})"),
886        }
887    }
888}
889
890impl Script {
891    pub fn new(file: ScriptFile) -> Self {
892        Self {
893            commands: Arc::new(vec![]),
894            includes: Arc::new(HashMap::new()),
895            file,
896        }
897    }
898
899    /// Collect all included script paths from the script.
900    pub fn includes(&self) -> Vec<(ScriptLocation, String)> {
901        self.commands
902            .iter()
903            .flat_map(|block| block.includes())
904            .collect()
905    }
906
907    pub fn run(&self, context: &mut ScriptRunContext) -> Result<(), ScriptRunError> {
908        let old_includes = context.includes.clone();
909        context.includes = self.includes.clone();
910        let res = ScriptBlock::run_blocks(context, &self.commands);
911        context.includes = old_includes;
912        let v = match res {
913            Ok(v) => v,
914            // Bypass normal script processing and exit successfully
915            Err(ScriptRunError::ExitScript) => return Ok(()),
916            Err(e) => return Err(e),
917        };
918        assert!(v.is_empty(), "script did not run to completion: {v:?}");
919        Ok(())
920    }
921
922    pub fn run_with_args(
923        &self,
924        args: ScriptRunArgs,
925        output: ScriptOutput,
926    ) -> Result<(), ScriptRunError> {
927        let start = Instant::now();
928        let script_path = &*self.file.file;
929        let mut context = ScriptRunContext::new(args, script_path, output);
930
931        // Write "Running..." message with colors
932        cwrite!(context.stream(), "Running ");
933        cwrite!(context.stream(), fg = Color::Cyan, "{}", script_path);
934        cwriteln!(context.stream(), " ...");
935        cwriteln!(context.stream());
936
937        let result = self.run(&mut context);
938
939        // Handle success and error output
940        if let Err(ref e) = result {
941            cwrite!(context.stream(), fg = Color::Cyan, "{} ", script_path);
942            cwrite!(context.stream(), fg = Color::Red, "FAILED");
943            if !context.args.simplified_output {
944                cwriteln!(context.stream(), " ({:.2}s)", start.elapsed().as_secs_f32());
945            } else {
946                cwriteln!(context.stream());
947            }
948            cwrite!(context.stream(), fg = Color::Red, "Error: ");
949            cwriteln!(context.stream(), "{}", e);
950            cwriteln!(context.stream());
951        } else {
952            cwrite!(context.stream(), fg = Color::Cyan, "{} ", script_path);
953            cwrite!(context.stream(), fg = Color::Green, "PASSED");
954            if !context.args.simplified_output {
955                cwriteln!(context.stream(), " ({:.2}s)", start.elapsed().as_secs_f32());
956            } else {
957                cwriteln!(context.stream());
958            }
959        }
960
961        result
962    }
963}
964
965#[derive(Debug, Default, Serialize)]
966pub enum CommandExit {
967    #[default]
968    Success,
969    Failure(i32),
970    Timeout,
971    Any,
972    AnyFailure,
973}
974
975impl CommandExit {
976    pub fn matches(&self, status: CommandResult) -> bool {
977        match (self, status) {
978            (CommandExit::Success, CommandResult::Exit(status)) => status.success(),
979            (CommandExit::Failure(code), CommandResult::Exit(status)) => {
980                *code == status.code().unwrap_or(-1)
981            }
982            (CommandExit::Timeout, CommandResult::TimedOut) => true,
983            (CommandExit::Any, _) => true,
984            (CommandExit::AnyFailure, CommandResult::Exit(status)) => !status.success(),
985            (CommandExit::AnyFailure, _) => true,
986            _ => false,
987        }
988    }
989
990    pub fn is_success(&self) -> bool {
991        matches!(self, CommandExit::Success)
992    }
993}
994
995#[derive(derive_more::Debug)]
996#[allow(clippy::large_enum_variant)]
997pub enum ScriptBlock {
998    Command(ScriptCommand),
999    InternalCommand(ScriptLocation, InternalCommand),
1000    Background(Vec<ScriptBlock>),
1001    Defer(Vec<ScriptBlock>),
1002    If(IfCondition, Vec<ScriptBlock>),
1003    For(ForCondition, Vec<ScriptBlock>),
1004    Retry(Vec<ScriptBlock>),
1005    GlobalIgnore(OutputPatterns),
1006    GlobalReject(OutputPatterns),
1007}
1008
1009impl ScriptBlock {
1010    pub fn includes(&self) -> Vec<(ScriptLocation, String)> {
1011        match self {
1012            ScriptBlock::Command(..) => vec![],
1013            ScriptBlock::InternalCommand(location, InternalCommand::Include(path)) => {
1014                vec![(location.clone(), path.clone())]
1015            }
1016            ScriptBlock::InternalCommand(..) => vec![],
1017            ScriptBlock::Background(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1018            ScriptBlock::Defer(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1019            ScriptBlock::If(_, blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1020            ScriptBlock::For(_, blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1021            ScriptBlock::Retry(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1022            ScriptBlock::GlobalIgnore(_) => vec![],
1023            ScriptBlock::GlobalReject(_) => vec![],
1024        }
1025    }
1026
1027    #[allow(clippy::type_complexity)]
1028    pub fn run_blocks(
1029        context: &mut ScriptRunContext,
1030        blocks: &[ScriptBlock],
1031    ) -> Result<Vec<ScriptResult>, ScriptRunError> {
1032        enum Deferred<'a> {
1033            Scripts(&'a [ScriptBlock]),
1034            Internal(
1035                Box<
1036                    dyn FnOnce(&mut ScriptRunContext) -> Result<(), ScriptRunError>
1037                        + Send
1038                        + Sync
1039                        + 'a,
1040                >,
1041            ),
1042            Background(
1043                ScopedJoinHandle<'a, Result<Vec<ScriptResult>, ScriptRunError>>,
1044                ScriptKillSender,
1045            ),
1046        }
1047
1048        let mut results = Vec::new();
1049        std::thread::scope(|s| {
1050            let mut defer_blocks = VecDeque::new();
1051            let mut pending_error = None;
1052            for block in blocks {
1053                if context.kill.is_killed() {
1054                    return Err(ScriptRunError::Killed);
1055                }
1056                match block {
1057                    ScriptBlock::Background(blocks) => {
1058                        let mut context = context.new_background();
1059                        let kill_sender = context.kill_sender.clone();
1060                        let handle = s.spawn(move || Self::run_blocks(&mut context, blocks));
1061                        defer_blocks.push_front(Deferred::Background(handle, kill_sender));
1062                    }
1063                    ScriptBlock::Defer(blocks) => {
1064                        // Insert at the front of the queue by extending and
1065                        // then rotating
1066                        defer_blocks.push_front(Deferred::Scripts(blocks));
1067                    }
1068                    ScriptBlock::InternalCommand(_, command) => {
1069                        if context.background == ScriptMode::Deferred {
1070                            cwrite!(context.stream(), dimmed = true, "(deferred) ");
1071                        }
1072                        if let Some(f) = command.run(context)? {
1073                            defer_blocks.push_front(Deferred::Internal(f));
1074                        }
1075                    }
1076                    _ => match block.run(context) {
1077                        Ok(res) => results.extend(res),
1078                        Err(e) => {
1079                            pending_error = Some(e);
1080                            break;
1081                        }
1082                    },
1083                }
1084            }
1085            for block in defer_blocks {
1086                match block {
1087                    Deferred::Scripts(blocks) => {
1088                        let mut context = context.new_deferred();
1089                        ScriptBlock::run_blocks(&mut context, blocks)?;
1090                    }
1091                    Deferred::Internal(block) => {
1092                        cwrite!(context.stream(), dimmed = true, "(cleanup) ");
1093                        block(context)?;
1094                    }
1095                    Deferred::Background(handle, kill_sender) => {
1096                        kill_sender.kill();
1097                        let start = std::time::Instant::now();
1098                        let mut warned = false;
1099
1100                        let timeout = context.timeout;
1101                        let warn_at = timeout * 8 / 10;
1102
1103                        let results = loop {
1104                            if handle.is_finished() {
1105                                break handle.join().unwrap()?;
1106                            }
1107                            std::thread::sleep(std::time::Duration::from_millis(10));
1108                            if !warned && start.elapsed() > warn_at {
1109                                cwriteln!(
1110                                    context.stream(),
1111                                    fg = Color::Yellow,
1112                                    "Background process is taking too long to finish."
1113                                );
1114                                warned = true;
1115                            }
1116                            if start.elapsed() > timeout {
1117                                cwriteln!(
1118                                    context.stream(),
1119                                    fg = Color::Red,
1120                                    "Background process took too long to finish."
1121                                );
1122                                return Err(ScriptRunError::BackgroundProcessTookTooLong);
1123                            }
1124                        };
1125                        for result in results {
1126                            cwrite!(context.stream(), dimmed = true, "(background) ");
1127                            for line in result.command.command.split('\n') {
1128                                cwriteln!(context.stream(), fg = Color::Green, "{}", line);
1129                            }
1130                            if context.args.simplified_output {
1131                                cwriteln!(context.stream(), dimmed = true, "---");
1132                            } else {
1133                                cwriteln_rule!(
1134                                    context.stream(),
1135                                    fg = Color::Cyan,
1136                                    "{}",
1137                                    result.command.location
1138                                );
1139                            }
1140                            for line in &result.output {
1141                                cwriteln!(context.stream(), "{}", line);
1142                            }
1143                            if result.output.is_empty() {
1144                                cwriteln!(context.stream(), dimmed = true, "(no output)");
1145                            }
1146                            if context.args.simplified_output {
1147                                cwriteln!(context.stream(), dimmed = true, "---");
1148                            } else {
1149                                cwriteln_rule!(context.stream());
1150                            }
1151                            result.evaluate(context)?;
1152                        }
1153                    }
1154                }
1155            }
1156            if let Some(error) = pending_error {
1157                return Err(error);
1158            }
1159            Ok(results)
1160        })
1161    }
1162
1163    pub fn run(&self, context: &mut ScriptRunContext) -> Result<Vec<ScriptResult>, ScriptRunError> {
1164        let pwd = context.pwd();
1165        let res = pwd.exists();
1166        if !matches!(res, Ok(true)) {
1167            cwriteln!(
1168                context.stream(),
1169                fg = Color::Red,
1170                "$PWD {pwd:?} doesn't exist. Run `cd $INITIAL_PWD` to fix.",
1171            );
1172            return Err(ScriptRunError::IO(std::io::Error::new(
1173                std::io::ErrorKind::NotFound,
1174                format!("PWD does not exist: {pwd:?}"),
1175            )));
1176        }
1177
1178        match self {
1179            ScriptBlock::Command(command) => {
1180                if context.background == ScriptMode::Deferred {
1181                    cwrite!(context.stream(), dimmed = true, "(deferred) ");
1182                }
1183                let result = command.run(context)?;
1184                if context.background != ScriptMode::Background {
1185                    result.evaluate(context)?;
1186                    Ok(vec![])
1187                } else {
1188                    Ok(vec![result])
1189                }
1190            }
1191            ScriptBlock::If(condition, blocks) => {
1192                let condition = condition.expand(context)?;
1193                if condition.matches(context) {
1194                    Self::run_blocks(context, blocks)
1195                } else {
1196                    Ok(vec![])
1197                }
1198            }
1199            ScriptBlock::For(ForCondition::Env(env, values), blocks) => {
1200                let mut results = Vec::new();
1201                for value in values {
1202                    context.set_env(env, context.expand(value)?);
1203                    results.extend(Self::run_blocks(context, blocks)?);
1204                }
1205                Ok(results)
1206            }
1207            ScriptBlock::Retry(blocks) => {
1208                let start = Instant::now();
1209                let mut backoff = Duration::from_millis(100);
1210
1211                cwrite!(context.stream(), fg = Color::Green, "retry: ");
1212                cwriteln!(context.stream(), "running...");
1213
1214                loop {
1215                    let mut nested_context = context.new_background();
1216                    if let Ok(results) = Self::run_blocks(&mut nested_context, blocks) {
1217                        let mut all_ok = true;
1218                        for result in results {
1219                            if result.evaluate(&mut nested_context).is_err() {
1220                                all_ok = false;
1221                                break;
1222                            }
1223                        }
1224                        if all_ok {
1225                            let output = nested_context.take_output();
1226                            cwrite!(context.stream(), fg = Color::Green, "retry: ");
1227                            cwriteln!(context.stream(), "success");
1228                            cwriteln!(context.stream());
1229                            cwriteln!(context.stream(), "{output}");
1230                            return Ok(vec![]);
1231                        }
1232                    }
1233
1234                    if start.elapsed() > context.timeout {
1235                        let output = nested_context.take_output();
1236                        cwrite!(context.stream(), fg = Color::Green, "retry: ");
1237                        cwriteln!(context.stream(), fg = Color::Red, "timed out");
1238                        cwriteln!(context.stream());
1239                        cwriteln!(context.stream(), "{output}");
1240                        cwriteln_rule!(context.stream());
1241                        return Err(ScriptRunError::RetryTookTooLong);
1242                    }
1243                    std::thread::sleep(backoff);
1244                    backoff *= 2;
1245                }
1246            }
1247            ScriptBlock::GlobalIgnore(patterns) => {
1248                for pattern in patterns.iter() {
1249                    pattern.prepare(&context.grok)?;
1250                }
1251                context.global_ignore.extend(patterns);
1252                Ok(vec![])
1253            }
1254            ScriptBlock::GlobalReject(patterns) => {
1255                for pattern in patterns.iter() {
1256                    pattern.prepare(&context.grok)?;
1257                }
1258                context.global_reject.extend(patterns);
1259                Ok(vec![])
1260            }
1261            _ => unreachable!("Unexpected block type: {self:?}"),
1262        }
1263    }
1264}
1265
1266impl Serialize for ScriptBlock {
1267    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1268    where
1269        S: serde::Serializer,
1270    {
1271        match self {
1272            ScriptBlock::Command(command) => command.serialize(serializer),
1273            ScriptBlock::InternalCommand(_, command) => command.serialize(serializer),
1274            ScriptBlock::Background(blocks) => {
1275                let mut ser = serializer.serialize_map(Some(1))?;
1276                ser.serialize_entry("background", blocks)?;
1277                ser.end()
1278            }
1279            ScriptBlock::Defer(blocks) => {
1280                let mut ser = serializer.serialize_map(Some(1))?;
1281                ser.serialize_entry("defer", blocks)?;
1282                ser.end()
1283            }
1284            ScriptBlock::If(condition, blocks) => {
1285                let mut ser = serializer.serialize_map(Some(2))?;
1286                ser.serialize_entry("if", condition)?;
1287                ser.serialize_entry("blocks", blocks)?;
1288                ser.end()
1289            }
1290            ScriptBlock::For(condition, blocks) => {
1291                let mut ser = serializer.serialize_map(Some(2))?;
1292                ser.serialize_entry("for", condition)?;
1293                ser.serialize_entry("blocks", blocks)?;
1294                ser.end()
1295            }
1296            ScriptBlock::Retry(blocks) => {
1297                let mut ser = serializer.serialize_map(Some(1))?;
1298                ser.serialize_entry("retry", blocks)?;
1299                ser.end()
1300            }
1301            ScriptBlock::GlobalIgnore(patterns) => {
1302                let mut ser = serializer.serialize_map(Some(1))?;
1303                ser.serialize_entry("ignore", patterns)?;
1304                ser.end()
1305            }
1306            ScriptBlock::GlobalReject(patterns) => {
1307                let mut ser = serializer.serialize_map(Some(1))?;
1308                ser.serialize_entry("reject", patterns)?;
1309                ser.end()
1310            }
1311        }
1312    }
1313}
1314
1315#[derive(Debug, Clone, Serialize)]
1316pub enum InternalCommand {
1317    UsingTempdir,
1318    UsingDir(ShellBit, bool),
1319    ChangeDir(ShellBit),
1320    Set(String, ShellBit),
1321    Include(String),
1322    ExitScript,
1323    Pattern(String, String),
1324}
1325
1326impl InternalCommand {
1327    #[allow(clippy::type_complexity)]
1328    pub fn run(
1329        &self,
1330        context: &mut ScriptRunContext,
1331    ) -> Result<
1332        Option<Box<dyn FnOnce(&mut ScriptRunContext) -> Result<(), ScriptRunError> + Send + Sync>>,
1333        ScriptRunError,
1334    > {
1335        match self.clone() {
1336            InternalCommand::Include(path) => {
1337                let Some(script) = context.includes.get(&path) else {
1338                    return Err(ScriptRunError::IncludedFileNotFound(path));
1339                };
1340                script.clone().run(context)?;
1341                Ok(None)
1342            }
1343            InternalCommand::Pattern(name, pattern) => {
1344                context.grok.add_pattern(name, pattern);
1345                Ok(None)
1346            }
1347            InternalCommand::UsingTempdir => {
1348                let current_pwd = context.pwd();
1349                let tempdir = NiceTempDir::new();
1350                cwrite!(context.stream(), fg = Color::Yellow, "using tempdir: ");
1351                cwriteln!(context.stream(), "{}", tempdir);
1352                cwriteln!(context.stream());
1353                context.set_pwd(&tempdir);
1354                let pwd = context.pwd();
1355                if !pwd.exists()? {
1356                    return Err(ScriptRunError::IO(std::io::Error::new(
1357                        std::io::ErrorKind::NotFound,
1358                        format!("newly created tempdir does not exist: {pwd:?}"),
1359                    )));
1360                }
1361                Ok(Some(Box::new(move |context: &mut ScriptRunContext| {
1362                    cwriteln!(
1363                        context.stream(),
1364                        fg = Color::Yellow,
1365                        "removing {} && cd {}",
1366                        tempdir,
1367                        current_pwd
1368                    );
1369                    cwriteln!(context.stream());
1370                    if !tempdir.exists()? {
1371                        cwriteln!(
1372                            context.stream(),
1373                            fg = Color::Red,
1374                            "tempdir does not exist: {tempdir}"
1375                        );
1376                    }
1377                    if let Err(e) = tempdir.remove_dir_all() {
1378                        cwriteln!(
1379                            context.stream(),
1380                            fg = Color::Red,
1381                            "error removing tempdir: {e:?}"
1382                        );
1383                    }
1384                    Ok::<_, ScriptRunError>(())
1385                })))
1386            }
1387            InternalCommand::UsingDir(dir, new) => {
1388                let current_pwd = context.pwd();
1389                let dir = context.expand(&dir)?;
1390                let new_pwd = current_pwd.join(dir);
1391                if new {
1392                    cwrite!(context.stream(), fg = Color::Yellow, "using new dir: ");
1393                } else {
1394                    cwrite!(context.stream(), fg = Color::Yellow, "using dir: ");
1395                }
1396                cwriteln!(context.stream(), "{}", new_pwd);
1397                cwriteln!(context.stream());
1398
1399                if new {
1400                    new_pwd.create_dir_all()?;
1401                } else if !new_pwd.exists()? {
1402                    return Err(ScriptRunError::IO(std::io::Error::new(
1403                        std::io::ErrorKind::NotFound,
1404                        "directory does not exist",
1405                    )));
1406                }
1407                context.set_pwd(&new_pwd);
1408                Ok(Some(Box::new(move |context: &mut ScriptRunContext| {
1409                    if new {
1410                        cwriteln!(
1411                            context.stream(),
1412                            fg = Color::Yellow,
1413                            "removing {} && cd {}",
1414                            new_pwd,
1415                            current_pwd
1416                        );
1417                        cwriteln!(context.stream());
1418                    } else {
1419                        cwriteln!(context.stream(), fg = Color::Yellow, "cd {}", current_pwd);
1420                        cwriteln!(context.stream());
1421                    }
1422                    if new {
1423                        new_pwd.remove_dir_all()?;
1424                    }
1425                    context.set_pwd(current_pwd);
1426                    Ok::<_, ScriptRunError>(())
1427                })))
1428            }
1429            InternalCommand::ChangeDir(dir) => {
1430                let dir = context.expand(&dir)?;
1431
1432                cwriteln!(context.stream(), fg = Color::Yellow, "cd {dir}");
1433                cwriteln!(context.stream());
1434                let current_pwd = context.pwd();
1435                let new_pwd = current_pwd.join(dir);
1436                if !new_pwd.exists()? {
1437                    return Err(ScriptRunError::IO(std::io::Error::new(
1438                        std::io::ErrorKind::NotFound,
1439                        format!("directory does not exist: {new_pwd}"),
1440                    )));
1441                }
1442                context.set_pwd(new_pwd);
1443                Ok(None)
1444            }
1445            InternalCommand::Set(name, value) => {
1446                let value = context.expand(&value)?;
1447
1448                context.set_env(&name, &value);
1449                let new_value = context.get_env(&name).unwrap_or_default();
1450                if new_value != value {
1451                    cwriteln!(
1452                        context.stream(),
1453                        fg = Color::Yellow,
1454                        "set {name} {value} (-> {new_value})"
1455                    );
1456                } else {
1457                    cwriteln!(context.stream(), fg = Color::Yellow, "set {name} {value}");
1458                }
1459                cwriteln!(context.stream());
1460
1461                Ok(None)
1462            }
1463            InternalCommand::ExitScript => {
1464                cwriteln!(context.stream(), fg = Color::Yellow, "exiting script");
1465                cwriteln!(context.stream());
1466                Err(ScriptRunError::ExitScript)
1467            }
1468        }
1469    }
1470}
1471
1472#[derive(Debug, Clone)]
1473pub enum IfCondition {
1474    True,
1475    False,
1476    EnvEq(bool, String, ShellBit),
1477}
1478
1479impl IfCondition {
1480    pub fn matches(&self, context: &ScriptRunContext) -> bool {
1481        match self {
1482            IfCondition::True => true,
1483            IfCondition::False => false,
1484            IfCondition::EnvEq(negated, name, expected) => {
1485                let value = context.get_env(name).unwrap_or_default();
1486                (expected == value) ^ negated
1487            }
1488        }
1489    }
1490
1491    pub fn expand(&self, context: &ScriptRunContext) -> Result<IfCondition, ScriptRunError> {
1492        match self {
1493            IfCondition::True => Ok(IfCondition::True),
1494            IfCondition::False => Ok(IfCondition::False),
1495            IfCondition::EnvEq(negated, name, expected) => {
1496                let value = context.expand(expected)?;
1497                Ok(IfCondition::EnvEq(
1498                    *negated,
1499                    name.clone(),
1500                    ShellBit::Literal(value),
1501                ))
1502            }
1503        }
1504    }
1505}
1506
1507impl Serialize for IfCondition {
1508    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1509    where
1510        S: serde::Serializer,
1511    {
1512        match self {
1513            IfCondition::True => "true".serialize(serializer),
1514            IfCondition::False => "false".serialize(serializer),
1515            IfCondition::EnvEq(negated, name, value) => {
1516                let mut ser = serializer.serialize_map(Some(3))?;
1517                ser.serialize_entry("op", if *negated { "!=" } else { "==" })?;
1518                ser.serialize_entry("env", name)?;
1519                ser.serialize_entry("value", value)?;
1520                ser.end()
1521            }
1522        }
1523    }
1524}
1525
1526#[derive(Debug)]
1527pub enum ForCondition {
1528    Env(String, Vec<ShellBit>),
1529}
1530
1531impl Serialize for ForCondition {
1532    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1533    where
1534        S: serde::Serializer,
1535    {
1536        match self {
1537            ForCondition::Env(name, values) => {
1538                let mut ser = serializer.serialize_map(Some(2))?;
1539                ser.serialize_entry("env", name)?;
1540                ser.serialize_entry("values", values)?;
1541                ser.end()
1542            }
1543        }
1544    }
1545}
1546
1547fn is_bool_false(b: &bool) -> bool {
1548    !b
1549}
1550
1551#[derive(Debug, Serialize)]
1552pub struct ScriptCommand {
1553    pub command: CommandLine,
1554    pub pattern: OutputPattern,
1555
1556    #[serde(skip_serializing_if = "CommandExit::is_success")]
1557    pub exit: CommandExit,
1558
1559    #[serde(skip_serializing_if = "is_bool_false")]
1560    pub expect_failure: bool,
1561
1562    /// Single set variable (entire command output trimmed)
1563    #[serde(skip_serializing_if = "Option::is_none")]
1564    pub set_var: Option<String>,
1565
1566    /// Specific set variables
1567    pub set_vars: HashMap<String, ShellBit>,
1568
1569    /// Specific command timeout
1570    #[serde(skip_serializing_if = "Option::is_none")]
1571    pub timeout: Option<Duration>,
1572
1573    /// Input grok expectations
1574    pub expect: HashMap<String, ShellBit>,
1575}
1576
1577impl ScriptCommand {
1578    pub fn new(command: CommandLine) -> Self {
1579        let location = command.location.clone();
1580        Self {
1581            command,
1582            pattern: OutputPattern {
1583                pattern: OutputPatternType::None,
1584                ignore: Default::default(),
1585                reject: Default::default(),
1586                location,
1587            },
1588            exit: Default::default(),
1589            timeout: None,
1590            expect_failure: false,
1591            set_var: None,
1592            set_vars: Default::default(),
1593            expect: Default::default(),
1594        }
1595    }
1596
1597    pub fn run(&self, context: &mut ScriptRunContext) -> Result<ScriptResult, ScriptRunError> {
1598        let command = &self.command;
1599        let args = &context.args;
1600        let start = Instant::now();
1601
1602        if let Some(delay) = args.delay_steps {
1603            std::thread::sleep(std::time::Duration::from_millis(delay));
1604        }
1605
1606        for line in command.command.split('\n') {
1607            cwriteln!(context.stream(), fg = Color::Green, "{}", line);
1608        }
1609        if args.simplified_output {
1610            cwriteln!(context.stream(), dimmed = true, "---");
1611        } else {
1612            cwriteln_rule!(context.stream(), fg = Color::Cyan, "{}", command.location);
1613        }
1614        let (output, status) = command.run(
1615            &mut context.stream(),
1616            context.args.show_line_numbers,
1617            context.args.runner.clone(),
1618            self.timeout.unwrap_or(context.timeout),
1619            context.env.env_vars(),
1620            &context.kill,
1621            &context.kill_sender,
1622        )?;
1623
1624        let exit_result = if !self.exit.matches(status) {
1625            ExitResult::Mismatch(status)
1626        } else {
1627            ExitResult::Matches(status)
1628        };
1629
1630        // Side-effects
1631        if let Some(set_var) = &self.set_var {
1632            context.set_env(set_var, output.to_string().trim());
1633        }
1634
1635        let match_context = OutputMatchContext::new(context);
1636        for (key, value) in &self.expect {
1637            match_context.expect(key, context.expand(value)?);
1638        }
1639        self.pattern.prepare(&context.grok)?;
1640        let prepared_output = output
1641            .with_ignore(&context.global_ignore)
1642            .with_reject(&context.global_reject);
1643        let pattern_result = match self.pattern.matches(match_context.clone(), prepared_output) {
1644            Ok(_) => {
1645                let mut env = context.env.clone();
1646                for (key, value) in match_context.expects() {
1647                    env.set_env(key, value);
1648                }
1649                for (key, value) in &self.set_vars {
1650                    context.set_env(key, env.expand(value)?);
1651                }
1652
1653                if self.expect_failure {
1654                    PatternResult::ExpectedFailure
1655                } else {
1656                    PatternResult::Matches
1657                }
1658            }
1659            Err(e) => {
1660                if self.expect_failure {
1661                    PatternResult::MatchesFailure
1662                } else {
1663                    let mut trace = String::new();
1664                    for line in match_context.traces() {
1665                        trace.push_str(&format!("{line}\n"));
1666                    }
1667                    PatternResult::Mismatch(e, trace)
1668                }
1669            }
1670        };
1671
1672        if output.is_empty() {
1673            cwriteln!(context.stream(), dimmed = true, "(no output)");
1674        }
1675
1676        if context.args.simplified_output {
1677            cwriteln!(context.stream(), dimmed = true, "---");
1678        } else {
1679            cwriteln_rule!(context.stream());
1680        }
1681
1682        Ok(ScriptResult {
1683            command: command.clone(),
1684            pattern: pattern_result,
1685            exit: exit_result,
1686            elapsed: start.elapsed(),
1687            output,
1688        })
1689    }
1690}
1691
1692#[derive(derive_more::Debug)]
1693pub struct ScriptResult {
1694    pub command: CommandLine,
1695    pub pattern: PatternResult,
1696    pub exit: ExitResult,
1697    pub elapsed: Duration,
1698    #[debug(skip)]
1699    pub output: Lines,
1700}
1701
1702impl ScriptResult {
1703    pub fn evaluate(&self, context: &mut ScriptRunContext) -> Result<(), ScriptRunError> {
1704        let args = &context.args;
1705        let (success, failure, warning, arrow) = if *crate::term::IS_UTF8 {
1706            ("✅", "❌", "⚠️", "→")
1707        } else {
1708            ("[*]", "[X]", "[!]", "->")
1709        };
1710
1711        if let ExitResult::Mismatch(status) = self.exit {
1712            if args.ignore_exit_codes {
1713                cwriteln!(
1714                    context.stream(),
1715                    fg = Color::Yellow,
1716                    "{warning} Ignored incorrect exit code: {status}"
1717                );
1718                cwriteln!(context.stream());
1719            } else {
1720                cwriteln!(
1721                    context.stream(),
1722                    fg = Color::Red,
1723                    "{failure} FAIL: {status}"
1724                );
1725                cwriteln!(
1726                    context.stream(),
1727                    dimmed = true,
1728                    " {arrow} {}",
1729                    self.command.command
1730                );
1731                cwriteln!(context.stream());
1732                return Err(ScriptRunError::Exit(status, self.command.location.clone()));
1733            }
1734        }
1735
1736        if let PatternResult::Mismatch(e, trace) = &self.pattern {
1737            if args.ignore_matches {
1738                cwriteln!(
1739                    context.stream(),
1740                    fg = Color::Yellow,
1741                    "{warning} Ignored error: {e} (ignoring mismatches)"
1742                );
1743                cwriteln!(context.stream());
1744            } else {
1745                cwriteln!(context.stream(), fg = Color::Red, "ERROR: {e}");
1746                cwriteln!(context.stream(), dimmed = true, "{trace}");
1747                cwriteln!(context.stream(), fg = Color::Red, "{failure} FAIL");
1748                cwriteln!(context.stream());
1749                return Err(ScriptRunError::Pattern(e.clone()));
1750            }
1751        }
1752
1753        if let PatternResult::ExpectedFailure = self.pattern {
1754            if args.ignore_matches {
1755                cwriteln!(
1756                    context.stream(),
1757                    fg = Color::Yellow,
1758                    "{warning} Should not have matched! (ignoring mismatches)"
1759                );
1760                cwriteln!(context.stream());
1761            } else {
1762                cwriteln!(
1763                    context.stream(),
1764                    fg = Color::Red,
1765                    "{failure} FAIL (output shouldn't match)"
1766                );
1767                cwriteln!(
1768                    context.stream(),
1769                    dimmed = true,
1770                    " {arrow} {}",
1771                    self.command.command
1772                );
1773                cwriteln!(context.stream());
1774                return Err(ScriptRunError::ExpectedFailure(
1775                    self.command.location.clone(),
1776                ));
1777            }
1778        }
1779
1780        if let ExitResult::Matches(status) = self.exit {
1781            if status.success() {
1782                cwrite!(context.stream(), fg = Color::Green, "{success} OK");
1783                if !context.args.simplified_output {
1784                    cwriteln!(
1785                        context.stream(),
1786                        dimmed = true,
1787                        " ({:.2}s)",
1788                        self.elapsed.as_secs_f32()
1789                    );
1790                } else {
1791                    cwriteln!(context.stream());
1792                }
1793            } else {
1794                cwrite!(
1795                    context.stream(),
1796                    fg = Color::Green,
1797                    "{success} OK ({status})"
1798                );
1799                if !context.args.simplified_output {
1800                    cwriteln!(
1801                        context.stream(),
1802                        dimmed = true,
1803                        " ({:.2}s)",
1804                        self.elapsed.as_secs_f32()
1805                    );
1806                } else {
1807                    cwriteln!(context.stream());
1808                }
1809            }
1810            cwriteln!(context.stream());
1811        }
1812
1813        Ok(())
1814    }
1815}
1816
1817#[derive(Debug)]
1818pub enum PatternResult {
1819    Matches,
1820    MatchesFailure,
1821    ExpectedFailure,
1822    Mismatch(OutputPatternMatchFailure, String),
1823}
1824
1825#[derive(Debug)]
1826pub enum ExitResult {
1827    Matches(CommandResult),
1828    Mismatch(CommandResult),
1829    TimedOut,
1830}
1831
1832#[cfg(test)]
1833mod tests {
1834    use crate::parser::v0::parse_script;
1835
1836    use super::*;
1837    use std::error::Error;
1838
1839    #[test]
1840    fn test_script() -> Result<(), Box<dyn Error>> {
1841        let script = r#"
1842pattern VERSION \d+\.\d+\.\d+;
1843
1844$ something --version || echo 1
1845? Something %{VERSION}
1846
1847$ something --help
1848? Usage: something [OPTIONS]
1849repeat {
1850    choice {
1851? %{DATA} %{GREEDYDATA}
1852? %{DATA}=%{DATA} %{GREEDYDATA}
1853    }
1854}
1855"#;
1856
1857        let script = parse_script(ScriptFile::new("test.cli"), script)?;
1858        assert_eq!(script.commands.len(), 3);
1859        eprintln!("{script:?}");
1860        Ok(())
1861    }
1862
1863    #[test]
1864    fn test_bad_script() -> Result<(), Box<dyn Error>> {
1865        let script = r#"
1866$ (cmd; cmd)
1867$ cmd &
1868    "#;
1869
1870        assert!(matches!(
1871            parse_script(ScriptFile::new("test.cli"), script),
1872            Err(ScriptError {
1873                error: ScriptErrorType::BackgroundProcessNotAllowed,
1874                ..
1875            })
1876        ));
1877        Ok(())
1878    }
1879
1880    #[test]
1881    fn test_script_run_context_expand() {
1882        let mut context = ScriptEnv::default();
1883        context.set_env("A", "1");
1884        context.set_env("B", "2");
1885        context.set_env("C", "3");
1886        assert_eq!(context.expand_str("$A").unwrap(), "1".to_string());
1887        assert_eq!(context.expand_str("$A $B ").unwrap(), "1 2 ".to_string());
1888        assert_eq!(
1889            context.expand_str("${A} ${B} ").unwrap(),
1890            "1 2 ".to_string()
1891        );
1892        assert_eq!(context.expand_str(r#"\$A"#).unwrap(), "$A".to_string());
1893        assert_eq!(context.expand_str(r#"\${A}"#).unwrap(), "${A}".to_string());
1894        assert_eq!(context.expand_str(r#"\\$A"#).unwrap(), r#"\1"#);
1895        assert_eq!(context.expand_str(r#"\\${A}"#).unwrap(), r#"\1"#);
1896        context.set_env("TEMP_DIR", "/tmp");
1897        assert_eq!(context.expand_str("$TEMP_DIR").unwrap(), "/tmp".to_string());
1898        assert_eq!(
1899            context.expand_str("${TEMP_DIR}").unwrap(),
1900            "/tmp".to_string()
1901        );
1902    }
1903}