clitest_lib/
script.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    path::Path,
4    process::ExitStatus,
5    sync::{Arc, Mutex, atomic::AtomicBool},
6    thread::ScopedJoinHandle,
7    time::{Duration, Instant},
8};
9
10use grok::Grok;
11use keepcalm::SharedMut;
12use serde::{Serialize, ser::SerializeMap};
13use termcolor::{Color, ColorChoice, WriteColor};
14
15use crate::{
16    command::{CommandLine, CommandResult},
17    util::{NicePathBuf, NiceTempDir},
18};
19use crate::{cwrite, cwriteln, cwriteln_rule};
20use crate::{output::*, util::ShellBit};
21
22const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
23
24#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
25pub struct ScriptLocation {
26    pub file: ScriptFile,
27    pub line: usize,
28}
29
30impl ScriptLocation {
31    pub fn new(file: ScriptFile, line: usize) -> Self {
32        Self { file, line }
33    }
34}
35
36impl std::fmt::Display for ScriptLocation {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        write!(f, "{}:{}", self.file, self.line)
39    }
40}
41
42#[derive(
43    derive_more::Debug, derive_more::Display, Clone, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash,
44)]
45#[display("{}", file)]
46pub struct ScriptFile {
47    pub file: Arc<NicePathBuf>,
48}
49
50impl ScriptFile {
51    pub fn new(file: impl AsRef<Path>) -> Self {
52        Self {
53            file: Arc::new(NicePathBuf::new(file)),
54        }
55    }
56}
57
58impl<T: AsRef<Path>> From<T> for ScriptFile {
59    fn from(file: T) -> Self {
60        Self::new(file)
61    }
62}
63
64#[derive(Clone, derive_more::Debug, Serialize)]
65pub struct Script {
66    pub commands: Arc<Vec<ScriptBlock>>,
67    pub includes: Arc<HashMap<String, Script>>,
68    pub file: ScriptFile,
69}
70
71#[derive(Debug, Clone, Default)]
72pub struct ScriptRunArgs {
73    pub delay_steps: Option<u64>,
74    pub ignore_exit_codes: bool,
75    pub ignore_matches: bool,
76    pub simplified_output: bool,
77    pub show_line_numbers: bool,
78    pub runner: Option<String>,
79    pub quiet: bool,
80    pub verbose: bool,
81    pub global_timeout: Option<Duration>,
82    pub no_color: bool,
83}
84
85#[derive(Debug, Clone, Default)]
86pub struct ScriptEnv {
87    env_vars: HashMap<String, String>,
88}
89
90impl ScriptEnv {
91    pub fn set_defaults(&mut self, pwd: impl AsRef<Path>) {
92        macro_rules! target {
93            ($env:ident, $var:ident, [$($vals:expr),*]) => {
94                $(
95                if cfg!($var = $vals) {
96                    self.env_vars.insert(stringify!($env).to_string(), $vals.to_string());
97                }
98                )*
99            };
100        }
101
102        target!(
103            TARGET_OS,
104            target_os,
105            ["windows", "linux", "macos", "ios", "android"]
106        );
107        target!(TARGET_FAMILY, target_family, ["windows", "unix", "wasm"]);
108        target!(
109            TARGET_ARCH,
110            target_arch,
111            ["x86", "x86_64", "arm", "aarch64"]
112        );
113
114        // Set the current working directory as a special variable "PWD"
115        self.env_vars.insert(
116            "PWD".to_string(),
117            NicePathBuf::from(pwd.as_ref()).env_string(),
118        );
119        // Save the initial PWD as INITIAL_PWD so it can easily be restored
120        self.env_vars
121            .insert("INITIAL_PWD".to_string(), self.env_vars["PWD"].clone());
122    }
123
124    pub fn pwd(&self) -> NicePathBuf {
125        self.env_vars
126            .get("PWD")
127            .cloned()
128            .map(NicePathBuf::from)
129            .unwrap_or_else(NicePathBuf::cwd)
130    }
131
132    pub fn get_env(&self, name: &str) -> Option<&str> {
133        self.env_vars.get(name).map(|s| s.as_str())
134    }
135
136    pub fn set_env(&mut self, name: impl Into<String>, value: impl Into<String>) {
137        let name = name.into();
138        if name == "PWD" {
139            self.set_pwd(value.into());
140        } else {
141            self.env_vars.insert(name, value.into());
142        }
143    }
144
145    pub fn set_pwd(&mut self, pwd: impl Into<NicePathBuf>) {
146        let pwd = pwd.into().env_string();
147        self.env_vars.insert("PWD".to_string(), pwd);
148    }
149
150    pub fn expand(&self, value: &ShellBit) -> Result<String, ScriptRunError> {
151        match value {
152            ShellBit::Literal(s) => Ok(s.clone()),
153            ShellBit::Quoted(s) => self.expand_str(s),
154        }
155    }
156
157    /// Perform shell expansion on a string.
158    pub fn expand_str(&self, value: impl AsRef<str>) -> Result<String, ScriptRunError> {
159        enum State {
160            Normal,
161            EscapeNext,
162            InCurly,
163            Dollar,
164            InDollar,
165        }
166
167        let value = value.as_ref();
168
169        // "\" triggers escaping
170        // ${A} expands to the value of A
171        // $A expands to the value of A (variable ends on first non-alphanumeric character)
172
173        let mut state = State::Normal;
174        let mut variable = String::new();
175        let mut expanded = String::new();
176
177        for c in value.chars() {
178            match state {
179                State::Normal => {
180                    if c == '$' {
181                        state = State::Dollar;
182                        continue;
183                    }
184                    if c == '\\' {
185                        state = State::EscapeNext;
186                        continue;
187                    }
188                    expanded.push(c);
189                }
190                State::EscapeNext => {
191                    expanded.push(c);
192                    state = State::Normal;
193                }
194                State::InCurly => {
195                    if c == '}' {
196                        if let Some(value) = self.get_env(&std::mem::take(&mut variable)) {
197                            expanded.push_str(value);
198                        } else {
199                            return Err(ScriptRunError::ExpansionError(format!(
200                                "undefined variable in ${{...}}: {:?} (in {value:?})",
201                                variable
202                            )));
203                        }
204                        state = State::Normal;
205                    } else {
206                        variable.push(c);
207                    }
208                }
209                State::Dollar => {
210                    if c.is_alphanumeric() || c == '_' {
211                        state = State::InDollar;
212                        variable.push(c);
213                    } else if c == '{' {
214                        state = State::InCurly;
215                    } else {
216                        return Err(ScriptRunError::ExpansionError(format!(
217                            "invalid variable: {:?} (in {value:?})",
218                            c
219                        )));
220                    }
221                }
222                State::InDollar => {
223                    if c.is_alphanumeric() || c == '_' {
224                        variable.push(c);
225                    } else {
226                        if let Some(value) = self.get_env(&std::mem::take(&mut variable)) {
227                            expanded.push_str(value);
228                        } else {
229                            return Err(ScriptRunError::ExpansionError(format!(
230                                "undefined variable in $...: {:?} (in {value:?})",
231                                variable
232                            )));
233                        }
234                        expanded.push(c);
235                        state = State::Normal;
236                    }
237                }
238            }
239        }
240        match state {
241            State::InDollar => {
242                if let Some(value) = self.get_env(&variable) {
243                    expanded.push_str(value);
244                } else {
245                    return Err(ScriptRunError::ExpansionError(format!(
246                        "undefined variable: {}",
247                        variable
248                    )));
249                }
250            }
251            State::Dollar => {
252                return Err(ScriptRunError::ExpansionError(
253                    "incomplete variable".to_string(),
254                ));
255            }
256            State::InCurly => {
257                return Err(ScriptRunError::ExpansionError(format!(
258                    "unclosed variable: {}",
259                    variable
260                )));
261            }
262            State::Normal => {}
263            State::EscapeNext => {
264                return Err(ScriptRunError::ExpansionError(
265                    "unclosed backslash".to_string(),
266                ));
267            }
268        }
269        Ok(expanded)
270    }
271
272    pub fn env_vars(&self) -> &HashMap<String, String> {
273        &self.env_vars
274    }
275}
276
277#[derive(derive_more::Debug, Clone)]
278pub struct ScriptOutput {
279    #[debug(skip)]
280    stream: SharedMut<Box<dyn WriteColorAny>>,
281}
282
283trait WriteColorAny: WriteColor + Send + Sync + std::any::Any + 'static + std::fmt::Debug {
284    /// Workaround for lack of upcasting
285    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String>;
286    fn clone_buffer(&self) -> Result<termcolor::Buffer, String>;
287}
288
289impl WriteColorAny for termcolor::StandardStream {
290    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String> {
291        Err("not a buffer".to_string())
292    }
293    fn clone_buffer(&self) -> Result<termcolor::Buffer, String> {
294        Err("not a buffer".to_string())
295    }
296}
297
298impl WriteColorAny for termcolor::Buffer {
299    fn take_buffer(self: Box<Self>) -> Result<termcolor::Buffer, String> {
300        Ok(*self)
301    }
302    fn clone_buffer(&self) -> Result<termcolor::Buffer, String> {
303        Ok(self.clone())
304    }
305}
306
307impl ScriptOutput {
308    pub fn no_color() -> Self {
309        let stm = termcolor::StandardStream::stdout(ColorChoice::Never);
310        Self {
311            stream: SharedMut::new(Box::new(stm) as _),
312        }
313    }
314
315    pub fn quiet(no_color: bool) -> Self {
316        let stm = if no_color {
317            termcolor::Buffer::no_color()
318        } else {
319            termcolor::Buffer::ansi()
320        };
321        Self {
322            stream: SharedMut::new(Box::new(stm) as _),
323        }
324    }
325
326    pub fn take_buffer(self) -> String {
327        let stream = match SharedMut::try_unwrap(self.stream) {
328            Ok(stream) => stream.take_buffer().expect("wrong stream type"),
329            Err(shared) => shared.read().clone_buffer().expect("wrong stream type"),
330        };
331        String::from_utf8_lossy(&stream.into_inner()).to_string()
332    }
333}
334
335impl Default for ScriptOutput {
336    fn default() -> Self {
337        let stm = termcolor::StandardStream::stdout(ColorChoice::Auto);
338        Self {
339            stream: SharedMut::new(Box::new(stm) as _),
340        }
341    }
342}
343
344impl std::io::Write for ScriptOutputLock<'_> {
345    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
346        self.stream.write(buf)
347    }
348    fn flush(&mut self) -> std::io::Result<()> {
349        self.stream.flush()
350    }
351}
352
353impl termcolor::WriteColor for ScriptOutputLock<'_> {
354    fn supports_color(&self) -> bool {
355        self.stream.supports_color()
356    }
357    fn set_color(&mut self, spec: &termcolor::ColorSpec) -> std::io::Result<()> {
358        self.stream.set_color(spec)
359    }
360    fn reset(&mut self) -> std::io::Result<()> {
361        self.stream.reset()
362    }
363    fn is_synchronous(&self) -> bool {
364        self.stream.is_synchronous()
365    }
366    fn set_hyperlink(&mut self, _link: &termcolor::HyperlinkSpec) -> std::io::Result<()> {
367        self.stream.set_hyperlink(_link)
368    }
369    fn supports_hyperlinks(&self) -> bool {
370        self.stream.supports_hyperlinks()
371    }
372}
373
374struct ScriptOutputLock<'a> {
375    stream: keepcalm::SharedWriteLock<'a, Box<dyn WriteColorAny>>,
376}
377
378#[derive(Debug, Clone, Copy, PartialEq, Eq)]
379enum ScriptMode {
380    Normal,
381    Deferred,
382    Background,
383}
384
385#[derive(derive_more::Debug)]
386pub struct ScriptRunContext {
387    pub args: ScriptRunArgs,
388    pub grok: Grok,
389    timeout: Duration,
390    env: ScriptEnv,
391    includes: Arc<HashMap<String, Script>>,
392    background: ScriptMode,
393    #[debug(skip)]
394    kill: ScriptKillReceiver,
395    #[debug(skip)]
396    kill_sender: ScriptKillSender,
397    output: ScriptOutput,
398
399    global_ignore: OutputPatterns,
400    global_reject: OutputPatterns,
401}
402
403impl Default for ScriptRunContext {
404    fn default() -> Self {
405        let kill = Arc::new(AtomicBool::new(false));
406        Self {
407            args: ScriptRunArgs::default(),
408            grok: Grok::with_default_patterns(),
409            timeout: DEFAULT_TIMEOUT,
410            env: ScriptEnv::default(),
411            background: ScriptMode::Normal,
412            includes: Arc::new(HashMap::new()),
413            kill: ScriptKillReceiver::new(kill.clone()),
414            kill_sender: ScriptKillSender::new(kill.clone()),
415            output: ScriptOutput::default(),
416            global_ignore: OutputPatterns::default(),
417            global_reject: OutputPatterns::default(),
418        }
419    }
420}
421
422impl ScriptRunContext {
423    pub fn new_background(&self) -> Self {
424        let kill = Arc::new(AtomicBool::new(false));
425        Self {
426            args: self.args.clone(),
427            grok: self.grok.clone(),
428            // Background processes are not subject to timeouts
429            timeout: Duration::MAX,
430            env: self.env.clone(),
431            background: ScriptMode::Background,
432            kill: ScriptKillReceiver::new(kill.clone()),
433            kill_sender: ScriptKillSender::new(kill.clone()),
434            includes: self.includes.clone(),
435            output: if self.args.verbose {
436                self.output.clone()
437            } else {
438                ScriptOutput::quiet(self.args.no_color)
439            },
440            global_ignore: self.global_ignore.clone(),
441            global_reject: self.global_reject.clone(),
442        }
443    }
444
445    pub fn new_deferred(&self) -> Self {
446        Self {
447            args: self.args.clone(),
448            grok: self.grok.clone(),
449            timeout: self.timeout,
450            env: self.env.clone(),
451            background: ScriptMode::Deferred,
452            kill: self.kill.clone(),
453            kill_sender: self.kill_sender.clone(),
454            includes: self.includes.clone(),
455            output: self.output.clone(),
456            global_ignore: self.global_ignore.clone(),
457            global_reject: self.global_reject.clone(),
458        }
459    }
460
461    pub fn pwd(&self) -> NicePathBuf {
462        self.env.pwd()
463    }
464
465    pub fn get_env(&self, name: &str) -> Option<&str> {
466        self.env.get_env(name)
467    }
468
469    pub fn set_env(&mut self, name: impl Into<String>, value: impl Into<String>) {
470        self.env.set_env(name, value);
471    }
472
473    pub fn set_pwd(&mut self, pwd: impl Into<NicePathBuf>) {
474        self.env.set_pwd(pwd);
475    }
476
477    pub fn take_output(self) -> String {
478        self.output.take_buffer()
479    }
480
481    fn expand(&self, value: &ShellBit) -> Result<String, ScriptRunError> {
482        self.env.expand(value)
483    }
484
485    /// Get a mutable reference to the output stream.
486    pub fn stream(&self) -> impl termcolor::WriteColor + use<'_> {
487        ScriptOutputLock {
488            stream: self.output.stream.write(),
489        }
490    }
491}
492
493#[derive(Clone)]
494pub struct ScriptKillReceiver {
495    kill_receiver: Arc<AtomicBool>,
496}
497
498impl ScriptKillReceiver {
499    pub fn new(kill_receiver: Arc<AtomicBool>) -> Self {
500        Self { kill_receiver }
501    }
502
503    pub fn is_killed(&self) -> bool {
504        self.kill_receiver.load(std::sync::atomic::Ordering::SeqCst)
505    }
506
507    pub fn run_with<T>(&self, kill: impl FnOnce() + Send, wait: impl FnOnce() -> T) -> T {
508        std::thread::scope(|s| {
509            let done = Arc::new(AtomicBool::new(false));
510            let done_clone = done.clone();
511            let t = s.spawn(move || {
512                while !done_clone.load(std::sync::atomic::Ordering::SeqCst) {
513                    if self.is_killed() {
514                        kill();
515                        break;
516                    }
517                    std::thread::sleep(Duration::from_millis(10));
518                }
519            });
520            let res = wait();
521            done.store(true, std::sync::atomic::Ordering::SeqCst);
522            t.join().unwrap();
523            res
524        })
525    }
526
527    #[cfg(windows)]
528    pub fn run_cmd(
529        &self,
530        output: std::process::Child,
531        warn_time: Duration,
532    ) -> std::io::Result<ExitStatus> {
533        use std::os::windows::io::AsRawHandle;
534        use win32job::Job;
535
536        fn map_job_error(e: win32job::JobError) -> std::io::Error {
537            match e {
538                win32job::JobError::AssignFailed(e) => e,
539                win32job::JobError::CreateFailed(e) => e,
540                win32job::JobError::GetInfoFailed(e) => e,
541                win32job::JobError::SetInfoFailed(e) => e,
542                _ => std::io::Error::new(std::io::ErrorKind::Other, "Unknown error"),
543            }
544        }
545
546        // Create a new Job object
547        let job = Job::create().map_err(map_job_error)?;
548
549        // Configure the job to terminate all child processes when the job is closed
550        let mut info = job.query_extended_limit_info().map_err(map_job_error)?;
551        info.limit_kill_on_job_close();
552        job.set_extended_limit_info(&info).map_err(map_job_error)?;
553        job.assign_process(output.as_raw_handle() as _)?;
554
555        // Resume the main thread for the process
556        let id = output.id();
557        for thread_entry in tlhelp32::Snapshot::new_thread()? {
558            if thread_entry.owner_process_id == id {
559                use windows_sys::Win32::Foundation::CloseHandle;
560                use windows_sys::Win32::System::Threading::*;
561
562                // TODO: error handling
563                unsafe {
564                    let thread = OpenThread(THREAD_SUSPEND_RESUME, 0, thread_entry.thread_id);
565                    ResumeThread(thread);
566                    CloseHandle(thread);
567                }
568            }
569        }
570
571        let job = Mutex::new(Some(job));
572        let output = Mutex::new(output);
573        self.run_with(
574            || {
575                _ = job.lock().unwrap().take();
576                _ = output.lock().unwrap().kill();
577            },
578            || {
579                let start = std::time::Instant::now();
580                let mut warned = false;
581                loop {
582                    let res = output.lock().unwrap().try_wait()?;
583                    if let Some(status) = res {
584                        return Ok::<_, std::io::Error>(status);
585                    }
586                    if start.elapsed() > warn_time {
587                        if !warned {
588                            let child = output.lock().unwrap().id();
589                            eprintln!("Process #{child} taking too long to finish.");
590                            warned = true;
591                        }
592                    }
593                    std::thread::sleep(Duration::from_millis(10));
594                }
595            },
596        )
597    }
598
599    #[cfg(unix)]
600    pub fn run_cmd(
601        &self,
602        output: std::process::Child,
603        warn_time: Duration,
604    ) -> std::io::Result<ExitStatus> {
605        let output = Mutex::new(output);
606        self.run_with(
607            || {
608                use signal_child::{signal, signal::*};
609                let id = output.lock().unwrap().id() as i32;
610                _ = signal(-id, SIGINT);
611                std::thread::sleep(Duration::from_millis(10));
612                _ = signal(-id, SIGTERM);
613            },
614            || {
615                let start = std::time::Instant::now();
616                let mut warned = false;
617                loop {
618                    let res = output.lock().unwrap().try_wait()?;
619                    if let Some(status) = res {
620                        return Ok::<_, std::io::Error>(status);
621                    }
622                    if start.elapsed() > warn_time && !warned {
623                        let child = output.lock().unwrap().id();
624                        eprintln!("Process #{child} taking too long to finish.");
625                        warned = true;
626                    }
627                    std::thread::sleep(Duration::from_millis(10));
628                }
629            },
630        )
631    }
632}
633
634#[derive(Clone)]
635pub struct ScriptKillSender {
636    kill_sender: Arc<AtomicBool>,
637}
638
639impl ScriptKillSender {
640    pub fn new(kill_sender: Arc<AtomicBool>) -> Self {
641        Self { kill_sender }
642    }
643
644    pub fn kill(&self) {
645        self.kill_sender
646            .store(true, std::sync::atomic::Ordering::SeqCst);
647    }
648}
649
650impl ScriptRunContext {
651    pub fn new(args: ScriptRunArgs, script_path: impl AsRef<Path>, output: ScriptOutput) -> Self {
652        let mut env = ScriptEnv::default();
653        env.set_defaults(script_path.as_ref().parent().unwrap());
654
655        let kill = Arc::new(AtomicBool::new(false));
656
657        Self {
658            timeout: args.global_timeout.unwrap_or(DEFAULT_TIMEOUT),
659            args,
660            env,
661            grok: Grok::with_default_patterns(),
662            includes: Arc::new(HashMap::new()),
663            background: ScriptMode::Normal,
664            kill: ScriptKillReceiver::new(kill.clone()),
665            kill_sender: ScriptKillSender::new(kill.clone()),
666            output,
667            global_ignore: OutputPatterns::default(),
668            global_reject: OutputPatterns::default(),
669        }
670    }
671}
672
673#[derive(Clone, Debug, PartialEq, Eq)]
674pub struct ScriptLine {
675    pub location: ScriptLocation,
676    text: String,
677}
678
679impl ScriptLine {
680    pub fn new(file: ScriptFile, line: usize, text: impl AsRef<str>) -> Self {
681        Self {
682            location: ScriptLocation::new(file, line),
683            text: text.as_ref().to_string(),
684        }
685    }
686
687    pub fn parse(file: ScriptFile, text: impl AsRef<str>) -> Vec<Self> {
688        text.as_ref()
689            .lines()
690            .enumerate()
691            .map(|(line, text)| Self {
692                location: ScriptLocation::new(file.clone(), line + 1),
693                text: text.to_string(),
694            })
695            .collect()
696    }
697
698    pub fn starts_with(&self, text: &str) -> bool {
699        self.text.trim().starts_with(text)
700    }
701
702    pub fn first_char(&self) -> Option<char> {
703        self.text.trim().chars().next()
704    }
705
706    pub fn text(&self) -> &str {
707        self.text.trim()
708    }
709
710    pub fn text_untrimmed(&self) -> &str {
711        &self.text
712    }
713
714    pub fn is_empty(&self) -> bool {
715        self.text.trim().is_empty()
716    }
717
718    pub fn strip_prefix(&self, prefix: &str) -> Option<&str> {
719        self.text.strip_prefix(prefix)
720    }
721}
722
723#[derive(Debug, thiserror::Error, derive_more::Display)]
724#[display("{error} at {location}{}", associated_data.as_deref().map_or("".to_string(), |d| format!(": {d}")))]
725pub struct ScriptError {
726    pub error: ScriptErrorType,
727    pub location: ScriptLocation,
728    pub associated_data: Option<String>,
729}
730
731impl ScriptError {
732    pub fn new(error: ScriptErrorType, location: ScriptLocation) -> Self {
733        if std::env::var("PANIC_ON_ERROR").is_ok() {
734            panic!("ScriptError: {error} at {location}");
735        }
736        Self {
737            error,
738            location,
739            associated_data: None,
740        }
741    }
742
743    pub fn new_with_data(
744        error: ScriptErrorType,
745        location: ScriptLocation,
746        associated_data: String,
747    ) -> Self {
748        if std::env::var("PANIC_ON_ERROR").is_ok() {
749            panic!("ScriptError: {error} at {location}: {associated_data}");
750        }
751        Self {
752            error,
753            location,
754            associated_data: Some(associated_data),
755        }
756    }
757}
758
759#[derive(Debug, thiserror::Error, Eq, PartialEq)]
760pub enum ScriptErrorType {
761    #[error("background process not allowed")]
762    BackgroundProcessNotAllowed,
763    #[error("unclosed quote")]
764    UnclosedQuote,
765    #[error("unclosed backslash")]
766    UnclosedBackslash,
767    #[error("illegal shell command format")]
768    IllegalShellCommand,
769    #[error("unsupported redirection")]
770    UnsupportedRedirection,
771    #[error("invalid pattern definition")]
772    InvalidPatternDefinition,
773    #[error("invalid pattern")]
774    InvalidPattern,
775    #[error("invalid meta command")]
776    InvalidMetaCommand,
777    #[error("invalid pattern at global level (only reject or ignore allowed here)")]
778    InvalidGlobalPattern,
779    #[error("invalid block type")]
780    InvalidBlockType,
781    #[error("invalid block arguments")]
782    InvalidBlockArgs,
783    #[error("unsupported command position")]
784    UnsupportedCommandPosition,
785    #[error("invalid trailing pattern after *")]
786    InvalidAnyPattern,
787    #[error("invalid exit status")]
788    InvalidExitStatus,
789    #[error("invalid set variable")]
790    InvalidSetVariable,
791    #[error("invalid version header, expected `#!/usr/bin/env clitest --v0`")]
792    InvalidVersion,
793    #[error("invalid internal command")]
794    InvalidInternalCommand,
795    #[error("missing command lines")]
796    MissingCommandLines,
797    #[error(
798        "block end without matching block start, too many closing braces or braces not properly nested"
799    )]
800    InvalidBlockEnd,
801    #[error("invalid if condition")]
802    InvalidIfCondition,
803    #[error("expected block or semi-colon (did you forget to add ';' at the end of this line?)")]
804    ExpectedBlockOrSemi,
805}
806
807#[derive(Debug, thiserror::Error)]
808pub enum ScriptRunError {
809    #[error("{0}")]
810    Pattern(#[from] OutputPatternMatchFailure),
811    #[error("{0}")]
812    PatternPrepareError(#[from] OutputPatternPrepareError),
813    #[error("{0}")]
814    Exit(CommandResult),
815    #[error("included file not found: {0}")]
816    IncludedFileNotFound(String),
817    #[error("expected failure, but passed")]
818    ExpectedFailure,
819    #[error("{0}")]
820    ExpansionError(String),
821    #[error("{0}")]
822    IO(#[from] std::io::Error),
823    #[error("killed")]
824    Killed,
825    #[error("background process took too long to finish")]
826    BackgroundProcessTookTooLong,
827    #[error("retry took too long to finish")]
828    RetryTookTooLong,
829    /// Internal flow control: exit the script
830    #[error("exiting script")]
831    ExitScript,
832}
833
834impl ScriptRunError {
835    #[allow(unused)]
836    pub fn short(&self) -> String {
837        match self {
838            Self::Pattern(_) => "Pattern".to_string(),
839            Self::PatternPrepareError(e) => format!("PatternPrepareError({:?})", e),
840            Self::Exit(status) => format!("Exit({})", status),
841            Self::ExpectedFailure => "ExpectedFailure".to_string(),
842            Self::IO(e) => format!("IO({:?})", e.kind()),
843            Self::Killed => "Killed".to_string(),
844            Self::BackgroundProcessTookTooLong => "BackgroundProcessTookTooLong".to_string(),
845            Self::ExpansionError(e) => "ExpansionError".to_string(),
846            Self::RetryTookTooLong => "RetryTookTooLong".to_string(),
847            Self::ExitScript => unreachable!(),
848            Self::IncludedFileNotFound(path) => format!("IncludedFileNotFound({})", path),
849        }
850    }
851}
852
853impl Script {
854    pub fn new(file: ScriptFile) -> Self {
855        Self {
856            commands: Arc::new(vec![]),
857            includes: Arc::new(HashMap::new()),
858            file,
859        }
860    }
861
862    /// Collect all included script paths from the script.
863    pub fn includes(&self) -> Vec<(ScriptLocation, String)> {
864        self.commands
865            .iter()
866            .flat_map(|block| block.includes())
867            .collect()
868    }
869
870    pub fn run(&self, context: &mut ScriptRunContext) -> Result<(), ScriptRunError> {
871        let old_includes = context.includes.clone();
872        context.includes = self.includes.clone();
873        let res = ScriptBlock::run_blocks(context, &self.commands);
874        context.includes = old_includes;
875        let v = match res {
876            Ok(v) => v,
877            // Bypass normal script processing and exit successfully
878            Err(ScriptRunError::ExitScript) => return Ok(()),
879            Err(e) => return Err(e),
880        };
881        assert!(v.is_empty(), "script did not run to completion: {v:?}");
882        Ok(())
883    }
884
885    pub fn run_with_args(
886        &self,
887        args: ScriptRunArgs,
888        output: ScriptOutput,
889    ) -> Result<(), ScriptRunError> {
890        let start = Instant::now();
891        let script_path = &*self.file.file;
892        let mut context = ScriptRunContext::new(args, script_path, output);
893
894        // Write "Running..." message with colors
895        cwrite!(context.stream(), "Running ");
896        cwrite!(context.stream(), fg = Color::Cyan, "{}", script_path);
897        cwriteln!(context.stream(), " ...");
898        cwriteln!(context.stream());
899
900        let result = self.run(&mut context);
901
902        // Handle success and error output
903        if let Err(ref e) = result {
904            cwrite!(context.stream(), fg = Color::Cyan, "{} ", script_path);
905            cwrite!(context.stream(), fg = Color::Red, "FAILED");
906            if !context.args.simplified_output {
907                cwriteln!(context.stream(), " ({:.2}s)", start.elapsed().as_secs_f32());
908            } else {
909                cwriteln!(context.stream());
910            }
911            cwrite!(context.stream(), fg = Color::Red, "Error: ");
912            cwriteln!(context.stream(), "{}", e);
913            cwriteln!(context.stream());
914        } else {
915            cwrite!(context.stream(), fg = Color::Cyan, "{} ", script_path);
916            cwrite!(context.stream(), fg = Color::Green, "PASSED");
917            if !context.args.simplified_output {
918                cwriteln!(context.stream(), " ({:.2}s)", start.elapsed().as_secs_f32());
919            } else {
920                cwriteln!(context.stream());
921            }
922        }
923
924        result
925    }
926}
927
928#[derive(Debug, Default, Serialize)]
929pub enum CommandExit {
930    #[default]
931    Success,
932    Failure(i32),
933    Timeout,
934    Any,
935    AnyFailure,
936}
937
938impl CommandExit {
939    pub fn matches(&self, status: CommandResult) -> bool {
940        match (self, status) {
941            (CommandExit::Success, CommandResult::Exit(status)) => status.success(),
942            (CommandExit::Failure(code), CommandResult::Exit(status)) => {
943                *code == status.code().unwrap_or(-1)
944            }
945            (CommandExit::Timeout, CommandResult::TimedOut) => true,
946            (CommandExit::Any, _) => true,
947            (CommandExit::AnyFailure, CommandResult::Exit(status)) => !status.success(),
948            (CommandExit::AnyFailure, _) => true,
949            _ => false,
950        }
951    }
952
953    pub fn is_success(&self) -> bool {
954        matches!(self, CommandExit::Success)
955    }
956}
957
958#[derive(derive_more::Debug)]
959pub enum ScriptBlock {
960    Command(ScriptCommand),
961    InternalCommand(ScriptLocation, InternalCommand),
962    Background(Vec<ScriptBlock>),
963    Defer(Vec<ScriptBlock>),
964    If(IfCondition, Vec<ScriptBlock>),
965    For(ForCondition, Vec<ScriptBlock>),
966    Retry(Vec<ScriptBlock>),
967    GlobalIgnore(OutputPatterns),
968    GlobalReject(OutputPatterns),
969}
970
971impl ScriptBlock {
972    pub fn includes(&self) -> Vec<(ScriptLocation, String)> {
973        match self {
974            ScriptBlock::Command(..) => vec![],
975            ScriptBlock::InternalCommand(location, InternalCommand::Include(path)) => {
976                vec![(location.clone(), path.clone())]
977            }
978            ScriptBlock::InternalCommand(..) => vec![],
979            ScriptBlock::Background(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
980            ScriptBlock::Defer(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
981            ScriptBlock::If(_, blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
982            ScriptBlock::For(_, blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
983            ScriptBlock::Retry(blocks) => blocks.iter().flat_map(|b| b.includes()).collect(),
984            ScriptBlock::GlobalIgnore(_) => vec![],
985            ScriptBlock::GlobalReject(_) => vec![],
986        }
987    }
988
989    pub fn run_blocks(
990        context: &mut ScriptRunContext,
991        blocks: &[ScriptBlock],
992    ) -> Result<Vec<ScriptResult>, ScriptRunError> {
993        enum Deferred<'a> {
994            Scripts(&'a [ScriptBlock]),
995            Internal(
996                Box<
997                    dyn FnOnce(&mut ScriptRunContext) -> Result<(), ScriptRunError>
998                        + Send
999                        + Sync
1000                        + 'a,
1001                >,
1002            ),
1003            Background(
1004                ScopedJoinHandle<'a, Result<Vec<ScriptResult>, ScriptRunError>>,
1005                ScriptKillSender,
1006            ),
1007        }
1008
1009        let mut results = Vec::new();
1010        std::thread::scope(|s| {
1011            let mut defer_blocks = VecDeque::new();
1012            let mut pending_error = None;
1013            for block in blocks {
1014                if context.kill.is_killed() {
1015                    return Err(ScriptRunError::Killed);
1016                }
1017                match block {
1018                    ScriptBlock::Background(blocks) => {
1019                        let mut context = context.new_background();
1020                        let kill_sender = context.kill_sender.clone();
1021                        let handle = s.spawn(move || Self::run_blocks(&mut context, blocks));
1022                        defer_blocks.push_front(Deferred::Background(handle, kill_sender));
1023                    }
1024                    ScriptBlock::Defer(blocks) => {
1025                        // Insert at the front of the queue by extending and
1026                        // then rotating
1027                        defer_blocks.push_front(Deferred::Scripts(blocks));
1028                    }
1029                    ScriptBlock::InternalCommand(_, command) => {
1030                        if context.background == ScriptMode::Deferred {
1031                            cwrite!(context.stream(), dimmed = true, "(deferred) ");
1032                        }
1033                        if let Some(f) = command.run(context)? {
1034                            defer_blocks.push_front(Deferred::Internal(f));
1035                        }
1036                    }
1037                    _ => match block.run(context) {
1038                        Ok(res) => results.extend(res),
1039                        Err(e) => {
1040                            pending_error = Some(e);
1041                            break;
1042                        }
1043                    },
1044                }
1045            }
1046            for block in defer_blocks {
1047                match block {
1048                    Deferred::Scripts(blocks) => {
1049                        let mut context = context.new_deferred();
1050                        ScriptBlock::run_blocks(&mut context, blocks)?;
1051                    }
1052                    Deferred::Internal(block) => {
1053                        cwrite!(context.stream(), dimmed = true, "(cleanup) ");
1054                        block(context)?;
1055                    }
1056                    Deferred::Background(handle, kill_sender) => {
1057                        kill_sender.kill();
1058                        let start = std::time::Instant::now();
1059                        let mut warned = false;
1060
1061                        let timeout = context.timeout;
1062                        let warn_at = timeout * 8 / 10;
1063
1064                        let results = loop {
1065                            if handle.is_finished() {
1066                                break handle.join().unwrap()?;
1067                            }
1068                            std::thread::sleep(std::time::Duration::from_millis(10));
1069                            if !warned && start.elapsed() > warn_at {
1070                                cwriteln!(
1071                                    context.stream(),
1072                                    fg = Color::Yellow,
1073                                    "Background process is taking too long to finish."
1074                                );
1075                                warned = true;
1076                            }
1077                            if start.elapsed() > timeout {
1078                                cwriteln!(
1079                                    context.stream(),
1080                                    fg = Color::Red,
1081                                    "Background process took too long to finish."
1082                                );
1083                                return Err(ScriptRunError::BackgroundProcessTookTooLong);
1084                            }
1085                        };
1086                        for result in results {
1087                            cwrite!(context.stream(), dimmed = true, "(background) ");
1088                            for line in result.command.command.split('\n') {
1089                                cwriteln!(context.stream(), fg = Color::Green, "{}", line);
1090                            }
1091                            if context.args.simplified_output {
1092                                cwriteln!(context.stream(), dimmed = true, "---");
1093                            } else {
1094                                cwriteln_rule!(
1095                                    context.stream(),
1096                                    fg = Color::Cyan,
1097                                    "{}",
1098                                    result.command.location
1099                                );
1100                            }
1101                            for line in &result.output {
1102                                cwriteln!(context.stream(), "{}", line);
1103                            }
1104                            if result.output.is_empty() {
1105                                cwriteln!(context.stream(), dimmed = true, "(no output)");
1106                            }
1107                            if context.args.simplified_output {
1108                                cwriteln!(context.stream(), dimmed = true, "---");
1109                            } else {
1110                                cwriteln_rule!(context.stream());
1111                            }
1112                            result.evaluate(context)?;
1113                        }
1114                    }
1115                }
1116            }
1117            if let Some(error) = pending_error {
1118                return Err(error);
1119            }
1120            Ok(results)
1121        })
1122    }
1123
1124    pub fn run(&self, context: &mut ScriptRunContext) -> Result<Vec<ScriptResult>, ScriptRunError> {
1125        let pwd = context.pwd();
1126        let res = pwd.exists();
1127        if !matches!(res, Ok(true)) {
1128            cwriteln!(
1129                context.stream(),
1130                fg = Color::Red,
1131                "$PWD {pwd:?} doesn't exist. Run `cd $INITIAL_PWD` to fix.",
1132            );
1133            return Err(ScriptRunError::IO(std::io::Error::new(
1134                std::io::ErrorKind::NotFound,
1135                format!("PWD does not exist: {pwd:?}"),
1136            )));
1137        }
1138
1139        match self {
1140            ScriptBlock::Command(command) => {
1141                if context.background == ScriptMode::Deferred {
1142                    cwrite!(context.stream(), dimmed = true, "(deferred) ");
1143                }
1144                let result = command.run(context)?;
1145                if context.background != ScriptMode::Background {
1146                    result.evaluate(context)?;
1147                    Ok(vec![])
1148                } else {
1149                    Ok(vec![result])
1150                }
1151            }
1152            ScriptBlock::If(condition, blocks) => {
1153                let condition = condition.expand(context)?;
1154                if condition.matches(context) {
1155                    Self::run_blocks(context, blocks)
1156                } else {
1157                    Ok(vec![])
1158                }
1159            }
1160            ScriptBlock::For(ForCondition::Env(env, values), blocks) => {
1161                let mut results = Vec::new();
1162                for value in values {
1163                    context.set_env(env, context.expand(value)?);
1164                    results.extend(Self::run_blocks(context, blocks)?);
1165                }
1166                Ok(results)
1167            }
1168            ScriptBlock::Retry(blocks) => {
1169                let start = Instant::now();
1170                let mut backoff = Duration::from_millis(100);
1171
1172                cwrite!(context.stream(), fg = Color::Green, "retry: ");
1173                cwriteln!(context.stream(), "running...");
1174
1175                loop {
1176                    let mut nested_context = context.new_background();
1177                    if let Ok(results) = Self::run_blocks(&mut nested_context, blocks) {
1178                        let mut all_ok = true;
1179                        for result in results {
1180                            if result.evaluate(&mut nested_context).is_err() {
1181                                all_ok = false;
1182                                break;
1183                            }
1184                        }
1185                        if all_ok {
1186                            let output = nested_context.take_output();
1187                            cwrite!(context.stream(), fg = Color::Green, "retry: ");
1188                            cwriteln!(context.stream(), "success");
1189                            cwriteln!(context.stream());
1190                            cwriteln!(context.stream(), "{output}");
1191                            return Ok(vec![]);
1192                        }
1193                    }
1194
1195                    if start.elapsed() > context.timeout {
1196                        let output = nested_context.take_output();
1197                        cwrite!(context.stream(), fg = Color::Green, "retry: ");
1198                        cwriteln!(context.stream(), fg = Color::Red, "timed out");
1199                        cwriteln!(context.stream());
1200                        cwriteln!(context.stream(), "{output}");
1201                        cwriteln_rule!(context.stream());
1202                        return Err(ScriptRunError::RetryTookTooLong);
1203                    }
1204                    std::thread::sleep(backoff);
1205                    backoff *= 2;
1206                }
1207            }
1208            ScriptBlock::GlobalIgnore(patterns) => {
1209                for pattern in patterns.iter() {
1210                    pattern.prepare(&context.grok)?;
1211                }
1212                context.global_ignore.extend(patterns);
1213                Ok(vec![])
1214            }
1215            ScriptBlock::GlobalReject(patterns) => {
1216                for pattern in patterns.iter() {
1217                    pattern.prepare(&context.grok)?;
1218                }
1219                context.global_reject.extend(patterns);
1220                Ok(vec![])
1221            }
1222            _ => unreachable!("Unexpected block type: {self:?}"),
1223        }
1224    }
1225}
1226
1227impl Serialize for ScriptBlock {
1228    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1229    where
1230        S: serde::Serializer,
1231    {
1232        match self {
1233            ScriptBlock::Command(command) => command.serialize(serializer),
1234            ScriptBlock::InternalCommand(_, command) => command.serialize(serializer),
1235            ScriptBlock::Background(blocks) => {
1236                let mut ser = serializer.serialize_map(Some(1))?;
1237                ser.serialize_entry("background", blocks)?;
1238                ser.end()
1239            }
1240            ScriptBlock::Defer(blocks) => {
1241                let mut ser = serializer.serialize_map(Some(1))?;
1242                ser.serialize_entry("defer", blocks)?;
1243                ser.end()
1244            }
1245            ScriptBlock::If(condition, blocks) => {
1246                let mut ser = serializer.serialize_map(Some(2))?;
1247                ser.serialize_entry("if", condition)?;
1248                ser.serialize_entry("blocks", blocks)?;
1249                ser.end()
1250            }
1251            ScriptBlock::For(condition, blocks) => {
1252                let mut ser = serializer.serialize_map(Some(2))?;
1253                ser.serialize_entry("for", condition)?;
1254                ser.serialize_entry("blocks", blocks)?;
1255                ser.end()
1256            }
1257            ScriptBlock::Retry(blocks) => {
1258                let mut ser = serializer.serialize_map(Some(1))?;
1259                ser.serialize_entry("retry", blocks)?;
1260                ser.end()
1261            }
1262            ScriptBlock::GlobalIgnore(patterns) => {
1263                let mut ser = serializer.serialize_map(Some(1))?;
1264                ser.serialize_entry("ignore", patterns)?;
1265                ser.end()
1266            }
1267            ScriptBlock::GlobalReject(patterns) => {
1268                let mut ser = serializer.serialize_map(Some(1))?;
1269                ser.serialize_entry("reject", patterns)?;
1270                ser.end()
1271            }
1272        }
1273    }
1274}
1275
1276#[derive(Debug, Clone, Serialize)]
1277pub enum InternalCommand {
1278    UsingTempdir,
1279    UsingDir(ShellBit, bool),
1280    ChangeDir(ShellBit),
1281    Set(String, ShellBit),
1282    Include(String),
1283    ExitScript,
1284    Pattern(String, String),
1285}
1286
1287impl InternalCommand {
1288    pub fn run(
1289        &self,
1290        context: &mut ScriptRunContext,
1291    ) -> Result<
1292        Option<Box<dyn FnOnce(&mut ScriptRunContext) -> Result<(), ScriptRunError> + Send + Sync>>,
1293        ScriptRunError,
1294    > {
1295        match self.clone() {
1296            InternalCommand::Include(path) => {
1297                let Some(script) = context.includes.get(&path) else {
1298                    return Err(ScriptRunError::IncludedFileNotFound(path));
1299                };
1300                script.clone().run(context)?;
1301                Ok(None)
1302            }
1303            InternalCommand::Pattern(name, pattern) => {
1304                context.grok.add_pattern(name, pattern);
1305                Ok(None)
1306            }
1307            InternalCommand::UsingTempdir => {
1308                let current_pwd = context.pwd();
1309                let tempdir = NiceTempDir::new();
1310                cwrite!(context.stream(), fg = Color::Yellow, "using tempdir: ");
1311                cwriteln!(context.stream(), "{}", tempdir);
1312                cwriteln!(context.stream());
1313                context.set_pwd(&tempdir);
1314                let pwd = context.pwd();
1315                if !pwd.exists()? {
1316                    return Err(ScriptRunError::IO(std::io::Error::new(
1317                        std::io::ErrorKind::NotFound,
1318                        format!("newly created tempdir does not exist: {pwd:?}"),
1319                    )));
1320                }
1321                Ok(Some(Box::new(move |context: &mut ScriptRunContext| {
1322                    cwriteln!(
1323                        context.stream(),
1324                        fg = Color::Yellow,
1325                        "removing {} && cd {}",
1326                        tempdir,
1327                        current_pwd
1328                    );
1329                    cwriteln!(context.stream());
1330                    if !tempdir.exists()? {
1331                        cwriteln!(
1332                            context.stream(),
1333                            fg = Color::Red,
1334                            "tempdir does not exist: {tempdir}"
1335                        );
1336                    }
1337                    if let Err(e) = tempdir.remove_dir_all() {
1338                        cwriteln!(
1339                            context.stream(),
1340                            fg = Color::Red,
1341                            "error removing tempdir: {e:?}"
1342                        );
1343                    }
1344                    Ok::<_, ScriptRunError>(())
1345                })))
1346            }
1347            InternalCommand::UsingDir(dir, new) => {
1348                let current_pwd = context.pwd();
1349                let dir = context.expand(&dir)?;
1350                let new_pwd = current_pwd.join(dir);
1351                if new {
1352                    cwrite!(context.stream(), fg = Color::Yellow, "using new dir: ");
1353                } else {
1354                    cwrite!(context.stream(), fg = Color::Yellow, "using dir: ");
1355                }
1356                cwriteln!(context.stream(), "{}", new_pwd);
1357                cwriteln!(context.stream());
1358
1359                if new {
1360                    new_pwd.create_dir_all()?;
1361                } else if !new_pwd.exists()? {
1362                    return Err(ScriptRunError::IO(std::io::Error::new(
1363                        std::io::ErrorKind::NotFound,
1364                        "directory does not exist",
1365                    )));
1366                }
1367                context.set_pwd(&new_pwd);
1368                Ok(Some(Box::new(move |context: &mut ScriptRunContext| {
1369                    if new {
1370                        cwriteln!(
1371                            context.stream(),
1372                            fg = Color::Yellow,
1373                            "removing {} && cd {}",
1374                            new_pwd,
1375                            current_pwd
1376                        );
1377                        cwriteln!(context.stream());
1378                    } else {
1379                        cwriteln!(context.stream(), fg = Color::Yellow, "cd {}", current_pwd);
1380                        cwriteln!(context.stream());
1381                    }
1382                    if new {
1383                        new_pwd.remove_dir_all()?;
1384                    }
1385                    context.set_pwd(current_pwd);
1386                    Ok::<_, ScriptRunError>(())
1387                })))
1388            }
1389            InternalCommand::ChangeDir(dir) => {
1390                let dir = context.expand(&dir)?;
1391
1392                cwriteln!(context.stream(), fg = Color::Yellow, "cd {dir}");
1393                cwriteln!(context.stream());
1394                let current_pwd = context.pwd();
1395                let new_pwd = current_pwd.join(dir);
1396                context.set_pwd(new_pwd);
1397                Ok(None)
1398            }
1399            InternalCommand::Set(name, value) => {
1400                let value = context.expand(&value)?;
1401
1402                context.set_env(&name, &value);
1403                let new_value = context.get_env(&name).unwrap_or_default();
1404                if new_value != value {
1405                    cwriteln!(
1406                        context.stream(),
1407                        fg = Color::Yellow,
1408                        "set {name} {value} (-> {new_value})"
1409                    );
1410                } else {
1411                    cwriteln!(context.stream(), fg = Color::Yellow, "set {name} {value}");
1412                }
1413                cwriteln!(context.stream());
1414
1415                Ok(None)
1416            }
1417            InternalCommand::ExitScript => {
1418                cwriteln!(context.stream(), fg = Color::Yellow, "exiting script");
1419                cwriteln!(context.stream());
1420                Err(ScriptRunError::ExitScript)
1421            }
1422        }
1423    }
1424}
1425
1426#[derive(Debug, Clone)]
1427pub enum IfCondition {
1428    True,
1429    False,
1430    EnvEq(bool, String, ShellBit),
1431}
1432
1433impl IfCondition {
1434    pub fn matches(&self, context: &ScriptRunContext) -> bool {
1435        match self {
1436            IfCondition::True => true,
1437            IfCondition::False => false,
1438            IfCondition::EnvEq(negated, name, expected) => {
1439                let value = context.get_env(name).unwrap_or_default();
1440                (expected == value) ^ negated
1441            }
1442        }
1443    }
1444
1445    pub fn expand(&self, context: &ScriptRunContext) -> Result<IfCondition, ScriptRunError> {
1446        match self {
1447            IfCondition::True => Ok(IfCondition::True),
1448            IfCondition::False => Ok(IfCondition::False),
1449            IfCondition::EnvEq(negated, name, expected) => {
1450                let value = context.expand(expected)?;
1451                Ok(IfCondition::EnvEq(
1452                    *negated,
1453                    name.clone(),
1454                    ShellBit::Literal(value),
1455                ))
1456            }
1457        }
1458    }
1459}
1460
1461impl Serialize for IfCondition {
1462    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1463    where
1464        S: serde::Serializer,
1465    {
1466        match self {
1467            IfCondition::True => "true".serialize(serializer),
1468            IfCondition::False => "false".serialize(serializer),
1469            IfCondition::EnvEq(negated, name, value) => {
1470                let mut ser = serializer.serialize_map(Some(3))?;
1471                ser.serialize_entry("op", if *negated { "!=" } else { "==" })?;
1472                ser.serialize_entry("env", name)?;
1473                ser.serialize_entry("value", value)?;
1474                ser.end()
1475            }
1476        }
1477    }
1478}
1479
1480#[derive(Debug)]
1481pub enum ForCondition {
1482    Env(String, Vec<ShellBit>),
1483}
1484
1485impl Serialize for ForCondition {
1486    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1487    where
1488        S: serde::Serializer,
1489    {
1490        match self {
1491            ForCondition::Env(name, values) => {
1492                let mut ser = serializer.serialize_map(Some(2))?;
1493                ser.serialize_entry("env", name)?;
1494                ser.serialize_entry("values", values)?;
1495                ser.end()
1496            }
1497        }
1498    }
1499}
1500
1501fn is_bool_false(b: &bool) -> bool {
1502    !b
1503}
1504
1505#[derive(Debug, Serialize)]
1506pub struct ScriptCommand {
1507    pub command: CommandLine,
1508    pub pattern: OutputPattern,
1509
1510    #[serde(skip_serializing_if = "CommandExit::is_success")]
1511    pub exit: CommandExit,
1512
1513    #[serde(skip_serializing_if = "is_bool_false")]
1514    pub expect_failure: bool,
1515
1516    /// Single set variable (entire command output trimmed)
1517    #[serde(skip_serializing_if = "Option::is_none")]
1518    pub set_var: Option<String>,
1519
1520    /// Specific set variables
1521    pub set_vars: HashMap<String, ShellBit>,
1522
1523    /// Specific command timeout
1524    #[serde(skip_serializing_if = "Option::is_none")]
1525    pub timeout: Option<Duration>,
1526
1527    /// Input grok expectations
1528    pub expect: HashMap<String, ShellBit>,
1529}
1530
1531impl ScriptCommand {
1532    pub fn new(command: CommandLine) -> Self {
1533        let location = command.location.clone();
1534        Self {
1535            command,
1536            pattern: OutputPattern {
1537                pattern: OutputPatternType::None,
1538                ignore: Default::default(),
1539                reject: Default::default(),
1540                location,
1541            },
1542            exit: Default::default(),
1543            timeout: None,
1544            expect_failure: false,
1545            set_var: None,
1546            set_vars: Default::default(),
1547            expect: Default::default(),
1548        }
1549    }
1550
1551    pub fn run(&self, context: &mut ScriptRunContext) -> Result<ScriptResult, ScriptRunError> {
1552        let command = &self.command;
1553        let args = &context.args;
1554        let start = Instant::now();
1555
1556        if let Some(delay) = args.delay_steps {
1557            std::thread::sleep(std::time::Duration::from_millis(delay));
1558        }
1559
1560        for line in command.command.split('\n') {
1561            cwriteln!(context.stream(), fg = Color::Green, "{}", line);
1562        }
1563        if args.simplified_output {
1564            cwriteln!(context.stream(), dimmed = true, "---");
1565        } else {
1566            cwriteln_rule!(context.stream(), fg = Color::Cyan, "{}", command.location);
1567        }
1568        let (output, status) = command.run(
1569            &mut context.stream(),
1570            context.args.show_line_numbers,
1571            context.args.runner.clone(),
1572            self.timeout.unwrap_or(context.timeout),
1573            context.env.env_vars(),
1574            &context.kill,
1575            &context.kill_sender,
1576        )?;
1577
1578        let exit_result = if !self.exit.matches(status) {
1579            ExitResult::Mismatch(status)
1580        } else {
1581            ExitResult::Matches(status)
1582        };
1583
1584        // Side-effects
1585        if let Some(set_var) = &self.set_var {
1586            context.set_env(set_var, output.to_string().trim());
1587        }
1588
1589        let match_context = OutputMatchContext::new(context);
1590        for (key, value) in &self.expect {
1591            match_context.expect(key, context.expand(value)?);
1592        }
1593        self.pattern.prepare(&context.grok)?;
1594        let prepared_output = output
1595            .with_ignore(&context.global_ignore)
1596            .with_reject(&context.global_reject);
1597        let pattern_result = match self.pattern.matches(match_context.clone(), prepared_output) {
1598            Ok(_) => {
1599                let mut env = context.env.clone();
1600                for (key, value) in match_context.expects() {
1601                    env.set_env(key, value);
1602                }
1603                for (key, value) in &self.set_vars {
1604                    context.set_env(key, env.expand(value)?);
1605                }
1606
1607                if self.expect_failure {
1608                    PatternResult::ExpectedFailure
1609                } else {
1610                    PatternResult::Matches
1611                }
1612            }
1613            Err(e) => {
1614                if self.expect_failure {
1615                    PatternResult::MatchesFailure
1616                } else {
1617                    let mut trace = String::new();
1618                    for line in match_context.traces() {
1619                        trace.push_str(&format!("{line}\n"));
1620                    }
1621                    PatternResult::Mismatch(e, trace)
1622                }
1623            }
1624        };
1625
1626        if output.is_empty() {
1627            cwriteln!(context.stream(), dimmed = true, "(no output)");
1628        }
1629
1630        if context.args.simplified_output {
1631            cwriteln!(context.stream(), dimmed = true, "---");
1632        } else {
1633            cwriteln_rule!(context.stream());
1634        }
1635
1636        Ok(ScriptResult {
1637            command: command.clone(),
1638            pattern: pattern_result,
1639            exit: exit_result,
1640            elapsed: start.elapsed(),
1641            output,
1642        })
1643    }
1644}
1645
1646#[derive(derive_more::Debug)]
1647pub struct ScriptResult {
1648    pub command: CommandLine,
1649    pub pattern: PatternResult,
1650    pub exit: ExitResult,
1651    pub elapsed: Duration,
1652    #[debug(skip)]
1653    pub output: Lines,
1654}
1655
1656impl ScriptResult {
1657    pub fn evaluate(&self, context: &mut ScriptRunContext) -> Result<(), ScriptRunError> {
1658        let args = &context.args;
1659        let (success, failure, warning, arrow) = if *crate::term::IS_UTF8 {
1660            ("✅", "❌", "⚠️", "→")
1661        } else {
1662            ("[*]", "[X]", "[!]", "->")
1663        };
1664
1665        if let ExitResult::Mismatch(status) = self.exit {
1666            if args.ignore_exit_codes {
1667                cwriteln!(
1668                    context.stream(),
1669                    fg = Color::Yellow,
1670                    "{warning} Ignored incorrect exit code: {status}"
1671                );
1672                cwriteln!(context.stream());
1673            } else {
1674                cwriteln!(
1675                    context.stream(),
1676                    fg = Color::Red,
1677                    "{failure} FAIL: {status}"
1678                );
1679                cwriteln!(
1680                    context.stream(),
1681                    dimmed = true,
1682                    " {arrow} {}",
1683                    self.command.command
1684                );
1685                cwriteln!(context.stream());
1686                return Err(ScriptRunError::Exit(status));
1687            }
1688        }
1689
1690        if let PatternResult::Mismatch(e, trace) = &self.pattern {
1691            if args.ignore_matches {
1692                cwriteln!(
1693                    context.stream(),
1694                    fg = Color::Yellow,
1695                    "{warning} Ignored error: {e} (ignoring mismatches)"
1696                );
1697                cwriteln!(context.stream());
1698            } else {
1699                cwriteln!(context.stream(), fg = Color::Red, "ERROR: {e}");
1700                cwriteln!(context.stream(), dimmed = true, "{trace}");
1701                cwriteln!(context.stream(), fg = Color::Red, "{failure} FAIL");
1702                cwriteln!(context.stream());
1703                return Err(ScriptRunError::Pattern(e.clone()));
1704            }
1705        }
1706
1707        if let PatternResult::ExpectedFailure = self.pattern {
1708            if args.ignore_matches {
1709                cwriteln!(
1710                    context.stream(),
1711                    fg = Color::Yellow,
1712                    "{warning} Should not have matched! (ignoring mismatches)"
1713                );
1714                cwriteln!(context.stream());
1715            } else {
1716                cwriteln!(
1717                    context.stream(),
1718                    fg = Color::Red,
1719                    "{failure} FAIL (output shouldn't match)"
1720                );
1721                cwriteln!(
1722                    context.stream(),
1723                    dimmed = true,
1724                    " {arrow} {}",
1725                    self.command.command
1726                );
1727                cwriteln!(context.stream());
1728                return Err(ScriptRunError::ExpectedFailure);
1729            }
1730        }
1731
1732        if let ExitResult::Matches(status) = self.exit {
1733            if status.success() {
1734                cwrite!(context.stream(), fg = Color::Green, "{success} OK");
1735                if !context.args.simplified_output {
1736                    cwriteln!(
1737                        context.stream(),
1738                        dimmed = true,
1739                        " ({:.2}s)",
1740                        self.elapsed.as_secs_f32()
1741                    );
1742                } else {
1743                    cwriteln!(context.stream());
1744                }
1745            } else {
1746                cwrite!(
1747                    context.stream(),
1748                    fg = Color::Green,
1749                    "{success} OK ({status})"
1750                );
1751                if !context.args.simplified_output {
1752                    cwriteln!(
1753                        context.stream(),
1754                        dimmed = true,
1755                        " ({:.2}s)",
1756                        self.elapsed.as_secs_f32()
1757                    );
1758                } else {
1759                    cwriteln!(context.stream());
1760                }
1761            }
1762            cwriteln!(context.stream());
1763        }
1764
1765        Ok(())
1766    }
1767}
1768
1769#[derive(Debug)]
1770pub enum PatternResult {
1771    Matches,
1772    MatchesFailure,
1773    ExpectedFailure,
1774    Mismatch(OutputPatternMatchFailure, String),
1775}
1776
1777#[derive(Debug)]
1778pub enum ExitResult {
1779    Matches(CommandResult),
1780    Mismatch(CommandResult),
1781    TimedOut,
1782}
1783
1784#[cfg(test)]
1785mod tests {
1786    use crate::parser::v0::parse_script;
1787
1788    use super::*;
1789    use std::error::Error;
1790
1791    #[test]
1792    fn test_script() -> Result<(), Box<dyn Error>> {
1793        let script = r#"
1794pattern VERSION \d+\.\d+\.\d+;
1795
1796$ something --version || echo 1
1797? Something %{VERSION}
1798
1799$ something --help
1800? Usage: something [OPTIONS]
1801repeat {
1802    choice {
1803? %{DATA} %{GREEDYDATA}
1804? %{DATA}=%{DATA} %{GREEDYDATA}
1805    }
1806}
1807"#;
1808
1809        let script = parse_script(ScriptFile::new("test.cli"), script)?;
1810        assert_eq!(script.commands.len(), 3);
1811        eprintln!("{:?}", script);
1812        Ok(())
1813    }
1814
1815    #[test]
1816    fn test_bad_script() -> Result<(), Box<dyn Error>> {
1817        let script = r#"
1818$ (cmd; cmd)
1819$ cmd &
1820    "#;
1821
1822        assert!(matches!(
1823            parse_script(ScriptFile::new("test.cli"), script),
1824            Err(ScriptError {
1825                error: ScriptErrorType::BackgroundProcessNotAllowed,
1826                ..
1827            })
1828        ));
1829        Ok(())
1830    }
1831
1832    #[test]
1833    fn test_script_run_context_expand() {
1834        let mut context = ScriptEnv::default();
1835        context.set_env("A", "1");
1836        context.set_env("B", "2");
1837        context.set_env("C", "3");
1838        assert_eq!(context.expand_str("$A").unwrap(), "1".to_string());
1839        assert_eq!(context.expand_str("$A $B ").unwrap(), "1 2 ".to_string());
1840        assert_eq!(
1841            context.expand_str("${A} ${B} ").unwrap(),
1842            "1 2 ".to_string()
1843        );
1844        assert_eq!(context.expand_str(r#"\$A"#).unwrap(), "$A".to_string());
1845        assert_eq!(context.expand_str(r#"\${A}"#).unwrap(), "${A}".to_string());
1846        assert_eq!(context.expand_str(r#"\\$A"#).unwrap(), r#"\1"#);
1847        assert_eq!(context.expand_str(r#"\\${A}"#).unwrap(), r#"\1"#);
1848        context.set_env("TEMP_DIR", "/tmp");
1849        assert_eq!(context.expand_str("$TEMP_DIR").unwrap(), "/tmp".to_string());
1850        assert_eq!(
1851            context.expand_str("${TEMP_DIR}").unwrap(),
1852            "/tmp".to_string()
1853        );
1854    }
1855}