Skip to main content

clitest_lib/
script.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    path::Path,
4    process::ExitStatus,
5    sync::{Arc, Mutex, atomic::AtomicBool},
6    thread::ScopedJoinHandle,
7    time::{Duration, Instant},
8};
9
10use grok::Grok;
11use keepcalm::SharedMut;
12use serde::{Serialize, ser::SerializeMap};
13use termcolor::{Color, ColorChoice, WriteColor};
14
15use crate::{
16    command::{CommandLine, CommandResult},
17    util::{NicePathBuf, NiceTempDir},
18};
19use crate::{cwrite, cwriteln, cwriteln_rule};
20use crate::{output::*, util::ShellBit};
21
22const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
23
24#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
25pub struct ScriptLocation {
26    pub file: ScriptFile,
27    pub line: usize,
28}
29
30impl ScriptLocation {
31    pub fn new(file: ScriptFile, line: usize) -> Self {
32        Self { file, line }
33    }
34}
35
36impl std::fmt::Display for ScriptLocation {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        write!(f, "{}:{}", self.file, self.line)
39    }
40}
41
42#[derive(
43    derive_more::Debug, derive_more::Display, Clone, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash,
44)]
45#[display("{}", file)]
46pub struct ScriptFile {
47    pub base_line: usize,
48    pub file: Arc<NicePathBuf>,
49}
50
51impl ScriptFile {
52    pub fn new(file: impl AsRef<Path>) -> Self {
53        Self {
54            base_line: 0,
55            file: Arc::new(NicePathBuf::new(file)),
56        }
57    }
58    pub fn new_with_line(file: impl AsRef<Path>, line: usize) -> Self {
59        Self {
60            base_line: line,
61            file: Arc::new(NicePathBuf::new(file)),
62        }
63    }
64}
65
66impl<T: AsRef<Path>> From<T> for ScriptFile {
67    fn from(file: T) -> Self {
68        Self::new(file)
69    }
70}
71
72#[derive(Clone, derive_more::Debug, Serialize)]
73pub struct Script {
74    pub commands: Arc<Vec<ScriptBlock>>,
75    pub includes: Arc<HashMap<String, Script>>,
76    pub file: ScriptFile,
77}
78
79#[derive(Debug, Clone, Default)]
80pub struct ScriptRunArgs {
81    pub delay_steps: Option<u64>,
82    pub ignore_exit_codes: bool,
83    pub ignore_matches: bool,
84    pub simplified_output: bool,
85    pub show_line_numbers: bool,
86    pub runner: Option<String>,
87    pub quiet: bool,
88    pub verbose: bool,
89    pub global_timeout: Option<Duration>,
90    pub no_color: bool,
91}
92
93#[derive(Debug, Clone, Default)]
94pub struct ScriptEnv {
95    env_vars: HashMap<String, String>,
96}
97
98impl ScriptEnv {
99    pub fn set_defaults(&mut self, pwd: impl AsRef<Path>) {
100        macro_rules! target {
101            ($env:ident, $var:ident, [$($vals:expr),*]) => {
102                $(
103                if cfg!($var = $vals) {
104                    self.env_vars.insert(stringify!($env).to_string(), $vals.to_string());
105                }
106                )*
107            };
108        }
109
110        target!(
111            TARGET_OS,
112            target_os,
113            [
114                "windows",
115                "linux",
116                "macos",
117                "ios",
118                "android",
119                "freebsd",
120                "netbsd",
121                "openbsd",
122                "dragonfly",
123                "haiku",
124                "aix"
125            ]
126        );
127        target!(TARGET_FAMILY, target_family, ["windows", "unix", "wasm"]);
128        target!(
129            TARGET_ARCH,
130            target_arch,
131            [
132                "aarch64",
133                "amdgpu",
134                "arm",
135                "arm64ec",
136                "avr",
137                "bpf",
138                "csky",
139                "hexagon",
140                "loongarch32",
141                "loongarch64",
142                "m68k",
143                "mips",
144                "mips32r6",
145                "mips64",
146                "mips64r6",
147                "msp430",
148                "nvptx64",
149                "powerpc",
150                "powerpc64",
151                "riscv32",
152                "riscv64",
153                "s390x",
154                "sparc",
155                "sparc64",
156                "wasm32",
157                "wasm64",
158                "x86",
159                "x86_64",
160                "xtensa"
161            ]
162        );
163
164        // Set the current working directory as a special variable "PWD"
165        let pwd = NicePathBuf::from(pwd.as_ref()).env_string();
166        self.env_vars.insert("PWD".to_string(), pwd);
167        // Save the initial PWD as INITIAL_PWD so it can easily be restored
168        self.env_vars
169            .insert("INITIAL_PWD".to_string(), self.env_vars["PWD"].clone());
170    }
171
172    pub fn pwd(&self) -> NicePathBuf {
173        self.env_vars
174            .get("PWD")
175            .cloned()
176            .map(NicePathBuf::from)
177            .unwrap_or_else(NicePathBuf::cwd)
178    }
179
180    pub fn get_env(&self, name: &str) -> Option<&str> {
181        self.env_vars.get(name).map(|s| s.as_str())
182    }
183
184    pub fn set_env(&mut self, name: impl Into<String>, value: impl Into<String>) {
185        let name = name.into();
186        if name == "PWD" {
187            self.set_pwd(value.into());
188        } else {
189            self.env_vars.insert(name, value.into());
190        }
191    }
192
193    pub fn set_pwd(&mut self, pwd: impl Into<NicePathBuf>) {
194        let pwd = pwd.into().env_string();
195        self.env_vars.insert("PWD".to_string(), pwd);
196    }
197
198    pub fn expand(&self, value: &ShellBit) -> Result<String, ScriptRunError> {
199        match value {
200            ShellBit::Literal(s) => Ok(s.clone()),
201            ShellBit::Quoted(s) => self.expand_str(s),
202        }
203    }
204
205    /// Perform shell expansion on a string.
206    pub fn expand_str(&self, value: impl AsRef<str>) -> Result<String, ScriptRunError> {
207        enum State {
208            Normal,
209            EscapeNext,
210            InCurly,
211            Dollar,
212            InDollar,
213        }
214
215        let value = value.as_ref();
216
217        // "\" triggers escaping
218        // ${A} expands to the value of A
219        // $A expands to the value of A (variable ends on first non-alphanumeric character)
220
221        let mut state = State::Normal;
222        let mut variable = String::new();
223        let mut expanded = String::new();
224
225        for c in value.chars() {
226            match state {
227                State::Normal => {
228                    if c == '$' {
229                        state = State::Dollar;
230                        continue;
231                    }
232                    if c == '\\' {
233                        state = State::EscapeNext;
234                        continue;
235                    }
236                    expanded.push(c);
237                }
238                State::EscapeNext => {
239                    expanded.push(c);
240                    state = State::Normal;
241                }
242                State::InCurly => {
243                    if c == '}' {
244                        if let Some(value) = self.get_env(&std::mem::take(&mut variable)) {
245                            expanded.push_str(value);
246                        } else {
247                            return Err(ScriptRunError::ExpansionError(format!(
248                                "undefined variable in ${{...}}: {variable:?} (in {value:?})"
249                            )));
250                        }
251                        state = State::Normal;
252                    } else {
253                        variable.push(c);
254                    }
255                }
256                State::Dollar => {
257                    if c.is_alphanumeric() || c == '_' {
258                        state = State::InDollar;
259                        variable.push(c);
260                    } else if c == '{' {
261                        state = State::InCurly;
262                    } else {
263                        return Err(ScriptRunError::ExpansionError(format!(
264                            "invalid variable: {c:?} (in {value:?})"
265                        )));
266                    }
267                }
268                State::InDollar => {
269                    if c.is_alphanumeric() || c == '_' {
270                        variable.push(c);
271                    } else {
272                        if let Some(value) = self.get_env(&std::mem::take(&mut variable)) {
273                            expanded.push_str(value);
274                        } else {
275                            return Err(ScriptRunError::ExpansionError(format!(
276                                "undefined variable in $...: {variable:?} (in {value:?})"
277                            )));
278                        }
279                        expanded.push(c);
280                        state = State::Normal;
281                    }
282                }
283            }
284        }
285        match state {
286            State::InDollar => {
287                if let Some(value) = self.get_env(&variable) {
288                    expanded.push_str(value);
289                } else {
290                    return Err(ScriptRunError::ExpansionError(format!(
291                        "undefined variable: {variable}"
292                    )));
293                }
294            }
295            State::Dollar => {
296                return Err(ScriptRunError::ExpansionError(
297                    "incomplete variable".to_string(),
298                ));
299            }
300            State::InCurly => {
301                return Err(ScriptRunError::ExpansionError(format!(
302                    "unclosed variable: {variable}"
303                )));
304            }
305            State::Normal => {}
306            State::EscapeNext => {
307                return Err(ScriptRunError::ExpansionError(
308                    "unclosed backslash".to_string(),
309                ));
310            }
311        }
312        Ok(expanded)
313    }
314
315    pub fn env_vars(&self) -> &HashMap<String, String> {
316        &self.env_vars
317    }
318}
319
320#[derive(derive_more::Debug, Clone)]
321pub struct ScriptOutput {
322    #[debug(skip)]
323    stream: SharedMut<Box<dyn WriteColorAny>>,
324}
325
326trait WriteColorAny: WriteColor + Send + Sync + std::any::Any + 'static + std::fmt::Debug {
327    /// Workaround for lack of upcasting
328    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String>;
329    fn clone_buffer(&self) -> Result<termcolor::Buffer, String>;
330}
331
332impl WriteColorAny for termcolor::StandardStream {
333    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String> {
334        Err("not a buffer".to_string())
335    }
336    fn clone_buffer(&self) -> Result<termcolor::Buffer, String> {
337        Err("not a buffer".to_string())
338    }
339}
340
341impl WriteColorAny for termcolor::Buffer {
342    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String> {
343        Ok(*self)
344    }
345    fn clone_buffer(&self) -> Result<termcolor::Buffer, String> {
346        Ok(self.clone())
347    }
348}
349
350impl ScriptOutput {
351    pub fn no_color() -> Self {
352        let stm = termcolor::StandardStream::stdout(ColorChoice::Never);
353        Self {
354            stream: SharedMut::new(Box::new(stm) as _),
355        }
356    }
357
358    pub fn quiet(no_color: bool) -> Self {
359        let stm = if no_color {
360            termcolor::Buffer::no_color()
361        } else {
362            termcolor::Buffer::ansi()
363        };
364        Self {
365            stream: SharedMut::new(Box::new(stm) as _),
366        }
367    }
368
369    pub fn take_buffer(self) -> String {
370        let stream = match SharedMut::try_unwrap(self.stream) {
371            Ok(stream) => stream.take_buffer().expect("wrong stream type"),
372            Err(shared) => shared.read().clone_buffer().expect("wrong stream type"),
373        };
374        String::from_utf8_lossy(&stream.into_inner()).to_string()
375    }
376}
377
378impl Default for ScriptOutput {
379    fn default() -> Self {
380        let stm = termcolor::StandardStream::stdout(ColorChoice::Auto);
381        Self {
382            stream: SharedMut::new(Box::new(stm) as _),
383        }
384    }
385}
386
387impl std::io::Write for ScriptOutputLock<'_> {
388    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
389        self.stream.write(buf)
390    }
391    fn flush(&mut self) -> std::io::Result<()> {
392        self.stream.flush()
393    }
394}
395
396impl termcolor::WriteColor for ScriptOutputLock<'_> {
397    fn supports_color(&self) -> bool {
398        self.stream.supports_color()
399    }
400    fn set_color(&mut self, spec: &termcolor::ColorSpec) -> std::io::Result<()> {
401        self.stream.set_color(spec)
402    }
403    fn reset(&mut self) -> std::io::Result<()> {
404        self.stream.reset()
405    }
406    fn is_synchronous(&self) -> bool {
407        self.stream.is_synchronous()
408    }
409    fn set_hyperlink(&mut self, _link: &termcolor::HyperlinkSpec) -> std::io::Result<()> {
410        self.stream.set_hyperlink(_link)
411    }
412    fn supports_hyperlinks(&self) -> bool {
413        self.stream.supports_hyperlinks()
414    }
415}
416
417struct ScriptOutputLock<'a> {
418    stream: keepcalm::SharedWriteLock<'a, Box<dyn WriteColorAny>>,
419}
420
421#[derive(Debug, Clone, Copy, PartialEq, Eq)]
422enum ScriptMode {
423    Normal,
424    Deferred,
425    Background,
426}
427
428#[derive(derive_more::Debug)]
429pub struct ScriptRunContext {
430    pub args: ScriptRunArgs,
431    pub grok: Grok,
432    timeout: Duration,
433    env: ScriptEnv,
434    includes: Arc<HashMap<String, Script>>,
435    background: ScriptMode,
436    #[debug(skip)]
437    kill: ScriptKillReceiver,
438    #[debug(skip)]
439    kill_sender: ScriptKillSender,
440    output: ScriptOutput,
441
442    global_ignore: OutputPatterns,
443    global_reject: OutputPatterns,
444}
445
446impl Default for ScriptRunContext {
447    fn default() -> Self {
448        let kill = Arc::new(AtomicBool::new(false));
449        Self {
450            args: ScriptRunArgs::default(),
451            grok: Grok::with_default_patterns(),
452            timeout: DEFAULT_TIMEOUT,
453            env: ScriptEnv::default(),
454            background: ScriptMode::Normal,
455            includes: Arc::new(HashMap::new()),
456            kill: ScriptKillReceiver::new(kill.clone()),
457            kill_sender: ScriptKillSender::new(kill.clone()),
458            output: ScriptOutput::default(),
459            global_ignore: OutputPatterns::default(),
460            global_reject: OutputPatterns::default(),
461        }
462    }
463}
464
465impl ScriptRunContext {
466    pub fn new_background(&self) -> Self {
467        let kill = Arc::new(AtomicBool::new(false));
468        Self {
469            args: self.args.clone(),
470            grok: self.grok.clone(),
471            // Background processes are not subject to timeouts
472            timeout: Duration::MAX,
473            env: self.env.clone(),
474            background: ScriptMode::Background,
475            kill: ScriptKillReceiver::new(kill.clone()),
476            kill_sender: ScriptKillSender::new(kill.clone()),
477            includes: self.includes.clone(),
478            output: if self.args.verbose {
479                self.output.clone()
480            } else {
481                ScriptOutput::quiet(self.args.no_color)
482            },
483            global_ignore: self.global_ignore.clone(),
484            global_reject: self.global_reject.clone(),
485        }
486    }
487
488    pub fn new_deferred(&self) -> Self {
489        Self {
490            args: self.args.clone(),
491            grok: self.grok.clone(),
492            timeout: self.timeout,
493            env: self.env.clone(),
494            background: ScriptMode::Deferred,
495            kill: self.kill.clone(),
496            kill_sender: self.kill_sender.clone(),
497            includes: self.includes.clone(),
498            output: self.output.clone(),
499            global_ignore: self.global_ignore.clone(),
500            global_reject: self.global_reject.clone(),
501        }
502    }
503
504    pub fn pwd(&self) -> NicePathBuf {
505        self.env.pwd()
506    }
507
508    pub fn get_env(&self, name: &str) -> Option<&str> {
509        self.env.get_env(name)
510    }
511
512    pub fn set_env(&mut self, name: impl Into<String>, value: impl Into<String>) {
513        self.env.set_env(name, value);
514    }
515
516    pub fn set_pwd(&mut self, pwd: impl Into<NicePathBuf>) {
517        self.env.set_pwd(pwd);
518    }
519
520    pub fn take_output(self) -> String {
521        self.output.take_buffer()
522    }
523
524    fn expand(&self, value: &ShellBit) -> Result<String, ScriptRunError> {
525        self.env.expand(value)
526    }
527
528    /// Get a mutable reference to the output stream.
529    pub fn stream(&self) -> impl termcolor::WriteColor + use<'_> {
530        ScriptOutputLock {
531            stream: self.output.stream.write(),
532        }
533    }
534}
535
536#[derive(Clone)]
537pub struct ScriptKillReceiver {
538    kill_receiver: Arc<AtomicBool>,
539}
540
541impl ScriptKillReceiver {
542    pub fn new(kill_receiver: Arc<AtomicBool>) -> Self {
543        Self { kill_receiver }
544    }
545
546    pub fn is_killed(&self) -> bool {
547        self.kill_receiver.load(std::sync::atomic::Ordering::SeqCst)
548    }
549
550    pub fn run_with<T>(&self, kill: impl FnOnce() + Send, wait: impl FnOnce() -> T) -> T {
551        std::thread::scope(|s| {
552            let done = Arc::new(AtomicBool::new(false));
553            let done_clone = done.clone();
554            let t = s.spawn(move || {
555                while !done_clone.load(std::sync::atomic::Ordering::SeqCst) {
556                    if self.is_killed() {
557                        kill();
558                        break;
559                    }
560                    std::thread::sleep(Duration::from_millis(10));
561                }
562            });
563            let res = wait();
564            done.store(true, std::sync::atomic::Ordering::SeqCst);
565            t.join().unwrap();
566            res
567        })
568    }
569
570    #[cfg(windows)]
571    pub fn run_cmd(
572        &self,
573        output: std::process::Child,
574        warn_time: Duration,
575    ) -> std::io::Result<ExitStatus> {
576        use std::os::windows::io::AsRawHandle;
577        use win32job::Job;
578
579        fn map_job_error(e: win32job::JobError) -> std::io::Error {
580            match e {
581                win32job::JobError::AssignFailed(e) => e,
582                win32job::JobError::CreateFailed(e) => e,
583                win32job::JobError::GetInfoFailed(e) => e,
584                win32job::JobError::SetInfoFailed(e) => e,
585                _ => std::io::Error::new(std::io::ErrorKind::Other, "Unknown error"),
586            }
587        }
588
589        // Create a new Job object
590        let job = Job::create().map_err(map_job_error)?;
591
592        // Configure the job to terminate all child processes when the job is closed
593        let mut info = job.query_extended_limit_info().map_err(map_job_error)?;
594        info.limit_kill_on_job_close();
595        job.set_extended_limit_info(&info).map_err(map_job_error)?;
596        job.assign_process(output.as_raw_handle() as _)?;
597
598        // Resume the main thread for the process
599        let id = output.id();
600        for thread_entry in tlhelp32::Snapshot::new_thread()? {
601            if thread_entry.owner_process_id == id {
602                use windows_sys::Win32::Foundation::CloseHandle;
603                use windows_sys::Win32::System::Threading::*;
604
605                unsafe {
606                    let thread = OpenThread(THREAD_SUSPEND_RESUME, 0, thread_entry.thread_id);
607                    if thread.is_null() {
608                        return Err(std::io::Error::last_os_error().into());
609                    }
610                    ResumeThread(thread);
611                    CloseHandle(thread);
612                }
613            }
614        }
615
616        let job = Mutex::new(Some(job));
617        let output = Mutex::new(output);
618        self.run_with(
619            || {
620                _ = job.lock().unwrap().take();
621                _ = output.lock().unwrap().kill();
622            },
623            || {
624                let start = std::time::Instant::now();
625                let mut warned = false;
626                loop {
627                    let res = output.lock().unwrap().try_wait()?;
628                    if let Some(status) = res {
629                        return Ok::<_, std::io::Error>(status);
630                    }
631                    if start.elapsed() > warn_time {
632                        if !warned {
633                            let child = output.lock().unwrap().id();
634                            eprintln!("Process #{child} taking too long to finish.");
635                            warned = true;
636                        }
637                    }
638                    std::thread::sleep(Duration::from_millis(10));
639                }
640            },
641        )
642    }
643
644    #[cfg(unix)]
645    pub fn run_cmd(
646        &self,
647        output: std::process::Child,
648        warn_time: Duration,
649    ) -> std::io::Result<ExitStatus> {
650        let output = Mutex::new(output);
651        self.run_with(
652            || {
653                use signal_child::{signal, signal::*};
654                let id = output.lock().unwrap().id() as i32;
655                _ = signal(-id, SIGINT);
656                std::thread::sleep(Duration::from_millis(10));
657                _ = signal(-id, SIGTERM);
658            },
659            || {
660                let start = std::time::Instant::now();
661                let mut warned = false;
662                loop {
663                    let res = output.lock().unwrap().try_wait()?;
664                    if let Some(status) = res {
665                        return Ok::<_, std::io::Error>(status);
666                    }
667                    if start.elapsed() > warn_time && !warned {
668                        let child = output.lock().unwrap().id();
669                        eprintln!("Process #{child} taking too long to finish.");
670                        warned = true;
671                    }
672                    std::thread::sleep(Duration::from_millis(10));
673                }
674            },
675        )
676    }
677}
678
679#[derive(Clone)]
680pub struct ScriptKillSender {
681    kill_sender: Arc<AtomicBool>,
682}
683
684impl ScriptKillSender {
685    pub fn new(kill_sender: Arc<AtomicBool>) -> Self {
686        Self { kill_sender }
687    }
688
689    pub fn kill(&self) {
690        self.kill_sender
691            .store(true, std::sync::atomic::Ordering::SeqCst);
692    }
693}
694
695impl ScriptRunContext {
696    pub fn new(args: ScriptRunArgs, script_path: impl AsRef<Path>, output: ScriptOutput) -> Self {
697        let mut env = ScriptEnv::default();
698        env.set_defaults(script_path.as_ref().parent().unwrap());
699
700        let kill = Arc::new(AtomicBool::new(false));
701
702        Self {
703            timeout: args.global_timeout.unwrap_or(DEFAULT_TIMEOUT),
704            args,
705            env,
706            grok: Grok::with_default_patterns(),
707            includes: Arc::new(HashMap::new()),
708            background: ScriptMode::Normal,
709            kill: ScriptKillReceiver::new(kill.clone()),
710            kill_sender: ScriptKillSender::new(kill.clone()),
711            output,
712            global_ignore: OutputPatterns::default(),
713            global_reject: OutputPatterns::default(),
714        }
715    }
716}
717
718#[derive(Clone, Debug, PartialEq, Eq)]
719pub struct ScriptLine {
720    pub location: ScriptLocation,
721    text: String,
722}
723
724impl ScriptLine {
725    pub fn new(file: ScriptFile, line: usize, text: impl AsRef<str>) -> Self {
726        Self {
727            location: ScriptLocation::new(file, line),
728            text: text.as_ref().to_string(),
729        }
730    }
731
732    pub fn parse(file: ScriptFile, text: impl AsRef<str>) -> Vec<Self> {
733        text.as_ref()
734            .lines()
735            .enumerate()
736            .map(|(line, text)| Self {
737                location: ScriptLocation::new(file.clone(), line + file.base_line + 1),
738                text: text.to_string(),
739            })
740            .collect()
741    }
742
743    pub fn starts_with(&self, text: &str) -> bool {
744        self.text.trim().starts_with(text)
745    }
746
747    pub fn first_char(&self) -> Option<char> {
748        self.text.trim().chars().next()
749    }
750
751    pub fn text(&self) -> &str {
752        self.text.trim()
753    }
754
755    pub fn text_untrimmed(&self) -> &str {
756        &self.text
757    }
758
759    pub fn is_empty(&self) -> bool {
760        self.text.trim().is_empty()
761    }
762
763    pub fn strip_prefix(&self, prefix: &str) -> Option<&str> {
764        self.text.strip_prefix(prefix)
765    }
766}
767
768#[derive(Debug, thiserror::Error, derive_more::Display)]
769#[display("{error} at {location}{}", associated_data.as_deref().map_or("".to_string(), |d| format!(": {d}")))]
770pub struct ScriptError {
771    pub error: ScriptErrorType,
772    pub location: ScriptLocation,
773    pub associated_data: Option<String>,
774}
775
776impl ScriptError {
777    pub fn new(error: ScriptErrorType, location: ScriptLocation) -> Self {
778        if std::env::var("PANIC_ON_ERROR").is_ok() {
779            panic!("ScriptError: {error} at {location}");
780        }
781        Self {
782            error,
783            location,
784            associated_data: None,
785        }
786    }
787
788    pub fn new_with_data(
789        error: ScriptErrorType,
790        location: ScriptLocation,
791        associated_data: String,
792    ) -> Self {
793        if std::env::var("PANIC_ON_ERROR").is_ok() {
794            panic!("ScriptError: {error} at {location}: {associated_data}");
795        }
796        Self {
797            error,
798            location,
799            associated_data: Some(associated_data),
800        }
801    }
802}
803
804#[derive(Debug, thiserror::Error, Eq, PartialEq)]
805pub enum ScriptErrorType {
806    #[error("background process not allowed")]
807    BackgroundProcessNotAllowed,
808    #[error("unclosed quote")]
809    UnclosedQuote,
810    #[error("unclosed backslash")]
811    UnclosedBackslash,
812    #[error("illegal shell command format")]
813    IllegalShellCommand,
814    #[error("unsupported redirection")]
815    UnsupportedRedirection,
816    #[error("invalid pattern definition")]
817    InvalidPatternDefinition,
818    #[error("invalid pattern")]
819    InvalidPattern,
820    #[error("invalid meta command")]
821    InvalidMetaCommand,
822    #[error("invalid pattern at global level (only reject or ignore allowed here)")]
823    InvalidGlobalPattern,
824    #[error("invalid block type")]
825    InvalidBlockType,
826    #[error("invalid block arguments")]
827    InvalidBlockArgs,
828    #[error("unsupported command position")]
829    UnsupportedCommandPosition,
830    #[error("invalid trailing pattern after *")]
831    InvalidAnyPattern,
832    #[error("invalid exit status")]
833    InvalidExitStatus,
834    #[error("invalid set variable")]
835    InvalidSetVariable,
836    #[error("invalid version header, expected `#!/usr/bin/env clitest --v0`")]
837    InvalidVersion,
838    #[error("invalid internal command")]
839    InvalidInternalCommand,
840    #[error("missing command lines")]
841    MissingCommandLines,
842    #[error(
843        "block end without matching block start, too many closing braces or braces not properly nested"
844    )]
845    InvalidBlockEnd,
846    #[error("invalid if condition")]
847    InvalidIfCondition,
848    #[error("expected block or semi-colon (did you forget to add ';' at the end of this line?)")]
849    ExpectedBlockOrSemi,
850}
851
852#[derive(Debug, thiserror::Error)]
853pub enum ScriptRunError {
854    #[error("{0}")]
855    Pattern(#[from] OutputPatternMatchFailure),
856    #[error("{0}")]
857    PatternPrepareError(#[from] OutputPatternPrepareError),
858    #[error("{0} at line {1}")]
859    Exit(CommandResult, ScriptLocation),
860    #[error("included file not found: {0}")]
861    IncludedFileNotFound(String),
862    #[error("expected failure, but passed at line {0}")]
863    ExpectedFailure(ScriptLocation),
864    #[error("{0}")]
865    ExpansionError(String),
866    #[error("{0}")]
867    IO(#[from] std::io::Error),
868    #[error("killed")]
869    Killed,
870    #[error("background process took too long to finish")]
871    BackgroundProcessTookTooLong,
872    #[error("retry took too long to finish")]
873    RetryTookTooLong,
874    /// Internal flow control: exit the script
875    #[error("exiting script")]
876    ExitScript,
877}
878
879impl ScriptRunError {
880    #[expect(unused)]
881    pub fn short(&self) -> String {
882        match self {
883            Self::Pattern(_) => "Pattern".to_string(),
884            Self::PatternPrepareError(e) => format!("PatternPrepareError({e:?})"),
885            Self::Exit(status, _) => format!("Exit({status})"),
886            Self::ExpectedFailure(_) => "ExpectedFailure".to_string(),
887            Self::IO(e) => format!("IO({:?})", e.kind()),
888            Self::Killed => "Killed".to_string(),
889            Self::BackgroundProcessTookTooLong => "BackgroundProcessTookTooLong".to_string(),
890            Self::ExpansionError(e) => "ExpansionError".to_string(),
891            Self::RetryTookTooLong => "RetryTookTooLong".to_string(),
892            Self::ExitScript => unreachable!(),
893            Self::IncludedFileNotFound(path) => format!("IncludedFileNotFound({path})"),
894        }
895    }
896}
897
898impl Script {
899    pub fn new(file: ScriptFile) -> Self {
900        Self {
901            commands: Arc::new(vec![]),
902            includes: Arc::new(HashMap::new()),
903            file,
904        }
905    }
906
907    /// Collect all included script paths from the script.
908    pub fn includes(&self) -> Vec<(ScriptLocation, String)> {
909        self.commands
910            .iter()
911            .flat_map(|block| block.includes())
912            .collect()
913    }
914
915    pub fn run(&self, context: &mut ScriptRunContext) -> Result<(), ScriptRunError> {
916        let old_includes = context.includes.clone();
917        context.includes = self.includes.clone();
918        let res = ScriptBlock::run_blocks(context, &self.commands);
919        context.includes = old_includes;
920        let v = match res {
921            Ok(v) => v,
922            // Bypass normal script processing and exit successfully
923            Err(ScriptRunError::ExitScript) => return Ok(()),
924            Err(e) => return Err(e),
925        };
926        assert!(v.is_empty(), "script did not run to completion: {v:?}");
927        Ok(())
928    }
929
930    pub fn run_with_args(
931        &self,
932        args: ScriptRunArgs,
933        output: ScriptOutput,
934    ) -> Result<(), ScriptRunError> {
935        let start = Instant::now();
936        let script_path = &*self.file.file;
937        let mut context = ScriptRunContext::new(args, script_path, output);
938
939        // Write "Running..." message with colors
940        cwrite!(context.stream(), "Running ");
941        cwrite!(context.stream(), fg = Color::Cyan, "{}", script_path);
942        cwriteln!(context.stream(), " ...");
943        cwriteln!(context.stream());
944
945        let result = self.run(&mut context);
946
947        // Handle success and error output
948        if let Err(ref e) = result {
949            cwrite!(context.stream(), fg = Color::Cyan, "{} ", script_path);
950            cwrite!(context.stream(), fg = Color::Red, "FAILED");
951            if !context.args.simplified_output {
952                cwriteln!(context.stream(), " ({:.2}s)", start.elapsed().as_secs_f32());
953            } else {
954                cwriteln!(context.stream());
955            }
956            cwrite!(context.stream(), fg = Color::Red, "Error: ");
957            cwriteln!(context.stream(), "{}", e);
958            cwriteln!(context.stream());
959        } else {
960            cwrite!(context.stream(), fg = Color::Cyan, "{} ", script_path);
961            cwrite!(context.stream(), fg = Color::Green, "PASSED");
962            if !context.args.simplified_output {
963                cwriteln!(context.stream(), " ({:.2}s)", start.elapsed().as_secs_f32());
964            } else {
965                cwriteln!(context.stream());
966            }
967        }
968
969        result
970    }
971}
972
973#[derive(Debug, Default, Serialize)]
974pub enum CommandExit {
975    #[default]
976    Success,
977    Failure(i32),
978    Timeout,
979    Any,
980    AnyFailure,
981}
982
983impl CommandExit {
984    pub fn matches(&self, status: CommandResult) -> bool {
985        match (self, status) {
986            (CommandExit::Success, CommandResult::Exit(status)) => status.success(),
987            (CommandExit::Failure(code), CommandResult::Exit(status)) => {
988                *code == status.code().unwrap_or(-1)
989            }
990            (CommandExit::Timeout, CommandResult::TimedOut) => true,
991            (CommandExit::Any, _) => true,
992            (CommandExit::AnyFailure, CommandResult::Exit(status)) => !status.success(),
993            (CommandExit::AnyFailure, _) => true,
994            _ => false,
995        }
996    }
997
998    pub fn is_success(&self) -> bool {
999        matches!(self, CommandExit::Success)
1000    }
1001}
1002
1003#[derive(derive_more::Debug)]
1004#[allow(clippy::large_enum_variant)]
1005pub enum ScriptBlock {
1006    Command(ScriptCommand),
1007    InternalCommand(ScriptLocation, InternalCommand),
1008    Background(Vec<ScriptBlock>),
1009    Defer(Vec<ScriptBlock>),
1010    If(IfCondition, Vec<ScriptBlock>),
1011    For(ForCondition, Vec<ScriptBlock>),
1012    Retry(Vec<ScriptBlock>),
1013    GlobalIgnore(OutputPatterns),
1014    GlobalReject(OutputPatterns),
1015}
1016
1017impl ScriptBlock {
1018    pub fn includes(&self) -> Vec<(ScriptLocation, String)> {
1019        match self {
1020            ScriptBlock::Command(..) => vec![],
1021            ScriptBlock::InternalCommand(location, InternalCommand::Include(path)) => {
1022                vec![(location.clone(), path.clone())]
1023            }
1024            ScriptBlock::InternalCommand(..) => vec![],
1025            ScriptBlock::Background(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1026            ScriptBlock::Defer(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1027            ScriptBlock::If(_, blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1028            ScriptBlock::For(_, blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1029            ScriptBlock::Retry(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
1030            ScriptBlock::GlobalIgnore(_) => vec![],
1031            ScriptBlock::GlobalReject(_) => vec![],
1032        }
1033    }
1034
1035    #[allow(clippy::type_complexity)]
1036    pub fn run_blocks(
1037        context: &mut ScriptRunContext,
1038        blocks: &[ScriptBlock],
1039    ) -> Result<Vec<ScriptResult>, ScriptRunError> {
1040        enum Deferred<'a> {
1041            Scripts(&'a [ScriptBlock]),
1042            Internal(
1043                Box<
1044                    dyn FnOnce(&mut ScriptRunContext) -> Result<(), ScriptRunError>
1045                        + Send
1046                        + Sync
1047                        + 'a,
1048                >,
1049            ),
1050            Background(
1051                ScopedJoinHandle<'a, Result<Vec<ScriptResult>, ScriptRunError>>,
1052                ScriptKillSender,
1053            ),
1054        }
1055
1056        let mut results = Vec::new();
1057        std::thread::scope(|s| {
1058            let mut defer_blocks = VecDeque::new();
1059            let mut pending_error = None;
1060            for block in blocks {
1061                if context.kill.is_killed() {
1062                    return Err(ScriptRunError::Killed);
1063                }
1064                match block {
1065                    ScriptBlock::Background(blocks) => {
1066                        let mut context = context.new_background();
1067                        let kill_sender = context.kill_sender.clone();
1068                        let handle = s.spawn(move || Self::run_blocks(&mut context, blocks));
1069                        defer_blocks.push_front(Deferred::Background(handle, kill_sender));
1070                    }
1071                    ScriptBlock::Defer(blocks) => {
1072                        // Insert at the front of the queue by extending and
1073                        // then rotating
1074                        defer_blocks.push_front(Deferred::Scripts(blocks));
1075                    }
1076                    ScriptBlock::InternalCommand(_, command) => {
1077                        if context.background == ScriptMode::Deferred {
1078                            cwrite!(context.stream(), dimmed = true, "(deferred) ");
1079                        }
1080                        if let Some(f) = command.run(context)? {
1081                            defer_blocks.push_front(Deferred::Internal(f));
1082                        }
1083                    }
1084                    _ => match block.run(context) {
1085                        Ok(res) => results.extend(res),
1086                        Err(e) => {
1087                            pending_error = Some(e);
1088                            break;
1089                        }
1090                    },
1091                }
1092            }
1093            for block in defer_blocks {
1094                match block {
1095                    Deferred::Scripts(blocks) => {
1096                        let mut context = context.new_deferred();
1097                        ScriptBlock::run_blocks(&mut context, blocks)?;
1098                    }
1099                    Deferred::Internal(block) => {
1100                        cwrite!(context.stream(), dimmed = true, "(cleanup) ");
1101                        block(context)?;
1102                    }
1103                    Deferred::Background(handle, kill_sender) => {
1104                        kill_sender.kill();
1105                        let start = std::time::Instant::now();
1106                        let mut warned = false;
1107
1108                        let timeout = context.timeout;
1109                        let warn_at = timeout * 8 / 10;
1110
1111                        let results = loop {
1112                            if handle.is_finished() {
1113                                break handle.join().unwrap()?;
1114                            }
1115                            std::thread::sleep(std::time::Duration::from_millis(10));
1116                            if !warned && start.elapsed() > warn_at {
1117                                cwriteln!(
1118                                    context.stream(),
1119                                    fg = Color::Yellow,
1120                                    "Background process is taking too long to finish."
1121                                );
1122                                warned = true;
1123                            }
1124                            if start.elapsed() > timeout {
1125                                cwriteln!(
1126                                    context.stream(),
1127                                    fg = Color::Red,
1128                                    "Background process took too long to finish."
1129                                );
1130                                return Err(ScriptRunError::BackgroundProcessTookTooLong);
1131                            }
1132                        };
1133                        for result in results {
1134                            cwrite!(context.stream(), dimmed = true, "(background) ");
1135                            for line in result.command.command.split('\n') {
1136                                cwriteln!(context.stream(), fg = Color::Green, "{}", line);
1137                            }
1138                            if context.args.simplified_output {
1139                                cwriteln!(context.stream(), dimmed = true, "---");
1140                            } else {
1141                                cwriteln_rule!(
1142                                    context.stream(),
1143                                    fg = Color::Cyan,
1144                                    "{}",
1145                                    result.command.location
1146                                );
1147                            }
1148                            for line in &result.output {
1149                                cwriteln!(context.stream(), "{}", line);
1150                            }
1151                            if result.output.is_empty() {
1152                                cwriteln!(context.stream(), dimmed = true, "(no output)");
1153                            }
1154                            if context.args.simplified_output {
1155                                cwriteln!(context.stream(), dimmed = true, "---");
1156                            } else {
1157                                cwriteln_rule!(context.stream());
1158                            }
1159                            result.evaluate(context)?;
1160                        }
1161                    }
1162                }
1163            }
1164            if let Some(error) = pending_error {
1165                return Err(error);
1166            }
1167            Ok(results)
1168        })
1169    }
1170
1171    pub fn run(&self, context: &mut ScriptRunContext) -> Result<Vec<ScriptResult>, ScriptRunError> {
1172        let pwd = context.pwd();
1173        let res = pwd.exists();
1174        if !matches!(res, Ok(true)) {
1175            cwriteln!(
1176                context.stream(),
1177                fg = Color::Red,
1178                "$PWD {pwd:?} doesn't exist. Run `cd $INITIAL_PWD` to fix.",
1179            );
1180            return Err(ScriptRunError::IO(std::io::Error::new(
1181                std::io::ErrorKind::NotFound,
1182                format!("PWD does not exist: {pwd:?}"),
1183            )));
1184        }
1185
1186        match self {
1187            ScriptBlock::Command(command) => {
1188                if context.background == ScriptMode::Deferred {
1189                    cwrite!(context.stream(), dimmed = true, "(deferred) ");
1190                }
1191                let result = command.run(context)?;
1192                if context.background != ScriptMode::Background {
1193                    result.evaluate(context)?;
1194                    Ok(vec![])
1195                } else {
1196                    Ok(vec![result])
1197                }
1198            }
1199            ScriptBlock::If(condition, blocks) => {
1200                let condition = condition.expand(context)?;
1201                if condition.matches(context) {
1202                    Self::run_blocks(context, blocks)
1203                } else {
1204                    Ok(vec![])
1205                }
1206            }
1207            ScriptBlock::For(ForCondition::Env(env, values), blocks) => {
1208                let mut results = Vec::new();
1209                for value in values {
1210                    context.set_env(env, context.expand(value)?);
1211                    results.extend(Self::run_blocks(context, blocks)?);
1212                }
1213                Ok(results)
1214            }
1215            ScriptBlock::Retry(blocks) => {
1216                let start = Instant::now();
1217                let mut backoff = Duration::from_millis(100);
1218
1219                cwrite!(context.stream(), fg = Color::Green, "retry: ");
1220                cwriteln!(context.stream(), "running...");
1221
1222                loop {
1223                    let mut nested_context = context.new_background();
1224                    if let Ok(results) = Self::run_blocks(&mut nested_context, blocks) {
1225                        let mut all_ok = true;
1226                        for result in results {
1227                            if result.evaluate(&mut nested_context).is_err() {
1228                                all_ok = false;
1229                                break;
1230                            }
1231                        }
1232                        if all_ok {
1233                            let output = nested_context.take_output();
1234                            cwrite!(context.stream(), fg = Color::Green, "retry: ");
1235                            cwriteln!(context.stream(), "success");
1236                            cwriteln!(context.stream());
1237                            cwriteln!(context.stream(), "{output}");
1238                            return Ok(vec![]);
1239                        }
1240                    }
1241
1242                    if start.elapsed() > context.timeout {
1243                        let output = nested_context.take_output();
1244                        cwrite!(context.stream(), fg = Color::Green, "retry: ");
1245                        cwriteln!(context.stream(), fg = Color::Red, "timed out");
1246                        cwriteln!(context.stream());
1247                        cwriteln!(context.stream(), "{output}");
1248                        cwriteln_rule!(context.stream());
1249                        return Err(ScriptRunError::RetryTookTooLong);
1250                    }
1251                    std::thread::sleep(backoff);
1252                    backoff *= 2;
1253                }
1254            }
1255            ScriptBlock::GlobalIgnore(patterns) => {
1256                for pattern in patterns.iter() {
1257                    pattern.prepare(&context.grok)?;
1258                }
1259                context.global_ignore.extend(patterns);
1260                Ok(vec![])
1261            }
1262            ScriptBlock::GlobalReject(patterns) => {
1263                for pattern in patterns.iter() {
1264                    pattern.prepare(&context.grok)?;
1265                }
1266                context.global_reject.extend(patterns);
1267                Ok(vec![])
1268            }
1269            _ => unreachable!("Unexpected block type: {self:?}"),
1270        }
1271    }
1272}
1273
1274impl Serialize for ScriptBlock {
1275    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1276    where
1277        S: serde::Serializer,
1278    {
1279        match self {
1280            ScriptBlock::Command(command) => command.serialize(serializer),
1281            ScriptBlock::InternalCommand(_, command) => command.serialize(serializer),
1282            ScriptBlock::Background(blocks) => {
1283                let mut ser = serializer.serialize_map(Some(1))?;
1284                ser.serialize_entry("background", blocks)?;
1285                ser.end()
1286            }
1287            ScriptBlock::Defer(blocks) => {
1288                let mut ser = serializer.serialize_map(Some(1))?;
1289                ser.serialize_entry("defer", blocks)?;
1290                ser.end()
1291            }
1292            ScriptBlock::If(condition, blocks) => {
1293                let mut ser = serializer.serialize_map(Some(2))?;
1294                ser.serialize_entry("if", condition)?;
1295                ser.serialize_entry("blocks", blocks)?;
1296                ser.end()
1297            }
1298            ScriptBlock::For(condition, blocks) => {
1299                let mut ser = serializer.serialize_map(Some(2))?;
1300                ser.serialize_entry("for", condition)?;
1301                ser.serialize_entry("blocks", blocks)?;
1302                ser.end()
1303            }
1304            ScriptBlock::Retry(blocks) => {
1305                let mut ser = serializer.serialize_map(Some(1))?;
1306                ser.serialize_entry("retry", blocks)?;
1307                ser.end()
1308            }
1309            ScriptBlock::GlobalIgnore(patterns) => {
1310                let mut ser = serializer.serialize_map(Some(1))?;
1311                ser.serialize_entry("ignore", patterns)?;
1312                ser.end()
1313            }
1314            ScriptBlock::GlobalReject(patterns) => {
1315                let mut ser = serializer.serialize_map(Some(1))?;
1316                ser.serialize_entry("reject", patterns)?;
1317                ser.end()
1318            }
1319        }
1320    }
1321}
1322
1323#[derive(Debug, Clone, Serialize)]
1324pub enum InternalCommand {
1325    UsingTempdir,
1326    UsingDir(ShellBit, bool),
1327    ChangeDir(ShellBit),
1328    Set(String, ShellBit),
1329    Include(String),
1330    ExitScript,
1331    Pattern(String, String),
1332}
1333
1334impl InternalCommand {
1335    #[allow(clippy::type_complexity)]
1336    pub fn run(
1337        &self,
1338        context: &mut ScriptRunContext,
1339    ) -> Result<
1340        Option<Box<dyn FnOnce(&mut ScriptRunContext) -> Result<(), ScriptRunError> + Send + Sync>>,
1341        ScriptRunError,
1342    > {
1343        match self.clone() {
1344            InternalCommand::Include(path) => {
1345                let Some(script) = context.includes.get(&path) else {
1346                    return Err(ScriptRunError::IncludedFileNotFound(path));
1347                };
1348                script.clone().run(context)?;
1349                Ok(None)
1350            }
1351            InternalCommand::Pattern(name, pattern) => {
1352                context.grok.add_pattern(name, pattern);
1353                Ok(None)
1354            }
1355            InternalCommand::UsingTempdir => {
1356                let current_pwd = context.pwd();
1357                let tempdir = NiceTempDir::new();
1358                cwrite!(context.stream(), fg = Color::Yellow, "using tempdir: ");
1359                cwriteln!(context.stream(), "{}", tempdir);
1360                cwriteln!(context.stream());
1361                context.set_pwd(&tempdir);
1362                let pwd = context.pwd();
1363                if !pwd.exists()? {
1364                    return Err(ScriptRunError::IO(std::io::Error::new(
1365                        std::io::ErrorKind::NotFound,
1366                        format!("newly created tempdir does not exist: {pwd:?}"),
1367                    )));
1368                }
1369                Ok(Some(Box::new(move |context: &mut ScriptRunContext| {
1370                    cwriteln!(
1371                        context.stream(),
1372                        fg = Color::Yellow,
1373                        "removing {} && cd {}",
1374                        tempdir,
1375                        current_pwd
1376                    );
1377                    cwriteln!(context.stream());
1378                    if !tempdir.exists()? {
1379                        cwriteln!(
1380                            context.stream(),
1381                            fg = Color::Red,
1382                            "tempdir does not exist: {tempdir}"
1383                        );
1384                    }
1385                    if let Err(e) = tempdir.remove_dir_all() {
1386                        cwriteln!(
1387                            context.stream(),
1388                            fg = Color::Red,
1389                            "error removing tempdir: {e:?}"
1390                        );
1391                    }
1392                    Ok::<_, ScriptRunError>(())
1393                })))
1394            }
1395            InternalCommand::UsingDir(dir, new) => {
1396                let current_pwd = context.pwd();
1397                let dir = context.expand(&dir)?;
1398                let new_pwd = current_pwd.join(dir);
1399                if new {
1400                    cwrite!(context.stream(), fg = Color::Yellow, "using new dir: ");
1401                } else {
1402                    cwrite!(context.stream(), fg = Color::Yellow, "using dir: ");
1403                }
1404                cwriteln!(context.stream(), "{}", new_pwd);
1405                cwriteln!(context.stream());
1406
1407                if new {
1408                    new_pwd.create_dir_all()?;
1409                } else if !new_pwd.exists()? {
1410                    return Err(ScriptRunError::IO(std::io::Error::new(
1411                        std::io::ErrorKind::NotFound,
1412                        "directory does not exist",
1413                    )));
1414                }
1415                context.set_pwd(&new_pwd);
1416                Ok(Some(Box::new(move |context: &mut ScriptRunContext| {
1417                    if new {
1418                        cwriteln!(
1419                            context.stream(),
1420                            fg = Color::Yellow,
1421                            "removing {} && cd {}",
1422                            new_pwd,
1423                            current_pwd
1424                        );
1425                        cwriteln!(context.stream());
1426                    } else {
1427                        cwriteln!(context.stream(), fg = Color::Yellow, "cd {}", current_pwd);
1428                        cwriteln!(context.stream());
1429                    }
1430                    if new {
1431                        new_pwd.remove_dir_all()?;
1432                    }
1433                    context.set_pwd(current_pwd);
1434                    Ok::<_, ScriptRunError>(())
1435                })))
1436            }
1437            InternalCommand::ChangeDir(dir) => {
1438                let dir = context.expand(&dir)?;
1439
1440                cwriteln!(context.stream(), fg = Color::Yellow, "cd {dir}");
1441                cwriteln!(context.stream());
1442                let current_pwd = context.pwd();
1443                let new_pwd = current_pwd.join(dir);
1444                if !new_pwd.exists()? {
1445                    return Err(ScriptRunError::IO(std::io::Error::new(
1446                        std::io::ErrorKind::NotFound,
1447                        format!("directory does not exist: {new_pwd}"),
1448                    )));
1449                }
1450                context.set_pwd(new_pwd);
1451                Ok(None)
1452            }
1453            InternalCommand::Set(name, value) => {
1454                let value = context.expand(&value)?;
1455
1456                context.set_env(&name, &value);
1457                let new_value = context.get_env(&name).unwrap_or_default();
1458                if new_value != value {
1459                    cwriteln!(
1460                        context.stream(),
1461                        fg = Color::Yellow,
1462                        "set {name} {value} (-> {new_value})"
1463                    );
1464                } else {
1465                    cwriteln!(context.stream(), fg = Color::Yellow, "set {name} {value}");
1466                }
1467                cwriteln!(context.stream());
1468
1469                Ok(None)
1470            }
1471            InternalCommand::ExitScript => {
1472                cwriteln!(context.stream(), fg = Color::Yellow, "exiting script");
1473                cwriteln!(context.stream());
1474                Err(ScriptRunError::ExitScript)
1475            }
1476        }
1477    }
1478}
1479
1480#[derive(Debug, Clone)]
1481pub enum IfCondition {
1482    True,
1483    False,
1484    EnvEq(bool, String, ShellBit),
1485}
1486
1487impl IfCondition {
1488    pub fn matches(&self, context: &ScriptRunContext) -> bool {
1489        match self {
1490            IfCondition::True => true,
1491            IfCondition::False => false,
1492            IfCondition::EnvEq(negated, name, expected) => {
1493                let value = context.get_env(name).unwrap_or_default();
1494                (expected == value) ^ negated
1495            }
1496        }
1497    }
1498
1499    pub fn expand(&self, context: &ScriptRunContext) -> Result<IfCondition, ScriptRunError> {
1500        match self {
1501            IfCondition::True => Ok(IfCondition::True),
1502            IfCondition::False => Ok(IfCondition::False),
1503            IfCondition::EnvEq(negated, name, expected) => {
1504                let value = context.expand(expected)?;
1505                Ok(IfCondition::EnvEq(
1506                    *negated,
1507                    name.clone(),
1508                    ShellBit::Literal(value),
1509                ))
1510            }
1511        }
1512    }
1513}
1514
1515impl Serialize for IfCondition {
1516    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1517    where
1518        S: serde::Serializer,
1519    {
1520        match self {
1521            IfCondition::True => "true".serialize(serializer),
1522            IfCondition::False => "false".serialize(serializer),
1523            IfCondition::EnvEq(negated, name, value) => {
1524                let mut ser = serializer.serialize_map(Some(3))?;
1525                ser.serialize_entry("op", if *negated { "!=" } else { "==" })?;
1526                ser.serialize_entry("env", name)?;
1527                ser.serialize_entry("value", value)?;
1528                ser.end()
1529            }
1530        }
1531    }
1532}
1533
1534#[derive(Debug)]
1535pub enum ForCondition {
1536    Env(String, Vec<ShellBit>),
1537}
1538
1539impl Serialize for ForCondition {
1540    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1541    where
1542        S: serde::Serializer,
1543    {
1544        match self {
1545            ForCondition::Env(name, values) => {
1546                let mut ser = serializer.serialize_map(Some(2))?;
1547                ser.serialize_entry("env", name)?;
1548                ser.serialize_entry("values", values)?;
1549                ser.end()
1550            }
1551        }
1552    }
1553}
1554
1555fn is_bool_false(b: &bool) -> bool {
1556    !b
1557}
1558
1559#[derive(Debug, Serialize)]
1560pub struct ScriptCommand {
1561    pub command: CommandLine,
1562    pub pattern: OutputPattern,
1563
1564    #[serde(skip_serializing_if = "CommandExit::is_success")]
1565    pub exit: CommandExit,
1566
1567    #[serde(skip_serializing_if = "is_bool_false")]
1568    pub expect_failure: bool,
1569
1570    /// Single set variable (entire command output trimmed)
1571    #[serde(skip_serializing_if = "Option::is_none")]
1572    pub set_var: Option<String>,
1573
1574    /// Specific set variables
1575    pub set_vars: HashMap<String, ShellBit>,
1576
1577    /// Specific command timeout
1578    #[serde(skip_serializing_if = "Option::is_none")]
1579    pub timeout: Option<Duration>,
1580
1581    /// Input grok expectations
1582    pub expect: HashMap<String, ShellBit>,
1583}
1584
1585impl ScriptCommand {
1586    pub fn new(command: CommandLine) -> Self {
1587        let location = command.location.clone();
1588        Self {
1589            command,
1590            pattern: OutputPattern {
1591                pattern: OutputPatternType::None,
1592                ignore: Default::default(),
1593                reject: Default::default(),
1594                location,
1595            },
1596            exit: Default::default(),
1597            timeout: None,
1598            expect_failure: false,
1599            set_var: None,
1600            set_vars: Default::default(),
1601            expect: Default::default(),
1602        }
1603    }
1604
1605    pub fn run(&self, context: &mut ScriptRunContext) -> Result<ScriptResult, ScriptRunError> {
1606        let command = &self.command;
1607        let args = &context.args;
1608        let start = Instant::now();
1609
1610        if let Some(delay) = args.delay_steps {
1611            std::thread::sleep(std::time::Duration::from_millis(delay));
1612        }
1613
1614        for line in command.command.split('\n') {
1615            cwriteln!(context.stream(), fg = Color::Green, "{}", line);
1616        }
1617        if args.simplified_output {
1618            cwriteln!(context.stream(), dimmed = true, "---");
1619        } else {
1620            cwriteln_rule!(context.stream(), fg = Color::Cyan, "{}", command.location);
1621        }
1622        let (output, status) = command.run(
1623            &mut context.stream(),
1624            context.args.show_line_numbers,
1625            context.args.runner.clone(),
1626            self.timeout.unwrap_or(context.timeout),
1627            context.env.env_vars(),
1628            &context.kill,
1629            &context.kill_sender,
1630        )?;
1631
1632        let exit_result = if !self.exit.matches(status) {
1633            ExitResult::Mismatch(status)
1634        } else {
1635            ExitResult::Matches(status)
1636        };
1637
1638        // Side-effects
1639        if let Some(set_var) = &self.set_var {
1640            context.set_env(set_var, output.to_string().trim());
1641        }
1642
1643        let match_context = OutputMatchContext::new(context);
1644        for (key, value) in &self.expect {
1645            match_context.expect(key, context.expand(value)?);
1646        }
1647        self.pattern.prepare(&context.grok)?;
1648        let prepared_output = output
1649            .with_ignore(&context.global_ignore)
1650            .with_reject(&context.global_reject);
1651        let pattern_result = match self.pattern.matches(match_context.clone(), prepared_output) {
1652            Ok(_) => {
1653                let mut env = context.env.clone();
1654                for (key, value) in match_context.expects() {
1655                    env.set_env(key, value);
1656                }
1657                for (key, value) in &self.set_vars {
1658                    context.set_env(key, env.expand(value)?);
1659                }
1660
1661                if self.expect_failure {
1662                    PatternResult::ExpectedFailure
1663                } else {
1664                    PatternResult::Matches
1665                }
1666            }
1667            Err(e) => {
1668                if self.expect_failure {
1669                    PatternResult::MatchesFailure
1670                } else {
1671                    let mut trace = String::new();
1672                    for line in match_context.traces() {
1673                        trace.push_str(&format!("{line}\n"));
1674                    }
1675                    PatternResult::Mismatch(e, trace)
1676                }
1677            }
1678        };
1679
1680        if output.is_empty() {
1681            cwriteln!(context.stream(), dimmed = true, "(no output)");
1682        }
1683
1684        if context.args.simplified_output {
1685            cwriteln!(context.stream(), dimmed = true, "---");
1686        } else {
1687            cwriteln_rule!(context.stream());
1688        }
1689
1690        Ok(ScriptResult {
1691            command: command.clone(),
1692            pattern: pattern_result,
1693            exit: exit_result,
1694            elapsed: start.elapsed(),
1695            output,
1696        })
1697    }
1698}
1699
1700#[derive(derive_more::Debug)]
1701pub struct ScriptResult {
1702    pub command: CommandLine,
1703    pub pattern: PatternResult,
1704    pub exit: ExitResult,
1705    pub elapsed: Duration,
1706    #[debug(skip)]
1707    pub output: Lines,
1708}
1709
1710impl ScriptResult {
1711    pub fn evaluate(&self, context: &mut ScriptRunContext) -> Result<(), ScriptRunError> {
1712        let args = &context.args;
1713        let (success, failure, warning, arrow) = if *crate::term::IS_UTF8 {
1714            ("✅", "❌", "⚠️", "→")
1715        } else {
1716            ("[*]", "[X]", "[!]", "->")
1717        };
1718
1719        if let ExitResult::Mismatch(status) = self.exit {
1720            if args.ignore_exit_codes {
1721                cwriteln!(
1722                    context.stream(),
1723                    fg = Color::Yellow,
1724                    "{warning} Ignored incorrect exit code: {status}"
1725                );
1726                cwriteln!(context.stream());
1727            } else {
1728                cwriteln!(
1729                    context.stream(),
1730                    fg = Color::Red,
1731                    "{failure} FAIL: {status}"
1732                );
1733                cwriteln!(
1734                    context.stream(),
1735                    dimmed = true,
1736                    " {arrow} {}",
1737                    self.command.command
1738                );
1739                cwriteln!(context.stream());
1740                return Err(ScriptRunError::Exit(status, self.command.location.clone()));
1741            }
1742        }
1743
1744        if let PatternResult::Mismatch(e, trace) = &self.pattern {
1745            if args.ignore_matches {
1746                cwriteln!(
1747                    context.stream(),
1748                    fg = Color::Yellow,
1749                    "{warning} Ignored error: {e} (ignoring mismatches)"
1750                );
1751                cwriteln!(context.stream());
1752            } else {
1753                cwriteln!(context.stream(), fg = Color::Red, "ERROR: {e}");
1754                cwriteln!(context.stream(), dimmed = true, "{trace}");
1755                cwriteln!(context.stream(), fg = Color::Red, "{failure} FAIL");
1756                cwriteln!(context.stream());
1757                return Err(ScriptRunError::Pattern(e.clone()));
1758            }
1759        }
1760
1761        if let PatternResult::ExpectedFailure = self.pattern {
1762            if args.ignore_matches {
1763                cwriteln!(
1764                    context.stream(),
1765                    fg = Color::Yellow,
1766                    "{warning} Should not have matched! (ignoring mismatches)"
1767                );
1768                cwriteln!(context.stream());
1769            } else {
1770                cwriteln!(
1771                    context.stream(),
1772                    fg = Color::Red,
1773                    "{failure} FAIL (output shouldn't match)"
1774                );
1775                cwriteln!(
1776                    context.stream(),
1777                    dimmed = true,
1778                    " {arrow} {}",
1779                    self.command.command
1780                );
1781                cwriteln!(context.stream());
1782                return Err(ScriptRunError::ExpectedFailure(
1783                    self.command.location.clone(),
1784                ));
1785            }
1786        }
1787
1788        if let ExitResult::Matches(status) = self.exit {
1789            if status.success() {
1790                cwrite!(context.stream(), fg = Color::Green, "{success} OK");
1791                if !context.args.simplified_output {
1792                    cwriteln!(
1793                        context.stream(),
1794                        dimmed = true,
1795                        " ({:.2}s)",
1796                        self.elapsed.as_secs_f32()
1797                    );
1798                } else {
1799                    cwriteln!(context.stream());
1800                }
1801            } else {
1802                cwrite!(
1803                    context.stream(),
1804                    fg = Color::Green,
1805                    "{success} OK ({status})"
1806                );
1807                if !context.args.simplified_output {
1808                    cwriteln!(
1809                        context.stream(),
1810                        dimmed = true,
1811                        " ({:.2}s)",
1812                        self.elapsed.as_secs_f32()
1813                    );
1814                } else {
1815                    cwriteln!(context.stream());
1816                }
1817            }
1818            cwriteln!(context.stream());
1819        }
1820
1821        Ok(())
1822    }
1823}
1824
1825#[derive(Debug)]
1826pub enum PatternResult {
1827    Matches,
1828    MatchesFailure,
1829    ExpectedFailure,
1830    Mismatch(OutputPatternMatchFailure, String),
1831}
1832
1833#[derive(Debug)]
1834pub enum ExitResult {
1835    Matches(CommandResult),
1836    Mismatch(CommandResult),
1837    TimedOut,
1838}
1839
1840#[cfg(test)]
1841mod tests {
1842    use crate::parser::v0::parse_script;
1843
1844    use super::*;
1845    use std::error::Error;
1846
1847    #[test]
1848    fn test_script() -> Result<(), Box<dyn Error>> {
1849        let script = r#"
1850pattern VERSION \d+\.\d+\.\d+;
1851
1852$ something --version || echo 1
1853? Something %{VERSION}
1854
1855$ something --help
1856? Usage: something [OPTIONS]
1857repeat {
1858    choice {
1859? %{DATA} %{GREEDYDATA}
1860? %{DATA}=%{DATA} %{GREEDYDATA}
1861    }
1862}
1863"#;
1864
1865        let script = parse_script(ScriptFile::new("test.cli"), script)?;
1866        assert_eq!(script.commands.len(), 3);
1867        eprintln!("{script:?}");
1868        Ok(())
1869    }
1870
1871    #[test]
1872    fn test_bad_script() -> Result<(), Box<dyn Error>> {
1873        let script = r#"
1874$ (cmd; cmd)
1875$ cmd &
1876    "#;
1877
1878        assert!(matches!(
1879            parse_script(ScriptFile::new("test.cli"), script),
1880            Err(ScriptError {
1881                error: ScriptErrorType::BackgroundProcessNotAllowed,
1882                ..
1883            })
1884        ));
1885        Ok(())
1886    }
1887
1888    #[test]
1889    fn test_script_run_context_expand() {
1890        let mut context = ScriptEnv::default();
1891        context.set_env("A", "1");
1892        context.set_env("B", "2");
1893        context.set_env("C", "3");
1894        assert_eq!(context.expand_str("$A").unwrap(), "1".to_string());
1895        assert_eq!(context.expand_str("$A $B ").unwrap(), "1 2 ".to_string());
1896        assert_eq!(
1897            context.expand_str("${A} ${B} ").unwrap(),
1898            "1 2 ".to_string()
1899        );
1900        assert_eq!(context.expand_str(r#"\$A"#).unwrap(), "$A".to_string());
1901        assert_eq!(context.expand_str(r#"\${A}"#).unwrap(), "${A}".to_string());
1902        assert_eq!(context.expand_str(r#"\\$A"#).unwrap(), r#"\1"#);
1903        assert_eq!(context.expand_str(r#"\\${A}"#).unwrap(), r#"\1"#);
1904        context.set_env("TEMP_DIR", "/tmp");
1905        assert_eq!(context.expand_str("$TEMP_DIR").unwrap(), "/tmp".to_string());
1906        assert_eq!(
1907            context.expand_str("${TEMP_DIR}").unwrap(),
1908            "/tmp".to_string()
1909        );
1910    }
1911}