Skip to main content

nu_test_support/tester/
mod.rs

1use std::{
2    env,
3    error::Error,
4    fmt::{Debug, Display},
5    panic::Location,
6    path::PathBuf,
7    sync::{Arc, LazyLock},
8};
9
10use nu_protocol::{
11    CompileError, Config, FromValue, IntoValue, LabeledError, ParseError, PipelineData,
12    PipelineExecutionData, ShellError, Span, Value,
13    ast::Block,
14    debugger::WithoutDebug,
15    engine::{Command, EngineState, Stack, StateDelta, StateWorkingSet},
16    shell_error::{io::IoError, network::NetworkError},
17};
18use nu_utils::{consts::ENV_PATH_SEPARATOR_CHAR, sync::KeyedLazyLock};
19
20use crate::harness::group::GroupKey;
21
22static ROOT: LazyLock<PathBuf> = LazyLock::new(|| {
23    PathBuf::from(env!("CARGO_MANIFEST_DIR"))
24        .join("../..")
25        .canonicalize()
26        .expect("could not canonicalize root")
27});
28
29// By using different engine states depending on the group key, we can ensure that behavior from
30// experimental options or environment variables take proper effect in the setup of an engine state.
31static INITIAL_ENGINE_STATES: KeyedLazyLock<GroupKey, EngineState> = KeyedLazyLock::new(|_| {
32    let engine_state = nu_cmd_lang::create_default_context();
33    let engine_state = nu_command::add_shell_command_context(engine_state);
34    let mut engine_state = nu_cmd_extra::add_extra_command_context(engine_state);
35
36    engine_state.generate_nu_constant();
37    [
38        ("PWD", Value::test_string(ROOT.to_string_lossy())),
39        ("config", Config::default().into_value(Span::unknown())),
40    ]
41    .into_iter()
42    .for_each(|(key, val)| engine_state.add_env_var(key.into(), val));
43
44    nu_std::load_standard_library(&mut engine_state).expect("could not load standard library");
45
46    engine_state
47});
48
49/// Create a [`NuTester`] for running Nushell snippets in tests.
50///
51/// Prefer this helper over the `nu!` macro for most tests.
52/// It runs snippets in-process instead of shelling out to a subprocess, which makes tests faster
53/// and lets you pass and read values directly without inferring from stdout or stderr.
54/// The `nu!` macro executes the `nu` binary, and changes in a single crate might not trigger a
55/// rebuild of that binary, so tests can run against stale behavior unless you run `cargo build`
56/// first.
57/// Using this helper avoids that by executing against the in-process engine components.
58///
59/// The tester starts from a default [`EngineState`] with the standard library loaded, and a fresh
60/// [`Stack`].
61/// Use the returned value to configure environment variables or the working directory before
62/// running code.
63///
64/// # Environment behavior
65///
66/// - This tester does not inherit process environment variables.
67/// - Any variables you want available to the engine must be added explicitly via
68///   [`NuTester::env`] (or convenience helpers like [`NuTester::locale`]).
69/// - Experimental options and other external environment settings are respected
70///   when constructing the underlying engine state for the current test group.
71///
72/// # Examples
73///
74/// ```rust
75/// use nu_test_support::prelude::*;
76///
77/// let code = "use std/util ellie; ellie | ansi strip";
78/// let value: String = test().run(code)?;
79/// assert_eq!(value, r#"
80///      __  ,
81///  .--()°'.'
82/// '|, . ,'
83///  !_-(_\
84/// "#.trim_matches('\n'));
85/// # Ok::<(), nu_test_support::tester::TestError>(())
86/// ```
87///
88/// ```rust
89/// use nu_test_support::prelude::*;
90///
91/// let mut tester = test()
92///     .env("FOO", "bar")
93///     .cwd("crates/nu-test-support");
94///
95/// let value: String = tester.run("$env.FOO")?;
96/// assert_eq!(value, "bar");
97/// # Ok::<(), nu_test_support::tester::TestError>(())
98/// ```
99pub fn test() -> NuTester {
100    NuTester::default()
101}
102
103/// Helper for running Nushell code in tests.
104///
105/// `NuTester` owns an [`EngineState`] and [`Stack`] that are reused across invocations.
106/// Configuration methods update the engine state before execution.
107#[derive(Clone)]
108pub struct NuTester {
109    engine_state: EngineState,
110    stack: Stack,
111}
112
113impl Default for NuTester {
114    /// Create a default tester.
115    ///
116    /// Prefer [`test()`] for a shorter entry point that avoids naming [`NuTester`].
117    fn default() -> Self {
118        Self {
119            engine_state: INITIAL_ENGINE_STATES.get(&GroupKey::current()).clone(),
120            stack: Stack::new().collect_value(),
121        }
122    }
123}
124
125impl NuTester {
126    /// Create a default tester with the standard engine state.
127    ///
128    /// Prefer [`test()`] for a shorter entry point that avoids naming [`NuTester`].
129    pub fn new() -> Self {
130        Self::default()
131    }
132
133    /// Set the working directory used for evaluation.
134    ///
135    /// Relative paths are resolved from the repository root and canonicalized.
136    pub fn cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
137        let cwd = cwd.into();
138
139        let cwd = match cwd.is_absolute() {
140            true => cwd,
141            false => ROOT
142                .join(cwd)
143                .canonicalize()
144                .expect("could not canonicalize path"),
145        };
146
147        self.engine_state
148            .add_env_var("PWD".into(), Value::test_string(cwd.to_string_lossy()));
149        self
150    }
151
152    /// Set the locale used by tests via `NU_TEST_LOCALE_OVERRIDE`.
153    pub fn locale(mut self, locale: impl Into<String>) -> Self {
154        self.engine_state.add_env_var(
155            "NU_TEST_LOCALE_OVERRIDE".into(),
156            Value::test_string(locale.into()),
157        );
158        self
159    }
160
161    /// Set the locale to `en_US.utf8`.
162    pub fn locale_en(self) -> Self {
163        self.locale("en_US.utf8")
164    }
165
166    /// Inherit the `PATH` environment variable from the running process.
167    ///
168    /// This is useful for tests that spawn external commands and should resolve
169    /// binaries the same way as the parent test process.
170    ///
171    /// Panics if `PATH` is not set in the current process environment.
172    pub fn inherit_path(self) -> Self {
173        let path = env::var("PATH").expect("PATH not available in env");
174        self.env("PATH", path)
175    }
176
177    /// Inherit an environment variable from the running process, but only if it is set.
178    ///
179    /// This is useful for optional variables whose absence should not cause a panic.
180    pub fn inherit_env_if_set(self, key: impl AsRef<str>) -> Self {
181        let key = key.as_ref();
182        match env::var(key) {
183            Ok(val) => self.env(key, val),
184            Err(_) => self,
185        }
186    }
187
188    /// Inherit Rust toolchain related environment variables from the running process,
189    /// but only when they are set.
190    ///
191    /// This helps tests that spawn `cargo`, `rustc`, or `rustup` behave more like
192    /// the parent process, especially when the active toolchain or install location
193    /// is configured through environment variables.
194    ///
195    /// The following variables are inherited when present:
196    /// - `PATH`
197    /// - `CARGO_HOME`
198    /// - `RUSTUP_HOME`
199    /// - `RUSTUP_TOOLCHAIN`
200    /// - `RUSTUP_DIST_SERVER`
201    /// - `RUSTUP_UPDATE_ROOT`
202    ///
203    /// Proxy variables are also inherited when present since `rustup` may need them
204    /// to download or resolve toolchain metadata:
205    /// - `HTTP_PROXY`, `HTTPS_PROXY`, `NO_PROXY`
206    /// - `http_proxy`, `https_proxy`, `no_proxy`
207    ///
208    /// This does not guarantee identical behavior to an interactive shell since the
209    /// current working directory can still affect rustup toolchain resolution.
210    pub fn inherit_rust_toolchain_env(self) -> Self {
211        self.inherit_env_if_set("PATH")
212            .inherit_env_if_set("CARGO_HOME")
213            .inherit_env_if_set("RUSTUP_HOME")
214            .inherit_env_if_set("RUSTUP_TOOLCHAIN")
215            .inherit_env_if_set("RUSTUP_DIST_SERVER")
216            .inherit_env_if_set("RUSTUP_UPDATE_ROOT")
217            .inherit_env_if_set("HTTP_PROXY")
218            .inherit_env_if_set("HTTPS_PROXY")
219            .inherit_env_if_set("NO_PROXY")
220            .inherit_env_if_set("http_proxy")
221            .inherit_env_if_set("https_proxy")
222            .inherit_env_if_set("no_proxy")
223    }
224
225    /// Adds the "nu" binary for testing to the path.
226    ///
227    /// Calling [`inherit_path`](Self::inherit_path) after this methods removes the path entry.
228    pub fn add_nu_to_path(self) -> Self {
229        let nu_home = crate::fs::binaries();
230        let path = self.engine_state.get_env_var("PATH");
231        let path = match path {
232            None => nu_home.display().to_string(),
233            Some(path) => format!(
234                "{nu}{sep}{prev}",
235                nu = nu_home.display(),
236                sep = ENV_PATH_SEPARATOR_CHAR,
237                prev = path.as_str().expect("PATH should always be a string")
238            ),
239        };
240        self.env("PATH", path)
241    }
242
243    /// Add a custom environment variable to the engine state.
244    pub fn env(mut self, key: impl Into<String>, val: impl Into<String>) -> Self {
245        self.engine_state
246            .add_env_var(key.into(), Value::test_string(val.into()));
247        self
248    }
249
250    /// Run Nushell code and extract the value into `T`.
251    ///
252    /// Parsing, compilation, or evaluation failures are returned as [`TestError`].
253    #[track_caller]
254    pub fn run<T: FromValue>(&mut self, code: impl AsRef<str>) -> Result<T> {
255        Self::extract_value(self.run_raw(code)?)
256    }
257
258    /// Run Nushell code with input data and extract the value into `T`.
259    ///
260    /// The input value is converted into `PipelineData` using [`IntoValue`].
261    #[track_caller]
262    pub fn run_with_data<T: FromValue>(
263        &mut self,
264        code: impl AsRef<str>,
265        data: impl IntoValue,
266    ) -> Result<T> {
267        let input = PipelineData::value(data.into_value(Span::test_data()), None);
268        Self::extract_value(self.run_raw_with_data(code, input)?)
269    }
270
271    /// Run Nushell code and return the raw [`PipelineExecutionData`].
272    #[track_caller]
273    pub fn run_raw(&mut self, code: impl AsRef<str>) -> Result<PipelineExecutionData> {
274        self.run_raw_with_data(code, PipelineData::empty())
275    }
276
277    /// Run Nushell code with input data and return the raw execution results.
278    ///
279    /// This parses, compiles, and evaluates the code against the current engine state.
280    #[track_caller]
281    pub fn run_raw_with_data(
282        &mut self,
283        code: impl AsRef<str>,
284        data: PipelineData,
285    ) -> Result<PipelineExecutionData> {
286        let location = TestLocation(Location::caller());
287        let (delta, block) = self.parse_and_compile(code)?;
288        self.engine_state.merge_delta(delta)?;
289        nu_engine::eval_block::<WithoutDebug>(&self.engine_state, &mut self.stack, &block, data)
290            .map_err(|err| TestError {
291                location,
292                kind: TestErrorKind::Shell(err),
293            })
294    }
295
296    #[track_caller]
297    pub fn parse_and_compile(&self, code: impl AsRef<str>) -> Result<(StateDelta, Arc<Block>)> {
298        let location = TestLocation(Location::caller());
299        let code = code.as_ref().as_bytes();
300
301        let mut working_set = StateWorkingSet::new(&self.engine_state);
302        let block = nu_parser::parse(&mut working_set, None, code, false);
303
304        if let Some(err) = working_set.parse_errors.into_iter().next() {
305            return Err(TestError {
306                location,
307                kind: TestErrorKind::Parse(err),
308            });
309        }
310
311        if let Some(err) = working_set.compile_errors.into_iter().next() {
312            return Err(TestError {
313                location,
314                kind: TestErrorKind::Compile(err),
315            });
316        }
317
318        Ok((working_set.delta, block))
319    }
320
321    #[track_caller]
322    fn extract_value<T: FromValue>(
323        pipeline_execution_data: PipelineExecutionData,
324    ) -> Result<T, TestError> {
325        let pipeline_data = pipeline_execution_data.body;
326        let value = pipeline_data.into_value(Span::test_data())?;
327        let value = T::from_value(value)?;
328        Ok(value)
329    }
330
331    /// Test examples of a command.
332    #[track_caller]
333    pub fn examples(&self, command: impl Command + 'static) -> Result {
334        let location = TestLocation(Location::caller());
335        for example in command.examples() {
336            match example.result {
337                None => self
338                    .parse_and_compile(example.example)
339                    .map(|_| ())
340                    .map_err(|err| TestError {
341                        location,
342                        kind: TestErrorKind::ExampleFailed {
343                            command: command.name().to_string(),
344                            description: example.description.to_string(),
345                            code: example.example.to_string(),
346                            err: Box::new(err.kind),
347                        },
348                    })?,
349                Some(expected) => {
350                    let got = self.clone().run(example.example)?;
351                    if got != expected {
352                        return Err(TestError {
353                            location,
354                            kind: TestErrorKind::ExampleFailed {
355                                command: command.name().to_string(),
356                                description: example.description.to_string(),
357                                code: example.example.to_string(),
358                                err: Box::new(TestErrorKind::UnexpectedValue { expected, got }),
359                            },
360                        });
361                    }
362                }
363            }
364        }
365
366        Ok(())
367    }
368}
369
370#[derive(Debug, Clone, PartialEq)]
371pub struct TestError {
372    location: TestLocation,
373    kind: TestErrorKind,
374}
375
376#[derive(Clone, Copy, PartialEq, derive_more::Debug)]
377#[debug("{_0}")]
378pub struct TestLocation(&'static Location<'static>);
379
380/// Errors emitted by `NuTester` when parsing, compiling, or evaluating code.
381///
382/// This enum is marked as non-exhaustive to allow adding new variants.
383#[non_exhaustive]
384#[derive(Debug, Clone, PartialEq)]
385pub enum TestErrorKind {
386    Parse(ParseError),
387    Compile(CompileError),
388    Shell(ShellError),
389    GotValue {
390        got: Value,
391    },
392    NoInner,
393    UnexpectedErrorKind {
394        expected: &'static str,
395        got: ShellError,
396    },
397    UnexpectedValue {
398        expected: Value,
399        got: Value,
400    },
401    ExampleFailed {
402        command: String,
403        description: String,
404        code: String,
405        err: Box<TestErrorKind>,
406    },
407}
408
409impl Display for TestError {
410    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411        write!(f, "{self:#?}")
412    }
413}
414
415impl Error for TestError {}
416
417impl From<ShellError> for TestError {
418    #[track_caller]
419    fn from(err: ShellError) -> Self {
420        Self {
421            location: TestLocation(Location::caller()),
422            kind: TestErrorKind::Shell(err),
423        }
424    }
425}
426
427impl From<ParseError> for TestError {
428    #[track_caller]
429    fn from(err: ParseError) -> Self {
430        Self {
431            location: TestLocation(Location::caller()),
432            kind: TestErrorKind::Parse(err),
433        }
434    }
435}
436
437impl TestError {
438    /// Convert this error into a [`ParseError`], if it is one.
439    pub fn parse(self) -> Result<ParseError, TestError> {
440        match self.kind {
441            TestErrorKind::Parse(err) => Ok(err),
442            _ => Err(self),
443        }
444    }
445
446    /// Convert this error into a [`CompileError`], if it is one.
447    pub fn compile(self) -> Result<CompileError, TestError> {
448        match self.kind {
449            TestErrorKind::Compile(err) => Ok(err),
450            _ => Err(self),
451        }
452    }
453
454    /// Convert this error into a [`ShellError`], if it is one.
455    pub fn shell(self) -> Result<ShellError, TestError> {
456        match self.kind {
457            TestErrorKind::Shell(err) => Ok(err),
458            _ => Err(self),
459        }
460    }
461
462    /// Update it's inner location with the call site of this function.
463    #[track_caller]
464    pub fn update_location(self) -> Self {
465        Self {
466            location: TestLocation(Location::caller()),
467            ..self
468        }
469    }
470}
471
472/// Convenience result type for test helpers.
473pub type Result<T = (), E = TestError> = std::result::Result<T, E>;
474
475/// Extensions for asserting error kinds from test helpers.
476pub trait TestResultExt: Sized {
477    /// Expect the result to be a `Value` equal to the provided input.
478    fn expect_value_eq<T: IntoValue>(self, value: T) -> Result;
479
480    /// Expect the result to be a [`ShellError`].
481    fn expect_shell_error(self) -> Result<ShellError>;
482    /// Expect the result to be a [`ParseError`].
483    fn expect_parse_error(self) -> Result<ParseError>;
484    /// Expect the result to be a [`CompileError`].
485    fn expect_compile_error(self) -> Result<CompileError>;
486
487    /// Expect the result to be a [`ShellError::Io`].
488    fn expect_io_error(self) -> Result<IoError>;
489    /// Expect the result to be a [`ShellError::Network`].
490    fn expect_network_error(self) -> Result<NetworkError>;
491    /// Expect the result to be a [`ShellError::LabeledError`].
492    fn expect_labeled_error(self) -> Result<LabeledError>;
493
494    /// Expect the result to be a [`ShellError`].
495    #[track_caller]
496    fn expect_error(self) -> Result<ShellError> {
497        self.expect_shell_error()
498    }
499}
500
501impl TestResultExt for Result<Value> {
502    #[track_caller]
503    fn expect_value_eq<T: IntoValue>(self, expected: T) -> Result {
504        let expected = expected.into_value(Span::test_data());
505        match self {
506            Err(err) => Err(err.update_location()),
507            Ok(actual) if actual == expected => Ok(()),
508            Ok(actual) => Err(TestError {
509                location: TestLocation(Location::caller()),
510                kind: TestErrorKind::UnexpectedValue {
511                    expected,
512                    got: actual,
513                },
514            }),
515        }
516    }
517
518    #[track_caller]
519    fn expect_shell_error(self) -> Result<ShellError> {
520        match self {
521            Ok(got) => Err(TestError {
522                location: TestLocation(Location::caller()),
523                kind: TestErrorKind::GotValue { got },
524            }),
525            Err(TestError {
526                kind: TestErrorKind::Shell(err),
527                ..
528            }) => Ok(err),
529            Err(err) => Err(err.update_location()),
530        }
531    }
532
533    #[track_caller]
534    fn expect_parse_error(self) -> Result<ParseError> {
535        match self {
536            Ok(got) => Err(TestError {
537                location: TestLocation(Location::caller()),
538                kind: TestErrorKind::GotValue { got },
539            }),
540            Err(TestError {
541                kind: TestErrorKind::Parse(err),
542                ..
543            }) => Ok(err),
544            Err(err) => Err(err.update_location()),
545        }
546    }
547
548    #[track_caller]
549    fn expect_compile_error(self) -> Result<CompileError> {
550        match self {
551            Ok(got) => Err(TestError {
552                location: TestLocation(Location::caller()),
553                kind: TestErrorKind::GotValue { got },
554            }),
555            Err(TestError {
556                kind: TestErrorKind::Compile(err),
557                ..
558            }) => Ok(err),
559            Err(err) => Err(err.update_location()),
560        }
561    }
562
563    #[track_caller]
564    fn expect_io_error(self) -> Result<IoError> {
565        match self {
566            Ok(got) => Err(TestError {
567                location: TestLocation(Location::caller()),
568                kind: TestErrorKind::GotValue { got },
569            }),
570            Err(TestError {
571                kind: TestErrorKind::Shell(ShellError::Io(err)),
572                ..
573            }) => Ok(err),
574            Err(err) => Err(err.update_location()),
575        }
576    }
577
578    #[track_caller]
579    fn expect_network_error(self) -> Result<NetworkError> {
580        match self {
581            Ok(got) => Err(TestError {
582                location: TestLocation(Location::caller()),
583                kind: TestErrorKind::GotValue { got },
584            }),
585            Err(TestError {
586                kind: TestErrorKind::Shell(ShellError::Network(err)),
587                ..
588            }) => Ok(err),
589            Err(err) => Err(err.update_location()),
590        }
591    }
592
593    #[track_caller]
594    fn expect_labeled_error(self) -> Result<LabeledError> {
595        match self {
596            Ok(got) => Err(TestError {
597                location: TestLocation(Location::caller()),
598                kind: TestErrorKind::GotValue { got },
599            }),
600            Err(TestError {
601                kind: TestErrorKind::Shell(ShellError::LabeledError(err)),
602                ..
603            }) => Ok(*err),
604            Err(err) => Err(err.update_location()),
605        }
606    }
607}
608
609/// Extensions for interrogating [`ShellError`] values in tests.
610pub trait ShellErrorExt {
611    /// Tries to convert into an inner value from a [`ShellError`].
612    ///
613    /// Useful if the error is expected to be a generic error that contains an inner error or a
614    /// chained error that chained another error.
615    ///
616    /// However, this function returns [`None`]
617    /// - if `inner` of [`ShellError::Generic`] is empty
618    /// - if `sources` of [`ShellError::ChainedError`] is empty
619    /// - the error is none of the above types
620    ///
621    /// So make sure that a [`None`] value is not surprise.
622    fn into_inner(self) -> Result<ShellError>;
623
624    /// Extract the [`LabeledError`] from [`ShellError::LabeledError`], if it is one.
625    fn into_labeled(self) -> Result<LabeledError>;
626
627    /// Extract the error field from [`ShellError::Generic`], if it is one.
628    fn generic_error(self) -> Result<String>;
629
630    /// Extract the message field from [`ShellError::Generic`], if it is one.
631    fn generic_msg(self) -> Result<String>;
632}
633
634impl ShellErrorExt for ShellError {
635    #[track_caller]
636    fn into_inner(self) -> Result<ShellError> {
637        let no_inner = TestError {
638            location: TestLocation(Location::caller()),
639            kind: TestErrorKind::NoInner,
640        };
641        match self {
642            ShellError::Generic(err) => err.inner.into_iter().next().ok_or(no_inner),
643            ShellError::ChainedError(err) => err.sources_iter().next().ok_or(no_inner),
644            _ => Err(no_inner),
645        }
646    }
647
648    #[track_caller]
649    fn into_labeled(self) -> Result<LabeledError> {
650        match self {
651            ShellError::LabeledError(err) => Ok(*err),
652            got => Err(TestError {
653                location: TestLocation(Location::caller()),
654                kind: TestErrorKind::UnexpectedErrorKind {
655                    expected: "Labeled",
656                    got,
657                },
658            }),
659        }
660    }
661
662    #[track_caller]
663    fn generic_error(self) -> Result<String> {
664        match self {
665            ShellError::Generic(err) => Ok(err.error.into_owned()),
666            got => Err(TestError {
667                location: TestLocation(Location::caller()),
668                kind: TestErrorKind::UnexpectedErrorKind {
669                    expected: "Generic",
670                    got,
671                },
672            }),
673        }
674    }
675
676    #[track_caller]
677    fn generic_msg(self) -> Result<String> {
678        match self {
679            ShellError::Generic(err) => Ok(err.msg.into_owned()),
680            got => Err(TestError {
681                location: TestLocation(Location::caller()),
682                kind: TestErrorKind::UnexpectedErrorKind {
683                    expected: "Generic",
684                    got,
685                },
686            }),
687        }
688    }
689}