Skip to main content

clitest_lib/
script.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    path::Path,
4    process::ExitStatus,
5    sync::{Arc, Mutex, atomic::AtomicBool},
6    thread::ScopedJoinHandle,
7    time::{Duration, Instant},
8};
9
10use grok::Grok;
11use keepcalm::SharedMut;
12use serde::{Serialize, ser::SerializeMap};
13use termcolor::{Color, ColorChoice, WriteColor};
14
15use crate::{
16    command::{CommandLine, CommandResult},
17    util::{NicePathBuf, NiceTempDir},
18};
19use crate::{cwrite, cwriteln, cwriteln_rule};
20use crate::{output::*, util::ShellBit};
21
22const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
23
24#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
25pub struct ScriptLocation {
26    pub file: ScriptFile,
27    pub line: usize,
28}
29
30impl ScriptLocation {
31    pub fn new(file: ScriptFile, line: usize) -> Self {
32        Self { file, line }
33    }
34}
35
36impl std::fmt::Display for ScriptLocation {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        write!(f, "{}:{}", self.file, self.line)
39    }
40}
41
42#[derive(
43    derive_more::Debug, derive_more::Display, Clone, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash,
44)]
45#[display("{}", file)]
46pub struct ScriptFile {
47    pub file: Arc<NicePathBuf>,
48}
49
50impl ScriptFile {
51    pub fn new(file: impl AsRef<Path>) -> Self {
52        Self {
53            file: Arc::new(NicePathBuf::new(file)),
54        }
55    }
56}
57
58impl<T: AsRef<Path>> From<T> for ScriptFile {
59    fn from(file: T) -> Self {
60        Self::new(file)
61    }
62}
63
64#[derive(Clone, derive_more::Debug, Serialize)]
65pub struct Script {
66    pub commands: Arc<Vec<ScriptBlock>>,
67    pub includes: Arc<HashMap<String, Script>>,
68    pub file: ScriptFile,
69}
70
71#[derive(Debug, Clone, Default)]
72pub struct ScriptRunArgs {
73    pub delay_steps: Option<u64>,
74    pub ignore_exit_codes: bool,
75    pub ignore_matches: bool,
76    pub simplified_output: bool,
77    pub show_line_numbers: bool,
78    pub runner: Option<String>,
79    pub quiet: bool,
80    pub verbose: bool,
81    pub global_timeout: Option<Duration>,
82    pub no_color: bool,
83}
84
85#[derive(Debug, Clone, Default)]
86pub struct ScriptEnv {
87    env_vars: HashMap<String, String>,
88}
89
90impl ScriptEnv {
91    pub fn set_defaults(&mut self, pwd: impl AsRef<Path>) {
92        macro_rules! target {
93            ($env:ident, $var:ident, [$($vals:expr),*]) => {
94                $(
95                if cfg!($var = $vals) {
96                    self.env_vars.insert(stringify!($env).to_string(), $vals.to_string());
97                }
98                )*
99            };
100        }
101
102        target!(
103            TARGET_OS,
104            target_os,
105            ["windows", "linux", "macos", "ios", "android"]
106        );
107        target!(TARGET_FAMILY, target_family, ["windows", "unix", "wasm"]);
108        target!(
109            TARGET_ARCH,
110            target_arch,
111            ["x86", "x86_64", "arm", "aarch64"]
112        );
113
114        // Set the current working directory as a special variable "PWD"
115        let pwd = NicePathBuf::from(pwd.as_ref()).env_string();
116        self.env_vars.insert(
117            "PWD".to_string(),
118            pwd,
119        );
120        // Save the initial PWD as INITIAL_PWD so it can easily be restored
121        self.env_vars
122            .insert("INITIAL_PWD".to_string(), self.env_vars["PWD"].clone());
123    }
124
125    pub fn pwd(&self) -> NicePathBuf {
126        self.env_vars
127            .get("PWD")
128            .cloned()
129            .map(NicePathBuf::from)
130            .unwrap_or_else(NicePathBuf::cwd)
131    }
132
133    pub fn get_env(&self, name: &str) -> Option<&str> {
134        self.env_vars.get(name).map(|s| s.as_str())
135    }
136
137    pub fn set_env(&mut self, name: impl Into<String>, value: impl Into<String>) {
138        let name = name.into();
139        if name == "PWD" {
140            self.set_pwd(value.into());
141        } else {
142            self.env_vars.insert(name, value.into());
143        }
144    }
145
146    pub fn set_pwd(&mut self, pwd: impl Into<NicePathBuf>) {
147        let pwd = pwd.into().env_string();
148        self.env_vars.insert("PWD".to_string(), pwd);
149    }
150
151    pub fn expand(&self, value: &ShellBit) -> Result<String, ScriptRunError> {
152        match value {
153            ShellBit::Literal(s) => Ok(s.clone()),
154            ShellBit::Quoted(s) => self.expand_str(s),
155        }
156    }
157
158    /// Perform shell expansion on a string.
159    pub fn expand_str(&self, value: impl AsRef<str>) -> Result<String, ScriptRunError> {
160        enum State {
161            Normal,
162            EscapeNext,
163            InCurly,
164            Dollar,
165            InDollar,
166        }
167
168        let value = value.as_ref();
169
170        // "\" triggers escaping
171        // ${A} expands to the value of A
172        // $A expands to the value of A (variable ends on first non-alphanumeric character)
173
174        let mut state = State::Normal;
175        let mut variable = String::new();
176        let mut expanded = String::new();
177
178        for c in value.chars() {
179            match state {
180                State::Normal => {
181                    if c == '$' {
182                        state = State::Dollar;
183                        continue;
184                    }
185                    if c == '\\' {
186                        state = State::EscapeNext;
187                        continue;
188                    }
189                    expanded.push(c);
190                }
191                State::EscapeNext => {
192                    expanded.push(c);
193                    state = State::Normal;
194                }
195                State::InCurly => {
196                    if c == '}' {
197                        if let Some(value) = self.get_env(&std::mem::take(&mut variable)) {
198                            expanded.push_str(value);
199                        } else {
200                            return Err(ScriptRunError::ExpansionError(format!(
201                                "undefined variable in ${{...}}: {variable:?} (in {value:?})"
202                            )));
203                        }
204                        state = State::Normal;
205                    } else {
206                        variable.push(c);
207                    }
208                }
209                State::Dollar => {
210                    if c.is_alphanumeric() || c == '_' {
211                        state = State::InDollar;
212                        variable.push(c);
213                    } else if c == '{' {
214                        state = State::InCurly;
215                    } else {
216                        return Err(ScriptRunError::ExpansionError(format!(
217                            "invalid variable: {c:?} (in {value:?})"
218                        )));
219                    }
220                }
221                State::InDollar => {
222                    if c.is_alphanumeric() || c == '_' {
223                        variable.push(c);
224                    } else {
225                        if let Some(value) = self.get_env(&std::mem::take(&mut variable)) {
226                            expanded.push_str(value);
227                        } else {
228                            return Err(ScriptRunError::ExpansionError(format!(
229                                "undefined variable in $...: {variable:?} (in {value:?})"
230                            )));
231                        }
232                        expanded.push(c);
233                        state = State::Normal;
234                    }
235                }
236            }
237        }
238        match state {
239            State::InDollar => {
240                if let Some(value) = self.get_env(&variable) {
241                    expanded.push_str(value);
242                } else {
243                    return Err(ScriptRunError::ExpansionError(format!(
244                        "undefined variable: {variable}"
245                    )));
246                }
247            }
248            State::Dollar => {
249                return Err(ScriptRunError::ExpansionError(
250                    "incomplete variable".to_string(),
251                ));
252            }
253            State::InCurly => {
254                return Err(ScriptRunError::ExpansionError(format!(
255                    "unclosed variable: {variable}"
256                )));
257            }
258            State::Normal => {}
259            State::EscapeNext => {
260                return Err(ScriptRunError::ExpansionError(
261                    "unclosed backslash".to_string(),
262                ));
263            }
264        }
265        Ok(expanded)
266    }
267
268    pub fn env_vars(&self) -> &HashMap<String, String> {
269        &self.env_vars
270    }
271}
272
273#[derive(derive_more::Debug, Clone)]
274pub struct ScriptOutput {
275    #[debug(skip)]
276    stream: SharedMut<Box<dyn WriteColorAny>>,
277}
278
279trait WriteColorAny: WriteColor + Send + Sync + std::any::Any + 'static + std::fmt::Debug {
280    /// Workaround for lack of upcasting
281    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String>;
282    fn clone_buffer(&self) -> Result<termcolor::Buffer, String>;
283}
284
285impl WriteColorAny for termcolor::StandardStream {
286    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String> {
287        Err("not a buffer".to_string())
288    }
289    fn clone_buffer(&self) -> Result<termcolor::Buffer, String> {
290        Err("not a buffer".to_string())
291    }
292}
293
294impl WriteColorAny for termcolor::Buffer {
295    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String> {
296        Ok(*self)
297    }
298    fn clone_buffer(&self) -> Result<termcolor::Buffer, String> {
299        Ok(self.clone())
300    }
301}
302
303impl ScriptOutput {
304    pub fn no_color() -> Self {
305        let stm = termcolor::StandardStream::stdout(ColorChoice::Never);
306        Self {
307            stream: SharedMut::new(Box::new(stm) as _),
308        }
309    }
310
311    pub fn quiet(no_color: bool) -> Self {
312        let stm = if no_color {
313            termcolor::Buffer::no_color()
314        } else {
315            termcolor::Buffer::ansi()
316        };
317        Self {
318            stream: SharedMut::new(Box::new(stm) as _),
319        }
320    }
321
322    pub fn take_buffer(self) -> String {
323        let stream = match SharedMut::try_unwrap(self.stream) {
324            Ok(stream) => stream.take_buffer().expect("wrong stream type"),
325            Err(shared) => shared.read().clone_buffer().expect("wrong stream type"),
326        };
327        String::from_utf8_lossy(&stream.into_inner()).to_string()
328    }
329}
330
331impl Default for ScriptOutput {
332    fn default() -> Self {
333        let stm = termcolor::StandardStream::stdout(ColorChoice::Auto);
334        Self {
335            stream: SharedMut::new(Box::new(stm) as _),
336        }
337    }
338}
339
340impl std::io::Write for ScriptOutputLock<'_> {
341    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
342        self.stream.write(buf)
343    }
344    fn flush(&mut self) -> std::io::Result<()> {
345        self.stream.flush()
346    }
347}
348
349impl termcolor::WriteColor for ScriptOutputLock<'_> {
350    fn supports_color(&self) -> bool {
351        self.stream.supports_color()
352    }
353    fn set_color(&mut self, spec: &termcolor::ColorSpec) -> std::io::Result<()> {
354        self.stream.set_color(spec)
355    }
356    fn reset(&mut self) -> std::io::Result<()> {
357        self.stream.reset()
358    }
359    fn is_synchronous(&self) -> bool {
360        self.stream.is_synchronous()
361    }
362    fn set_hyperlink(&mut self, _link: &termcolor::HyperlinkSpec) -> std::io::Result<()> {
363        self.stream.set_hyperlink(_link)
364    }
365    fn supports_hyperlinks(&self) -> bool {
366        self.stream.supports_hyperlinks()
367    }
368}
369
370struct ScriptOutputLock<'a> {
371    stream: keepcalm::SharedWriteLock<'a, Box<dyn WriteColorAny>>,
372}
373
374#[derive(Debug, Clone, Copy, PartialEq, Eq)]
375enum ScriptMode {
376    Normal,
377    Deferred,
378    Background,
379}
380
381#[derive(derive_more::Debug)]
382pub struct ScriptRunContext {
383    pub args: ScriptRunArgs,
384    pub grok: Grok,
385    timeout: Duration,
386    env: ScriptEnv,
387    includes: Arc<HashMap<String, Script>>,
388    background: ScriptMode,
389    #[debug(skip)]
390    kill: ScriptKillReceiver,
391    #[debug(skip)]
392    kill_sender: ScriptKillSender,
393    output: ScriptOutput,
394
395    global_ignore: OutputPatterns,
396    global_reject: OutputPatterns,
397}
398
399impl Default for ScriptRunContext {
400    fn default() -> Self {
401        let kill = Arc::new(AtomicBool::new(false));
402        Self {
403            args: ScriptRunArgs::default(),
404            grok: Grok::with_default_patterns(),
405            timeout: DEFAULT_TIMEOUT,
406            env: ScriptEnv::default(),
407            background: ScriptMode::Normal,
408            includes: Arc::new(HashMap::new()),
409            kill: ScriptKillReceiver::new(kill.clone()),
410            kill_sender: ScriptKillSender::new(kill.clone()),
411            output: ScriptOutput::default(),
412            global_ignore: OutputPatterns::default(),
413            global_reject: OutputPatterns::default(),
414        }
415    }
416}
417
418impl ScriptRunContext {
419    pub fn new_background(&self) -> Self {
420        let kill = Arc::new(AtomicBool::new(false));
421        Self {
422            args: self.args.clone(),
423            grok: self.grok.clone(),
424            // Background processes are not subject to timeouts
425            timeout: Duration::MAX,
426            env: self.env.clone(),
427            background: ScriptMode::Background,
428            kill: ScriptKillReceiver::new(kill.clone()),
429            kill_sender: ScriptKillSender::new(kill.clone()),
430            includes: self.includes.clone(),
431            output: if self.args.verbose {
432                self.output.clone()
433            } else {
434                ScriptOutput::quiet(self.args.no_color)
435            },
436            global_ignore: self.global_ignore.clone(),
437            global_reject: self.global_reject.clone(),
438        }
439    }
440
441    pub fn new_deferred(&self) -> Self {
442        Self {
443            args: self.args.clone(),
444            grok: self.grok.clone(),
445            timeout: self.timeout,
446            env: self.env.clone(),
447            background: ScriptMode::Deferred,
448            kill: self.kill.clone(),
449            kill_sender: self.kill_sender.clone(),
450            includes: self.includes.clone(),
451            output: self.output.clone(),
452            global_ignore: self.global_ignore.clone(),
453            global_reject: self.global_reject.clone(),
454        }
455    }
456
457    pub fn pwd(&self) -> NicePathBuf {
458        self.env.pwd()
459    }
460
461    pub fn get_env(&self, name: &str) -> Option<&str> {
462        self.env.get_env(name)
463    }
464
465    pub fn set_env(&mut self, name: impl Into<String>, value: impl Into<String>) {
466        self.env.set_env(name, value);
467    }
468
469    pub fn set_pwd(&mut self, pwd: impl Into<NicePathBuf>) {
470        self.env.set_pwd(pwd);
471    }
472
473    pub fn take_output(self) -> String {
474        self.output.take_buffer()
475    }
476
477    fn expand(&self, value: &ShellBit) -> Result<String, ScriptRunError> {
478        self.env.expand(value)
479    }
480
481    /// Get a mutable reference to the output stream.
482    pub fn stream(&self) -> impl termcolor::WriteColor + use<'_> {
483        ScriptOutputLock {
484            stream: self.output.stream.write(),
485        }
486    }
487}
488
489#[derive(Clone)]
490pub struct ScriptKillReceiver {
491    kill_receiver: Arc<AtomicBool>,
492}
493
494impl ScriptKillReceiver {
495    pub fn new(kill_receiver: Arc<AtomicBool>) -> Self {
496        Self { kill_receiver }
497    }
498
499    pub fn is_killed(&self) -> bool {
500        self.kill_receiver.load(std::sync::atomic::Ordering::SeqCst)
501    }
502
503    pub fn run_with<T>(&self, kill: impl FnOnce() + Send, wait: impl FnOnce() -> T) -> T {
504        std::thread::scope(|s| {
505            let done = Arc::new(AtomicBool::new(false));
506            let done_clone = done.clone();
507            let t = s.spawn(move || {
508                while !done_clone.load(std::sync::atomic::Ordering::SeqCst) {
509                    if self.is_killed() {
510                        kill();
511                        break;
512                    }
513                    std::thread::sleep(Duration::from_millis(10));
514                }
515            });
516            let res = wait();
517            done.store(true, std::sync::atomic::Ordering::SeqCst);
518            t.join().unwrap();
519            res
520        })
521    }
522
523    #[cfg(windows)]
524    pub fn run_cmd(
525        &self,
526        output: std::process::Child,
527        warn_time: Duration,
528    ) -> std::io::Result<ExitStatus> {
529        use std::os::windows::io::AsRawHandle;
530        use win32job::Job;
531
532        fn map_job_error(e: win32job::JobError) -> std::io::Error {
533            match e {
534                win32job::JobError::AssignFailed(e) => e,
535                win32job::JobError::CreateFailed(e) => e,
536                win32job::JobError::GetInfoFailed(e) => e,
537                win32job::JobError::SetInfoFailed(e) => e,
538                _ => std::io::Error::new(std::io::ErrorKind::Other, "Unknown error"),
539            }
540        }
541
542        // Create a new Job object
543        let job = Job::create().map_err(map_job_error)?;
544
545        // Configure the job to terminate all child processes when the job is closed
546        let mut info = job.query_extended_limit_info().map_err(map_job_error)?;
547        info.limit_kill_on_job_close();
548        job.set_extended_limit_info(&info).map_err(map_job_error)?;
549        job.assign_process(output.as_raw_handle() as _)?;
550
551        // Resume the main thread for the process
552        let id = output.id();
553        for thread_entry in tlhelp32::Snapshot::new_thread()? {
554            if thread_entry.owner_process_id == id {
555                use windows_sys::Win32::Foundation::CloseHandle;
556                use windows_sys::Win32::System::Threading::*;
557
558                unsafe {
559                    let thread = OpenThread(THREAD_SUSPEND_RESUME, 0, thread_entry.thread_id);
560                    if thread.is_null() {
561                        return Err(std::io::Error::last_os_error().into());
562                    }
563                    ResumeThread(thread);
564                    CloseHandle(thread);
565                }
566            }
567        }
568
569        let job = Mutex::new(Some(job));
570        let output = Mutex::new(output);
571        self.run_with(
572            || {
573                _ = job.lock().unwrap().take();
574                _ = output.lock().unwrap().kill();
575            },
576            || {
577                let start = std::time::Instant::now();
578                let mut warned = false;
579                loop {
580                    let res = output.lock().unwrap().try_wait()?;
581                    if let Some(status) = res {
582                        return Ok::<_, std::io::Error>(status);
583                    }
584                    if start.elapsed() > warn_time {
585                        if !warned {
586                            let child = output.lock().unwrap().id();
587                            eprintln!("Process #{child} taking too long to finish.");
588                            warned = true;
589                        }
590                    }
591                    std::thread::sleep(Duration::from_millis(10));
592                }
593            },
594        )
595    }
596
597    #[cfg(unix)]
598    pub fn run_cmd(
599        &self,
600        output: std::process::Child,
601        warn_time: Duration,
602    ) -> std::io::Result<ExitStatus> {
603        let output = Mutex::new(output);
604        self.run_with(
605            || {
606                use signal_child::{signal, signal::*};
607                let id = output.lock().unwrap().id() as i32;
608                _ = signal(-id, SIGINT);
609                std::thread::sleep(Duration::from_millis(10));
610                _ = signal(-id, SIGTERM);
611            },
612            || {
613                let start = std::time::Instant::now();
614                let mut warned = false;
615                loop {
616                    let res = output.lock().unwrap().try_wait()?;
617                    if let Some(status) = res {
618                        return Ok::<_, std::io::Error>(status);
619                    }
620                    if start.elapsed() > warn_time && !warned {
621                        let child = output.lock().unwrap().id();
622                        eprintln!("Process #{child} taking too long to finish.");
623                        warned = true;
624                    }
625                    std::thread::sleep(Duration::from_millis(10));
626                }
627            },
628        )
629    }
630}
631
632#[derive(Clone)]
633pub struct ScriptKillSender {
634    kill_sender: Arc<AtomicBool>,
635}
636
637impl ScriptKillSender {
638    pub fn new(kill_sender: Arc<AtomicBool>) -> Self {
639        Self { kill_sender }
640    }
641
642    pub fn kill(&self) {
643        self.kill_sender
644            .store(true, std::sync::atomic::Ordering::SeqCst);
645    }
646}
647
648impl ScriptRunContext {
649    pub fn new(args: ScriptRunArgs, script_path: impl AsRef<Path>, output: ScriptOutput) -> Self {
650        let mut env = ScriptEnv::default();
651        env.set_defaults(script_path.as_ref().parent().unwrap());
652
653        let kill = Arc::new(AtomicBool::new(false));
654
655        Self {
656            timeout: args.global_timeout.unwrap_or(DEFAULT_TIMEOUT),
657            args,
658            env,
659            grok: Grok::with_default_patterns(),
660            includes: Arc::new(HashMap::new()),
661            background: ScriptMode::Normal,
662            kill: ScriptKillReceiver::new(kill.clone()),
663            kill_sender: ScriptKillSender::new(kill.clone()),
664            output,
665            global_ignore: OutputPatterns::default(),
666            global_reject: OutputPatterns::default(),
667        }
668    }
669}
670
671#[derive(Clone, Debug, PartialEq, Eq)]
672pub struct ScriptLine {
673    pub location: ScriptLocation,
674    text: String,
675}
676
677impl ScriptLine {
678    pub fn new(file: ScriptFile, line: usize, text: impl AsRef<str>) -> Self {
679        Self {
680            location: ScriptLocation::new(file, line),
681            text: text.as_ref().to_string(),
682        }
683    }
684
685    pub fn parse(file: ScriptFile, text: impl AsRef<str>) -> Vec<Self> {
686        text.as_ref()
687            .lines()
688            .enumerate()
689            .map(|(line, text)| Self {
690                location: ScriptLocation::new(file.clone(), line + 1),
691                text: text.to_string(),
692            })
693            .collect()
694    }
695
696    pub fn starts_with(&self, text: &str) -> bool {
697        self.text.trim().starts_with(text)
698    }
699
700    pub fn first_char(&self) -> Option<char> {
701        self.text.trim().chars().next()
702    }
703
704    pub fn text(&self) -> &str {
705        self.text.trim()
706    }
707
708    pub fn text_untrimmed(&self) -> &str {
709        &self.text
710    }
711
712    pub fn is_empty(&self) -> bool {
713        self.text.trim().is_empty()
714    }
715
716    pub fn strip_prefix(&self, prefix: &str) -> Option<&str> {
717        self.text.strip_prefix(prefix)
718    }
719}
720
721#[derive(Debug, thiserror::Error, derive_more::Display)]
722#[display("{error} at {location}{}", associated_data.as_deref().map_or("".to_string(), |d| format!(": {d}")))]
723pub struct ScriptError {
724    pub error: ScriptErrorType,
725    pub location: ScriptLocation,
726    pub associated_data: Option<String>,
727}
728
729impl ScriptError {
730    pub fn new(error: ScriptErrorType, location: ScriptLocation) -> Self {
731        if std::env::var("PANIC_ON_ERROR").is_ok() {
732            panic!("ScriptError: {error} at {location}");
733        }
734        Self {
735            error,
736            location,
737            associated_data: None,
738        }
739    }
740
741    pub fn new_with_data(
742        error: ScriptErrorType,
743        location: ScriptLocation,
744        associated_data: String,
745    ) -> Self {
746        if std::env::var("PANIC_ON_ERROR").is_ok() {
747            panic!("ScriptError: {error} at {location}: {associated_data}");
748        }
749        Self {
750            error,
751            location,
752            associated_data: Some(associated_data),
753        }
754    }
755}
756
757#[derive(Debug, thiserror::Error, Eq, PartialEq)]
758pub enum ScriptErrorType {
759    #[error("background process not allowed")]
760    BackgroundProcessNotAllowed,
761    #[error("unclosed quote")]
762    UnclosedQuote,
763    #[error("unclosed backslash")]
764    UnclosedBackslash,
765    #[error("illegal shell command format")]
766    IllegalShellCommand,
767    #[error("unsupported redirection")]
768    UnsupportedRedirection,
769    #[error("invalid pattern definition")]
770    InvalidPatternDefinition,
771    #[error("invalid pattern")]
772    InvalidPattern,
773    #[error("invalid meta command")]
774    InvalidMetaCommand,
775    #[error("invalid pattern at global level (only reject or ignore allowed here)")]
776    InvalidGlobalPattern,
777    #[error("invalid block type")]
778    InvalidBlockType,
779    #[error("invalid block arguments")]
780    InvalidBlockArgs,
781    #[error("unsupported command position")]
782    UnsupportedCommandPosition,
783    #[error("invalid trailing pattern after *")]
784    InvalidAnyPattern,
785    #[error("invalid exit status")]
786    InvalidExitStatus,
787    #[error("invalid set variable")]
788    InvalidSetVariable,
789    #[error("invalid version header, expected `#!/usr/bin/env clitest --v0`")]
790    InvalidVersion,
791    #[error("invalid internal command")]
792    InvalidInternalCommand,
793    #[error("missing command lines")]
794    MissingCommandLines,
795    #[error(
796        "block end without matching block start, too many closing braces or braces not properly nested"
797    )]
798    InvalidBlockEnd,
799    #[error("invalid if condition")]
800    InvalidIfCondition,
801    #[error("expected block or semi-colon (did you forget to add ';' at the end of this line?)")]
802    ExpectedBlockOrSemi,
803}
804
805#[derive(Debug, thiserror::Error)]
806pub enum ScriptRunError {
807    #[error("{0}")]
808    Pattern(#[from] OutputPatternMatchFailure),
809    #[error("{0}")]
810    PatternPrepareError(#[from] OutputPatternPrepareError),
811    #[error("{0} at line {1}")]
812    Exit(CommandResult, ScriptLocation),
813    #[error("included file not found: {0}")]
814    IncludedFileNotFound(String),
815    #[error("expected failure, but passed at line {0}")]
816    ExpectedFailure(ScriptLocation),
817    #[error("{0}")]
818    ExpansionError(String),
819    #[error("{0}")]
820    IO(#[from] std::io::Error),
821    #[error("killed")]
822    Killed,
823    #[error("background process took too long to finish")]
824    BackgroundProcessTookTooLong,
825    #[error("retry took too long to finish")]
826    RetryTookTooLong,
827    /// Internal flow control: exit the script
828    #[error("exiting script")]
829    ExitScript,
830}
831
832impl ScriptRunError {
833    #[expect(unused)]
834    pub fn short(&self) -> String {
835        match self {
836            Self::Pattern(_) => "Pattern".to_string(),
837            Self::PatternPrepareError(e) => format!("PatternPrepareError({e:?})"),
838            Self::Exit(status, _) => format!("Exit({status})"),
839            Self::ExpectedFailure(_) => "ExpectedFailure".to_string(),
840            Self::IO(e) => format!("IO({:?})", e.kind()),
841            Self::Killed => "Killed".to_string(),
842            Self::BackgroundProcessTookTooLong => "BackgroundProcessTookTooLong".to_string(),
843            Self::ExpansionError(e) => "ExpansionError".to_string(),
844            Self::RetryTookTooLong => "RetryTookTooLong".to_string(),
845            Self::ExitScript => unreachable!(),
846            Self::IncludedFileNotFound(path) => format!("IncludedFileNotFound({path})"),
847        }
848    }
849}
850
851impl Script {
852    pub fn new(file: ScriptFile) -> Self {
853        Self {
854            commands: Arc::new(vec![]),
855            includes: Arc::new(HashMap::new()),
856            file,
857        }
858    }
859
860    /// Collect all included script paths from the script.
861    pub fn includes(&self) -> Vec<(ScriptLocation, String)> {
862        self.commands
863            .iter()
864            .flat_map(|block| block.includes())
865            .collect()
866    }
867
868    pub fn run(&self, context: &mut ScriptRunContext) -> Result<(), ScriptRunError> {
869        let old_includes = context.includes.clone();
870        context.includes = self.includes.clone();
871        let res = ScriptBlock::run_blocks(context, &self.commands);
872        context.includes = old_includes;
873        let v = match res {
874            Ok(v) => v,
875            // Bypass normal script processing and exit successfully
876            Err(ScriptRunError::ExitScript) => return Ok(()),
877            Err(e) => return Err(e),
878        };
879        assert!(v.is_empty(), "script did not run to completion: {v:?}");
880        Ok(())
881    }
882
883    pub fn run_with_args(
884        &self,
885        args: ScriptRunArgs,
886        output: ScriptOutput,
887    ) -> Result<(), ScriptRunError> {
888        let start = Instant::now();
889        let script_path = &*self.file.file;
890        let mut context = ScriptRunContext::new(args, script_path, output);
891
892        // Write "Running..." message with colors
893        cwrite!(context.stream(), "Running ");
894        cwrite!(context.stream(), fg = Color::Cyan, "{}", script_path);
895        cwriteln!(context.stream(), " ...");
896        cwriteln!(context.stream());
897
898        let result = self.run(&mut context);
899
900        // Handle success and error output
901        if let Err(ref e) = result {
902            cwrite!(context.stream(), fg = Color::Cyan, "{} ", script_path);
903            cwrite!(context.stream(), fg = Color::Red, "FAILED");
904            if !context.args.simplified_output {
905                cwriteln!(context.stream(), " ({:.2}s)", start.elapsed().as_secs_f32());
906            } else {
907                cwriteln!(context.stream());
908            }
909            cwrite!(context.stream(), fg = Color::Red, "Error: ");
910            cwriteln!(context.stream(), "{}", e);
911            cwriteln!(context.stream());
912        } else {
913            cwrite!(context.stream(), fg = Color::Cyan, "{} ", script_path);
914            cwrite!(context.stream(), fg = Color::Green, "PASSED");
915            if !context.args.simplified_output {
916                cwriteln!(context.stream(), " ({:.2}s)", start.elapsed().as_secs_f32());
917            } else {
918                cwriteln!(context.stream());
919            }
920        }
921
922        result
923    }
924}
925
926#[derive(Debug, Default, Serialize)]
927pub enum CommandExit {
928    #[default]
929    Success,
930    Failure(i32),
931    Timeout,
932    Any,
933    AnyFailure,
934}
935
936impl CommandExit {
937    pub fn matches(&self, status: CommandResult) -> bool {
938        match (self, status) {
939            (CommandExit::Success, CommandResult::Exit(status)) => status.success(),
940            (CommandExit::Failure(code), CommandResult::Exit(status)) => {
941                *code == status.code().unwrap_or(-1)
942            }
943            (CommandExit::Timeout, CommandResult::TimedOut) => true,
944            (CommandExit::Any, _) => true,
945            (CommandExit::AnyFailure, CommandResult::Exit(status)) => !status.success(),
946            (CommandExit::AnyFailure, _) => true,
947            _ => false,
948        }
949    }
950
951    pub fn is_success(&self) -> bool {
952        matches!(self, CommandExit::Success)
953    }
954}
955
956#[derive(derive_more::Debug)]
957#[allow(clippy::large_enum_variant)]
958pub enum ScriptBlock {
959    Command(ScriptCommand),
960    InternalCommand(ScriptLocation, InternalCommand),
961    Background(Vec<ScriptBlock>),
962    Defer(Vec<ScriptBlock>),
963    If(IfCondition, Vec<ScriptBlock>),
964    For(ForCondition, Vec<ScriptBlock>),
965    Retry(Vec<ScriptBlock>),
966    GlobalIgnore(OutputPatterns),
967    GlobalReject(OutputPatterns),
968}
969
970impl ScriptBlock {
971    pub fn includes(&self) -> Vec<(ScriptLocation, String)> {
972        match self {
973            ScriptBlock::Command(..) => vec![],
974            ScriptBlock::InternalCommand(location, InternalCommand::Include(path)) => {
975                vec![(location.clone(), path.clone())]
976            }
977            ScriptBlock::InternalCommand(..) => vec![],
978            ScriptBlock::Background(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
979            ScriptBlock::Defer(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
980            ScriptBlock::If(_, blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
981            ScriptBlock::For(_, blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
982            ScriptBlock::Retry(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
983            ScriptBlock::GlobalIgnore(_) => vec![],
984            ScriptBlock::GlobalReject(_) => vec![],
985        }
986    }
987
988    #[allow(clippy::type_complexity)]
989    pub fn run_blocks(
990        context: &mut ScriptRunContext,
991        blocks: &[ScriptBlock],
992    ) -> Result<Vec<ScriptResult>, ScriptRunError> {
993        enum Deferred<'a> {
994            Scripts(&'a [ScriptBlock]),
995            Internal(
996                Box<
997                    dyn FnOnce(&mut ScriptRunContext) -> Result<(), ScriptRunError>
998                        + Send
999                        + Sync
1000                        + 'a,
1001                >,
1002            ),
1003            Background(
1004                ScopedJoinHandle<'a, Result<Vec<ScriptResult>, ScriptRunError>>,
1005                ScriptKillSender,
1006            ),
1007        }
1008
1009        let mut results = Vec::new();
1010        std::thread::scope(|s| {
1011            let mut defer_blocks = VecDeque::new();
1012            let mut pending_error = None;
1013            for block in blocks {
1014                if context.kill.is_killed() {
1015                    return Err(ScriptRunError::Killed);
1016                }
1017                match block {
1018                    ScriptBlock::Background(blocks) => {
1019                        let mut context = context.new_background();
1020                        let kill_sender = context.kill_sender.clone();
1021                        let handle = s.spawn(move || Self::run_blocks(&mut context, blocks));
1022                        defer_blocks.push_front(Deferred::Background(handle, kill_sender));
1023                    }
1024                    ScriptBlock::Defer(blocks) => {
1025                        // Insert at the front of the queue by extending and
1026                        // then rotating
1027                        defer_blocks.push_front(Deferred::Scripts(blocks));
1028                    }
1029                    ScriptBlock::InternalCommand(_, command) => {
1030                        if context.background == ScriptMode::Deferred {
1031                            cwrite!(context.stream(), dimmed = true, "(deferred) ");
1032                        }
1033                        if let Some(f) = command.run(context)? {
1034                            defer_blocks.push_front(Deferred::Internal(f));
1035                        }
1036                    }
1037                    _ => match block.run(context) {
1038                        Ok(res) => results.extend(res),
1039                        Err(e) => {
1040                            pending_error = Some(e);
1041                            break;
1042                        }
1043                    },
1044                }
1045            }
1046            for block in defer_blocks {
1047                match block {
1048                    Deferred::Scripts(blocks) => {
1049                        let mut context = context.new_deferred();
1050                        ScriptBlock::run_blocks(&mut context, blocks)?;
1051                    }
1052                    Deferred::Internal(block) => {
1053                        cwrite!(context.stream(), dimmed = true, "(cleanup) ");
1054                        block(context)?;
1055                    }
1056                    Deferred::Background(handle, kill_sender) => {
1057                        kill_sender.kill();
1058                        let start = std::time::Instant::now();
1059                        let mut warned = false;
1060
1061                        let timeout = context.timeout;
1062                        let warn_at = timeout * 8 / 10;
1063
1064                        let results = loop {
1065                            if handle.is_finished() {
1066                                break handle.join().unwrap()?;
1067                            }
1068                            std::thread::sleep(std::time::Duration::from_millis(10));
1069                            if !warned && start.elapsed() > warn_at {
1070                                cwriteln!(
1071                                    context.stream(),
1072                                    fg = Color::Yellow,
1073                                    "Background process is taking too long to finish."
1074                                );
1075                                warned = true;
1076                            }
1077                            if start.elapsed() > timeout {
1078                                cwriteln!(
1079                                    context.stream(),
1080                                    fg = Color::Red,
1081                                    "Background process took too long to finish."
1082                                );
1083                                return Err(ScriptRunError::BackgroundProcessTookTooLong);
1084                            }
1085                        };
1086                        for result in results {
1087                            cwrite!(context.stream(), dimmed = true, "(background) ");
1088                            for line in result.command.command.split('\n') {
1089                                cwriteln!(context.stream(), fg = Color::Green, "{}", line);
1090                            }
1091                            if context.args.simplified_output {
1092                                cwriteln!(context.stream(), dimmed = true, "---");
1093                            } else {
1094                                cwriteln_rule!(
1095                                    context.stream(),
1096                                    fg = Color::Cyan,
1097                                    "{}",
1098                                    result.command.location
1099                                );
1100                            }
1101                            for line in &result.output {
1102                                cwriteln!(context.stream(), "{}", line);
1103                            }
1104                            if result.output.is_empty() {
1105                                cwriteln!(context.stream(), dimmed = true, "(no output)");
1106                            }
1107                            if context.args.simplified_output {
1108                                cwriteln!(context.stream(), dimmed = true, "---");
1109                            } else {
1110                                cwriteln_rule!(context.stream());
1111                            }
1112                            result.evaluate(context)?;
1113                        }
1114                    }
1115                }
1116            }
1117            if let Some(error) = pending_error {
1118                return Err(error);
1119            }
1120            Ok(results)
1121        })
1122    }
1123
1124    pub fn run(&self, context: &mut ScriptRunContext) -> Result<Vec<ScriptResult>, ScriptRunError> {
1125        let pwd = context.pwd();
1126        let res = pwd.exists();
1127        if !matches!(res, Ok(true)) {
1128            cwriteln!(
1129                context.stream(),
1130                fg = Color::Red,
1131                "$PWD {pwd:?} doesn't exist. Run `cd $INITIAL_PWD` to fix.",
1132            );
1133            return Err(ScriptRunError::IO(std::io::Error::new(
1134                std::io::ErrorKind::NotFound,
1135                format!("PWD does not exist: {pwd:?}"),
1136            )));
1137        }
1138
1139        match self {
1140            ScriptBlock::Command(command) => {
1141                if context.background == ScriptMode::Deferred {
1142                    cwrite!(context.stream(), dimmed = true, "(deferred) ");
1143                }
1144                let result = command.run(context)?;
1145                if context.background != ScriptMode::Background {
1146                    result.evaluate(context)?;
1147                    Ok(vec![])
1148                } else {
1149                    Ok(vec![result])
1150                }
1151            }
1152            ScriptBlock::If(condition, blocks) => {
1153                let condition = condition.expand(context)?;
1154                if condition.matches(context) {
1155                    Self::run_blocks(context, blocks)
1156                } else {
1157                    Ok(vec![])
1158                }
1159            }
1160            ScriptBlock::For(ForCondition::Env(env, values), blocks) => {
1161                let mut results = Vec::new();
1162                for value in values {
1163                    context.set_env(env, context.expand(value)?);
1164                    results.extend(Self::run_blocks(context, blocks)?);
1165                }
1166                Ok(results)
1167            }
1168            ScriptBlock::Retry(blocks) => {
1169                let start = Instant::now();
1170                let mut backoff = Duration::from_millis(100);
1171
1172                cwrite!(context.stream(), fg = Color::Green, "retry: ");
1173                cwriteln!(context.stream(), "running...");
1174
1175                loop {
1176                    let mut nested_context = context.new_background();
1177                    if let Ok(results) = Self::run_blocks(&mut nested_context, blocks) {
1178                        let mut all_ok = true;
1179                        for result in results {
1180                            if result.evaluate(&mut nested_context).is_err() {
1181                                all_ok = false;
1182                                break;
1183                            }
1184                        }
1185                        if all_ok {
1186                            let output = nested_context.take_output();
1187                            cwrite!(context.stream(), fg = Color::Green, "retry: ");
1188                            cwriteln!(context.stream(), "success");
1189                            cwriteln!(context.stream());
1190                            cwriteln!(context.stream(), "{output}");
1191                            return Ok(vec![]);
1192                        }
1193                    }
1194
1195                    if start.elapsed() > context.timeout {
1196                        let output = nested_context.take_output();
1197                        cwrite!(context.stream(), fg = Color::Green, "retry: ");
1198                        cwriteln!(context.stream(), fg = Color::Red, "timed out");
1199                        cwriteln!(context.stream());
1200                        cwriteln!(context.stream(), "{output}");
1201                        cwriteln_rule!(context.stream());
1202                        return Err(ScriptRunError::RetryTookTooLong);
1203                    }
1204                    std::thread::sleep(backoff);
1205                    backoff *= 2;
1206                }
1207            }
1208            ScriptBlock::GlobalIgnore(patterns) => {
1209                for pattern in patterns.iter() {
1210                    pattern.prepare(&context.grok)?;
1211                }
1212                context.global_ignore.extend(patterns);
1213                Ok(vec![])
1214            }
1215            ScriptBlock::GlobalReject(patterns) => {
1216                for pattern in patterns.iter() {
1217                    pattern.prepare(&context.grok)?;
1218                }
1219                context.global_reject.extend(patterns);
1220                Ok(vec![])
1221            }
1222            _ => unreachable!("Unexpected block type: {self:?}"),
1223        }
1224    }
1225}
1226
1227impl Serialize for ScriptBlock {
1228    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1229    where
1230        S: serde::Serializer,
1231    {
1232        match self {
1233            ScriptBlock::Command(command) => command.serialize(serializer),
1234            ScriptBlock::InternalCommand(_, command) => command.serialize(serializer),
1235            ScriptBlock::Background(blocks) => {
1236                let mut ser = serializer.serialize_map(Some(1))?;
1237                ser.serialize_entry("background", blocks)?;
1238                ser.end()
1239            }
1240            ScriptBlock::Defer(blocks) => {
1241                let mut ser = serializer.serialize_map(Some(1))?;
1242                ser.serialize_entry("defer", blocks)?;
1243                ser.end()
1244            }
1245            ScriptBlock::If(condition, blocks) => {
1246                let mut ser = serializer.serialize_map(Some(2))?;
1247                ser.serialize_entry("if", condition)?;
1248                ser.serialize_entry("blocks", blocks)?;
1249                ser.end()
1250            }
1251            ScriptBlock::For(condition, blocks) => {
1252                let mut ser = serializer.serialize_map(Some(2))?;
1253                ser.serialize_entry("for", condition)?;
1254                ser.serialize_entry("blocks", blocks)?;
1255                ser.end()
1256            }
1257            ScriptBlock::Retry(blocks) => {
1258                let mut ser = serializer.serialize_map(Some(1))?;
1259                ser.serialize_entry("retry", blocks)?;
1260                ser.end()
1261            }
1262            ScriptBlock::GlobalIgnore(patterns) => {
1263                let mut ser = serializer.serialize_map(Some(1))?;
1264                ser.serialize_entry("ignore", patterns)?;
1265                ser.end()
1266            }
1267            ScriptBlock::GlobalReject(patterns) => {
1268                let mut ser = serializer.serialize_map(Some(1))?;
1269                ser.serialize_entry("reject", patterns)?;
1270                ser.end()
1271            }
1272        }
1273    }
1274}
1275
1276#[derive(Debug, Clone, Serialize)]
1277pub enum InternalCommand {
1278    UsingTempdir,
1279    UsingDir(ShellBit, bool),
1280    ChangeDir(ShellBit),
1281    Set(String, ShellBit),
1282    Include(String),
1283    ExitScript,
1284    Pattern(String, String),
1285}
1286
1287impl InternalCommand {
1288    #[allow(clippy::type_complexity)]
1289    pub fn run(
1290        &self,
1291        context: &mut ScriptRunContext,
1292    ) -> Result<
1293        Option<Box<dyn FnOnce(&mut ScriptRunContext) -> Result<(), ScriptRunError> + Send + Sync>>,
1294        ScriptRunError,
1295    > {
1296        match self.clone() {
1297            InternalCommand::Include(path) => {
1298                let Some(script) = context.includes.get(&path) else {
1299                    return Err(ScriptRunError::IncludedFileNotFound(path));
1300                };
1301                script.clone().run(context)?;
1302                Ok(None)
1303            }
1304            InternalCommand::Pattern(name, pattern) => {
1305                context.grok.add_pattern(name, pattern);
1306                Ok(None)
1307            }
1308            InternalCommand::UsingTempdir => {
1309                let current_pwd = context.pwd();
1310                let tempdir = NiceTempDir::new();
1311                cwrite!(context.stream(), fg = Color::Yellow, "using tempdir: ");
1312                cwriteln!(context.stream(), "{}", tempdir);
1313                cwriteln!(context.stream());
1314                context.set_pwd(&tempdir);
1315                let pwd = context.pwd();
1316                if !pwd.exists()? {
1317                    return Err(ScriptRunError::IO(std::io::Error::new(
1318                        std::io::ErrorKind::NotFound,
1319                        format!("newly created tempdir does not exist: {pwd:?}"),
1320                    )));
1321                }
1322                Ok(Some(Box::new(move |context: &mut ScriptRunContext| {
1323                    cwriteln!(
1324                        context.stream(),
1325                        fg = Color::Yellow,
1326                        "removing {} && cd {}",
1327                        tempdir,
1328                        current_pwd
1329                    );
1330                    cwriteln!(context.stream());
1331                    if !tempdir.exists()? {
1332                        cwriteln!(
1333                            context.stream(),
1334                            fg = Color::Red,
1335                            "tempdir does not exist: {tempdir}"
1336                        );
1337                    }
1338                    if let Err(e) = tempdir.remove_dir_all() {
1339                        cwriteln!(
1340                            context.stream(),
1341                            fg = Color::Red,
1342                            "error removing tempdir: {e:?}"
1343                        );
1344                    }
1345                    Ok::<_, ScriptRunError>(())
1346                })))
1347            }
1348            InternalCommand::UsingDir(dir, new) => {
1349                let current_pwd = context.pwd();
1350                let dir = context.expand(&dir)?;
1351                let new_pwd = current_pwd.join(dir);
1352                if new {
1353                    cwrite!(context.stream(), fg = Color::Yellow, "using new dir: ");
1354                } else {
1355                    cwrite!(context.stream(), fg = Color::Yellow, "using dir: ");
1356                }
1357                cwriteln!(context.stream(), "{}", new_pwd);
1358                cwriteln!(context.stream());
1359
1360                if new {
1361                    new_pwd.create_dir_all()?;
1362                } else if !new_pwd.exists()? {
1363                    return Err(ScriptRunError::IO(std::io::Error::new(
1364                        std::io::ErrorKind::NotFound,
1365                        "directory does not exist",
1366                    )));
1367                }
1368                context.set_pwd(&new_pwd);
1369                Ok(Some(Box::new(move |context: &mut ScriptRunContext| {
1370                    if new {
1371                        cwriteln!(
1372                            context.stream(),
1373                            fg = Color::Yellow,
1374                            "removing {} && cd {}",
1375                            new_pwd,
1376                            current_pwd
1377                        );
1378                        cwriteln!(context.stream());
1379                    } else {
1380                        cwriteln!(context.stream(), fg = Color::Yellow, "cd {}", current_pwd);
1381                        cwriteln!(context.stream());
1382                    }
1383                    if new {
1384                        new_pwd.remove_dir_all()?;
1385                    }
1386                    context.set_pwd(current_pwd);
1387                    Ok::<_, ScriptRunError>(())
1388                })))
1389            }
1390            InternalCommand::ChangeDir(dir) => {
1391                let dir = context.expand(&dir)?;
1392
1393                cwriteln!(context.stream(), fg = Color::Yellow, "cd {dir}");
1394                cwriteln!(context.stream());
1395                let current_pwd = context.pwd();
1396                let new_pwd = current_pwd.join(dir);
1397                if !new_pwd.exists()? {
1398                    return Err(ScriptRunError::IO(std::io::Error::new(
1399                        std::io::ErrorKind::NotFound,
1400                        format!("directory does not exist: {new_pwd}"),
1401                    )));
1402                }
1403                context.set_pwd(new_pwd);
1404                Ok(None)
1405            }
1406            InternalCommand::Set(name, value) => {
1407                let value = context.expand(&value)?;
1408
1409                context.set_env(&name, &value);
1410                let new_value = context.get_env(&name).unwrap_or_default();
1411                if new_value != value {
1412                    cwriteln!(
1413                        context.stream(),
1414                        fg = Color::Yellow,
1415                        "set {name} {value} (-> {new_value})"
1416                    );
1417                } else {
1418                    cwriteln!(context.stream(), fg = Color::Yellow, "set {name} {value}");
1419                }
1420                cwriteln!(context.stream());
1421
1422                Ok(None)
1423            }
1424            InternalCommand::ExitScript => {
1425                cwriteln!(context.stream(), fg = Color::Yellow, "exiting script");
1426                cwriteln!(context.stream());
1427                Err(ScriptRunError::ExitScript)
1428            }
1429        }
1430    }
1431}
1432
1433#[derive(Debug, Clone)]
1434pub enum IfCondition {
1435    True,
1436    False,
1437    EnvEq(bool, String, ShellBit),
1438}
1439
1440impl IfCondition {
1441    pub fn matches(&self, context: &ScriptRunContext) -> bool {
1442        match self {
1443            IfCondition::True => true,
1444            IfCondition::False => false,
1445            IfCondition::EnvEq(negated, name, expected) => {
1446                let value = context.get_env(name).unwrap_or_default();
1447                (expected == value) ^ negated
1448            }
1449        }
1450    }
1451
1452    pub fn expand(&self, context: &ScriptRunContext) -> Result<IfCondition, ScriptRunError> {
1453        match self {
1454            IfCondition::True => Ok(IfCondition::True),
1455            IfCondition::False => Ok(IfCondition::False),
1456            IfCondition::EnvEq(negated, name, expected) => {
1457                let value = context.expand(expected)?;
1458                Ok(IfCondition::EnvEq(
1459                    *negated,
1460                    name.clone(),
1461                    ShellBit::Literal(value),
1462                ))
1463            }
1464        }
1465    }
1466}
1467
1468impl Serialize for IfCondition {
1469    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1470    where
1471        S: serde::Serializer,
1472    {
1473        match self {
1474            IfCondition::True => "true".serialize(serializer),
1475            IfCondition::False => "false".serialize(serializer),
1476            IfCondition::EnvEq(negated, name, value) => {
1477                let mut ser = serializer.serialize_map(Some(3))?;
1478                ser.serialize_entry("op", if *negated { "!=" } else { "==" })?;
1479                ser.serialize_entry("env", name)?;
1480                ser.serialize_entry("value", value)?;
1481                ser.end()
1482            }
1483        }
1484    }
1485}
1486
1487#[derive(Debug)]
1488pub enum ForCondition {
1489    Env(String, Vec<ShellBit>),
1490}
1491
1492impl Serialize for ForCondition {
1493    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1494    where
1495        S: serde::Serializer,
1496    {
1497        match self {
1498            ForCondition::Env(name, values) => {
1499                let mut ser = serializer.serialize_map(Some(2))?;
1500                ser.serialize_entry("env", name)?;
1501                ser.serialize_entry("values", values)?;
1502                ser.end()
1503            }
1504        }
1505    }
1506}
1507
1508fn is_bool_false(b: &bool) -> bool {
1509    !b
1510}
1511
1512#[derive(Debug, Serialize)]
1513pub struct ScriptCommand {
1514    pub command: CommandLine,
1515    pub pattern: OutputPattern,
1516
1517    #[serde(skip_serializing_if = "CommandExit::is_success")]
1518    pub exit: CommandExit,
1519
1520    #[serde(skip_serializing_if = "is_bool_false")]
1521    pub expect_failure: bool,
1522
1523    /// Single set variable (entire command output trimmed)
1524    #[serde(skip_serializing_if = "Option::is_none")]
1525    pub set_var: Option<String>,
1526
1527    /// Specific set variables
1528    pub set_vars: HashMap<String, ShellBit>,
1529
1530    /// Specific command timeout
1531    #[serde(skip_serializing_if = "Option::is_none")]
1532    pub timeout: Option<Duration>,
1533
1534    /// Input grok expectations
1535    pub expect: HashMap<String, ShellBit>,
1536}
1537
1538impl ScriptCommand {
1539    pub fn new(command: CommandLine) -> Self {
1540        let location = command.location.clone();
1541        Self {
1542            command,
1543            pattern: OutputPattern {
1544                pattern: OutputPatternType::None,
1545                ignore: Default::default(),
1546                reject: Default::default(),
1547                location,
1548            },
1549            exit: Default::default(),
1550            timeout: None,
1551            expect_failure: false,
1552            set_var: None,
1553            set_vars: Default::default(),
1554            expect: Default::default(),
1555        }
1556    }
1557
1558    pub fn run(&self, context: &mut ScriptRunContext) -> Result<ScriptResult, ScriptRunError> {
1559        let command = &self.command;
1560        let args = &context.args;
1561        let start = Instant::now();
1562
1563        if let Some(delay) = args.delay_steps {
1564            std::thread::sleep(std::time::Duration::from_millis(delay));
1565        }
1566
1567        for line in command.command.split('\n') {
1568            cwriteln!(context.stream(), fg = Color::Green, "{}", line);
1569        }
1570        if args.simplified_output {
1571            cwriteln!(context.stream(), dimmed = true, "---");
1572        } else {
1573            cwriteln_rule!(context.stream(), fg = Color::Cyan, "{}", command.location);
1574        }
1575        let (output, status) = command.run(
1576            &mut context.stream(),
1577            context.args.show_line_numbers,
1578            context.args.runner.clone(),
1579            self.timeout.unwrap_or(context.timeout),
1580            context.env.env_vars(),
1581            &context.kill,
1582            &context.kill_sender,
1583        )?;
1584
1585        let exit_result = if !self.exit.matches(status) {
1586            ExitResult::Mismatch(status)
1587        } else {
1588            ExitResult::Matches(status)
1589        };
1590
1591        // Side-effects
1592        if let Some(set_var) = &self.set_var {
1593            context.set_env(set_var, output.to_string().trim());
1594        }
1595
1596        let match_context = OutputMatchContext::new(context);
1597        for (key, value) in &self.expect {
1598            match_context.expect(key, context.expand(value)?);
1599        }
1600        self.pattern.prepare(&context.grok)?;
1601        let prepared_output = output
1602            .with_ignore(&context.global_ignore)
1603            .with_reject(&context.global_reject);
1604        let pattern_result = match self.pattern.matches(match_context.clone(), prepared_output) {
1605            Ok(_) => {
1606                let mut env = context.env.clone();
1607                for (key, value) in match_context.expects() {
1608                    env.set_env(key, value);
1609                }
1610                for (key, value) in &self.set_vars {
1611                    context.set_env(key, env.expand(value)?);
1612                }
1613
1614                if self.expect_failure {
1615                    PatternResult::ExpectedFailure
1616                } else {
1617                    PatternResult::Matches
1618                }
1619            }
1620            Err(e) => {
1621                if self.expect_failure {
1622                    PatternResult::MatchesFailure
1623                } else {
1624                    let mut trace = String::new();
1625                    for line in match_context.traces() {
1626                        trace.push_str(&format!("{line}\n"));
1627                    }
1628                    PatternResult::Mismatch(e, trace)
1629                }
1630            }
1631        };
1632
1633        if output.is_empty() {
1634            cwriteln!(context.stream(), dimmed = true, "(no output)");
1635        }
1636
1637        if context.args.simplified_output {
1638            cwriteln!(context.stream(), dimmed = true, "---");
1639        } else {
1640            cwriteln_rule!(context.stream());
1641        }
1642
1643        Ok(ScriptResult {
1644            command: command.clone(),
1645            pattern: pattern_result,
1646            exit: exit_result,
1647            elapsed: start.elapsed(),
1648            output,
1649        })
1650    }
1651}
1652
1653#[derive(derive_more::Debug)]
1654pub struct ScriptResult {
1655    pub command: CommandLine,
1656    pub pattern: PatternResult,
1657    pub exit: ExitResult,
1658    pub elapsed: Duration,
1659    #[debug(skip)]
1660    pub output: Lines,
1661}
1662
1663impl ScriptResult {
1664    pub fn evaluate(&self, context: &mut ScriptRunContext) -> Result<(), ScriptRunError> {
1665        let args = &context.args;
1666        let (success, failure, warning, arrow) = if *crate::term::IS_UTF8 {
1667            ("✅", "❌", "⚠️", "→")
1668        } else {
1669            ("[*]", "[X]", "[!]", "->")
1670        };
1671
1672        if let ExitResult::Mismatch(status) = self.exit {
1673            if args.ignore_exit_codes {
1674                cwriteln!(
1675                    context.stream(),
1676                    fg = Color::Yellow,
1677                    "{warning} Ignored incorrect exit code: {status}"
1678                );
1679                cwriteln!(context.stream());
1680            } else {
1681                cwriteln!(
1682                    context.stream(),
1683                    fg = Color::Red,
1684                    "{failure} FAIL: {status}"
1685                );
1686                cwriteln!(
1687                    context.stream(),
1688                    dimmed = true,
1689                    " {arrow} {}",
1690                    self.command.command
1691                );
1692                cwriteln!(context.stream());
1693                return Err(ScriptRunError::Exit(status, self.command.location.clone()));
1694            }
1695        }
1696
1697        if let PatternResult::Mismatch(e, trace) = &self.pattern {
1698            if args.ignore_matches {
1699                cwriteln!(
1700                    context.stream(),
1701                    fg = Color::Yellow,
1702                    "{warning} Ignored error: {e} (ignoring mismatches)"
1703                );
1704                cwriteln!(context.stream());
1705            } else {
1706                cwriteln!(context.stream(), fg = Color::Red, "ERROR: {e}");
1707                cwriteln!(context.stream(), dimmed = true, "{trace}");
1708                cwriteln!(context.stream(), fg = Color::Red, "{failure} FAIL");
1709                cwriteln!(context.stream());
1710                return Err(ScriptRunError::Pattern(e.clone()));
1711            }
1712        }
1713
1714        if let PatternResult::ExpectedFailure = self.pattern {
1715            if args.ignore_matches {
1716                cwriteln!(
1717                    context.stream(),
1718                    fg = Color::Yellow,
1719                    "{warning} Should not have matched! (ignoring mismatches)"
1720                );
1721                cwriteln!(context.stream());
1722            } else {
1723                cwriteln!(
1724                    context.stream(),
1725                    fg = Color::Red,
1726                    "{failure} FAIL (output shouldn't match)"
1727                );
1728                cwriteln!(
1729                    context.stream(),
1730                    dimmed = true,
1731                    " {arrow} {}",
1732                    self.command.command
1733                );
1734                cwriteln!(context.stream());
1735                return Err(ScriptRunError::ExpectedFailure(
1736                    self.command.location.clone(),
1737                ));
1738            }
1739        }
1740
1741        if let ExitResult::Matches(status) = self.exit {
1742            if status.success() {
1743                cwrite!(context.stream(), fg = Color::Green, "{success} OK");
1744                if !context.args.simplified_output {
1745                    cwriteln!(
1746                        context.stream(),
1747                        dimmed = true,
1748                        " ({:.2}s)",
1749                        self.elapsed.as_secs_f32()
1750                    );
1751                } else {
1752                    cwriteln!(context.stream());
1753                }
1754            } else {
1755                cwrite!(
1756                    context.stream(),
1757                    fg = Color::Green,
1758                    "{success} OK ({status})"
1759                );
1760                if !context.args.simplified_output {
1761                    cwriteln!(
1762                        context.stream(),
1763                        dimmed = true,
1764                        " ({:.2}s)",
1765                        self.elapsed.as_secs_f32()
1766                    );
1767                } else {
1768                    cwriteln!(context.stream());
1769                }
1770            }
1771            cwriteln!(context.stream());
1772        }
1773
1774        Ok(())
1775    }
1776}
1777
1778#[derive(Debug)]
1779pub enum PatternResult {
1780    Matches,
1781    MatchesFailure,
1782    ExpectedFailure,
1783    Mismatch(OutputPatternMatchFailure, String),
1784}
1785
1786#[derive(Debug)]
1787pub enum ExitResult {
1788    Matches(CommandResult),
1789    Mismatch(CommandResult),
1790    TimedOut,
1791}
1792
1793#[cfg(test)]
1794mod tests {
1795    use crate::parser::v0::parse_script;
1796
1797    use super::*;
1798    use std::error::Error;
1799
1800    #[test]
1801    fn test_script() -> Result<(), Box<dyn Error>> {
1802        let script = r#"
1803pattern VERSION \d+\.\d+\.\d+;
1804
1805$ something --version || echo 1
1806? Something %{VERSION}
1807
1808$ something --help
1809? Usage: something [OPTIONS]
1810repeat {
1811    choice {
1812? %{DATA} %{GREEDYDATA}
1813? %{DATA}=%{DATA} %{GREEDYDATA}
1814    }
1815}
1816"#;
1817
1818        let script = parse_script(ScriptFile::new("test.cli"), script)?;
1819        assert_eq!(script.commands.len(), 3);
1820        eprintln!("{script:?}");
1821        Ok(())
1822    }
1823
1824    #[test]
1825    fn test_bad_script() -> Result<(), Box<dyn Error>> {
1826        let script = r#"
1827$ (cmd; cmd)
1828$ cmd &
1829    "#;
1830
1831        assert!(matches!(
1832            parse_script(ScriptFile::new("test.cli"), script),
1833            Err(ScriptError {
1834                error: ScriptErrorType::BackgroundProcessNotAllowed,
1835                ..
1836            })
1837        ));
1838        Ok(())
1839    }
1840
1841    #[test]
1842    fn test_script_run_context_expand() {
1843        let mut context = ScriptEnv::default();
1844        context.set_env("A", "1");
1845        context.set_env("B", "2");
1846        context.set_env("C", "3");
1847        assert_eq!(context.expand_str("$A").unwrap(), "1".to_string());
1848        assert_eq!(context.expand_str("$A $B ").unwrap(), "1 2 ".to_string());
1849        assert_eq!(
1850            context.expand_str("${A} ${B} ").unwrap(),
1851            "1 2 ".to_string()
1852        );
1853        assert_eq!(context.expand_str(r#"\$A"#).unwrap(), "$A".to_string());
1854        assert_eq!(context.expand_str(r#"\${A}"#).unwrap(), "${A}".to_string());
1855        assert_eq!(context.expand_str(r#"\\$A"#).unwrap(), r#"\1"#);
1856        assert_eq!(context.expand_str(r#"\\${A}"#).unwrap(), r#"\1"#);
1857        context.set_env("TEMP_DIR", "/tmp");
1858        assert_eq!(context.expand_str("$TEMP_DIR").unwrap(), "/tmp".to_string());
1859        assert_eq!(
1860            context.expand_str("${TEMP_DIR}").unwrap(),
1861            "/tmp".to_string()
1862        );
1863    }
1864}