Skip to main content

nu_test_support/tester/
mod.rs

1use std::{
2    env,
3    error::Error,
4    fmt::{Debug, Display},
5    panic::Location,
6    path::PathBuf,
7    sync::{Arc, LazyLock},
8};
9
10use nu_protocol::{
11    CompileError, Config, FromValue, IntoValue, LabeledError, ParseError, PipelineData,
12    PipelineExecutionData, ShellError, Span, Value,
13    ast::Block,
14    debugger::WithoutDebug,
15    engine::{Command, EngineState, Stack, StateDelta, StateWorkingSet},
16    shell_error::{io::IoError, network::NetworkError},
17};
18use nu_utils::{consts::ENV_PATH_SEPARATOR_CHAR, sync::KeyedLazyLock};
19
20use crate::harness::group::GroupKey;
21
22static ROOT: LazyLock<PathBuf> = LazyLock::new(|| {
23    PathBuf::from(env!("CARGO_MANIFEST_DIR"))
24        .join("../..")
25        .canonicalize()
26        .expect("could not canonicalize root")
27});
28
29// By using different engine states depending on the group key, we can ensure that behavior from
30// experimental options or environment variables take proper effect in the setup of an engine state.
31static INITIAL_ENGINE_STATES: KeyedLazyLock<GroupKey, EngineState> = KeyedLazyLock::new(|_| {
32    let engine_state = nu_cmd_lang::create_default_context();
33    let engine_state = nu_command::add_shell_command_context(engine_state);
34    let mut engine_state = nu_cmd_extra::add_extra_command_context(engine_state);
35
36    engine_state.generate_nu_constant();
37    [
38        ("PWD", Value::test_string(ROOT.to_string_lossy())),
39        ("config", Config::default().into_value(Span::unknown())),
40    ]
41    .into_iter()
42    .for_each(|(key, val)| engine_state.add_env_var(key.into(), val));
43
44    nu_std::load_standard_library(&mut engine_state).expect("could not load standard library");
45
46    engine_state
47});
48
49/// Create a [`NuTester`] for running Nushell snippets in tests.
50///
51/// Prefer this helper over the `nu!` macro for most tests.
52/// It runs snippets in-process instead of shelling out to a subprocess, which makes tests faster
53/// and lets you pass and read values directly without inferring from stdout or stderr.
54/// The `nu!` macro executes the `nu` binary, and changes in a single crate might not trigger a
55/// rebuild of that binary, so tests can run against stale behavior unless you run `cargo build`
56/// first.
57/// Using this helper avoids that by executing against the in-process engine components.
58///
59/// The tester starts from a default [`EngineState`] with the standard library loaded, and a fresh
60/// [`Stack`].
61/// Use the returned value to configure environment variables or the working directory before
62/// running code.
63///
64/// # Environment behavior
65///
66/// - This tester does not inherit process environment variables.
67/// - Any variables you want available to the engine must be added explicitly via
68///   [`NuTester::env`] (or convenience helpers like [`NuTester::locale`]).
69/// - Experimental options and other external environment settings are respected
70///   when constructing the underlying engine state for the current test group.
71///
72/// # Examples
73///
74/// ```rust
75/// use nu_test_support::prelude::*;
76///
77/// let code = "use std/util ellie; ellie | ansi strip";
78/// let value: String = test().run(code)?;
79/// assert_eq!(value, r#"
80///      __  ,
81///  .--()°'.'
82/// '|, . ,'
83///  !_-(_\
84/// "#.trim_matches('\n'));
85/// # Ok::<(), nu_test_support::tester::TestError>(())
86/// ```
87///
88/// ```rust
89/// use nu_test_support::prelude::*;
90///
91/// let mut tester = test()
92///     .env("FOO", "bar")
93///     .cwd("crates/nu-test-support");
94///
95/// let value: String = tester.run("$env.FOO")?;
96/// assert_eq!(value, "bar");
97/// # Ok::<(), nu_test_support::tester::TestError>(())
98/// ```
99pub fn test() -> NuTester {
100    NuTester::default()
101}
102
103/// Helper for running Nushell code in tests.
104///
105/// `NuTester` owns an [`EngineState`] and [`Stack`] that are reused across invocations.
106/// Configuration methods update the engine state before execution.
107#[derive(Clone)]
108pub struct NuTester {
109    engine_state: EngineState,
110    stack: Stack,
111}
112
113impl Default for NuTester {
114    /// Create a default tester.
115    ///
116    /// Prefer [`test()`] for a shorter entry point that avoids naming [`NuTester`].
117    fn default() -> Self {
118        Self {
119            engine_state: INITIAL_ENGINE_STATES.get(&GroupKey::current()).clone(),
120            stack: Stack::new().collect_value(),
121        }
122    }
123}
124
125impl NuTester {
126    /// Create a default tester with the standard engine state.
127    ///
128    /// Prefer [`test()`] for a shorter entry point that avoids naming [`NuTester`].
129    pub fn new() -> Self {
130        Self::default()
131    }
132
133    /// Set the working directory used for evaluation.
134    ///
135    /// Relative paths are resolved from the repository root and canonicalized.
136    pub fn cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
137        let cwd = cwd.into();
138
139        let cwd = match cwd.is_absolute() {
140            true => cwd,
141            false => ROOT
142                .join(cwd)
143                .canonicalize()
144                .expect("could not canonicalize path"),
145        };
146
147        self.engine_state
148            .add_env_var("PWD".into(), Value::test_string(cwd.to_string_lossy()));
149        self
150    }
151
152    /// Set the locale used by tests via `NU_TEST_LOCALE_OVERRIDE`.
153    pub fn locale(mut self, locale: impl Into<String>) -> Self {
154        self.engine_state.add_env_var(
155            "NU_TEST_LOCALE_OVERRIDE".into(),
156            Value::test_string(locale.into()),
157        );
158        self
159    }
160
161    /// Set the locale to `en_US.utf8`.
162    pub fn locale_en(self) -> Self {
163        self.locale("en_US.utf8")
164    }
165
166    /// Inherit the `PATH` environment variable from the running process.
167    ///
168    /// This is useful for tests that spawn external commands and should resolve
169    /// binaries the same way as the parent test process.
170    ///
171    /// Panics if `PATH` is not set in the current process environment.
172    pub fn inherit_path(self) -> Self {
173        let path = env::var("PATH").expect("PATH not available in env");
174        self.env("PATH", path)
175    }
176
177    /// Inherit an environment variable from the running process, but only if it is set.
178    ///
179    /// This is useful for optional variables whose absence should not cause a panic.
180    pub fn inherit_env_if_set(self, key: impl AsRef<str>) -> Self {
181        let key = key.as_ref();
182        match env::var(key) {
183            Ok(val) => self.env(key, val),
184            Err(_) => self,
185        }
186    }
187
188    /// Inherit Rust toolchain related environment variables from the running process,
189    /// but only when they are set.
190    ///
191    /// This helps tests that spawn `cargo`, `rustc`, or `rustup` behave more like
192    /// the parent process, especially when the active toolchain or install location
193    /// is configured through environment variables.
194    ///
195    /// The following variables are inherited when present:
196    /// - `PATH`
197    /// - `CARGO_HOME`
198    /// - `RUSTUP_HOME`
199    /// - `RUSTUP_TOOLCHAIN`
200    /// - `RUSTUP_DIST_SERVER`
201    /// - `RUSTUP_UPDATE_ROOT`
202    ///
203    /// Proxy variables are also inherited when present since `rustup` may need them
204    /// to download or resolve toolchain metadata:
205    /// - `HTTP_PROXY`, `HTTPS_PROXY`, `NO_PROXY`
206    /// - `http_proxy`, `https_proxy`, `no_proxy`
207    ///
208    /// This does not guarantee identical behavior to an interactive shell since the
209    /// current working directory can still affect rustup toolchain resolution.
210    pub fn inherit_rust_toolchain_env(self) -> Self {
211        self.inherit_env_if_set("PATH")
212            .inherit_env_if_set("CARGO_HOME")
213            .inherit_env_if_set("RUSTUP_HOME")
214            .inherit_env_if_set("RUSTUP_TOOLCHAIN")
215            .inherit_env_if_set("RUSTUP_DIST_SERVER")
216            .inherit_env_if_set("RUSTUP_UPDATE_ROOT")
217            .inherit_env_if_set("HTTP_PROXY")
218            .inherit_env_if_set("HTTPS_PROXY")
219            .inherit_env_if_set("NO_PROXY")
220            .inherit_env_if_set("http_proxy")
221            .inherit_env_if_set("https_proxy")
222            .inherit_env_if_set("no_proxy")
223    }
224
225    /// Adds the "nu" binary for testing to the path.
226    ///
227    /// Calling [`inherit_path`](Self::inherit_path) after this methods removes the path entry.
228    pub fn add_nu_to_path(self) -> Self {
229        let nu_home = crate::fs::binaries();
230        let path = self.engine_state.get_env_var("PATH");
231        let path = match path {
232            None => nu_home.display().to_string(),
233            Some(path) => format!(
234                "{nu}{sep}{prev}",
235                nu = nu_home.display(),
236                sep = ENV_PATH_SEPARATOR_CHAR,
237                prev = path.as_str().expect("PATH should always be a string")
238            ),
239        };
240        self.env("PATH", path)
241    }
242
243    /// Add a custom environment variable to the engine state.
244    pub fn env(mut self, key: impl Into<String>, val: impl Into<String>) -> Self {
245        self.engine_state
246            .add_env_var(key.into(), Value::test_string(val.into()));
247        self
248    }
249
250    /// Run Nushell code and extract the value into `T`.
251    ///
252    /// Parsing, compilation, or evaluation failures are returned as [`TestError`].
253    #[track_caller]
254    pub fn run<T: FromValue>(&mut self, code: impl AsRef<str>) -> Result<T> {
255        Self::extract_value(self.run_raw(code)?)
256    }
257
258    /// Run Nushell code with input data and extract the value into `T`.
259    ///
260    /// The input value is converted into `PipelineData` using [`IntoValue`].
261    #[track_caller]
262    pub fn run_with_data<T: FromValue>(
263        &mut self,
264        code: impl AsRef<str>,
265        data: impl IntoValue,
266    ) -> Result<T> {
267        let input = PipelineData::value(data.into_value(Span::test_data()), None);
268        Self::extract_value(self.run_raw_with_data(code, input)?)
269    }
270
271    /// Run Nushell code and return the raw [`PipelineExecutionData`].
272    #[track_caller]
273    pub fn run_raw(&mut self, code: impl AsRef<str>) -> Result<PipelineExecutionData> {
274        self.run_raw_with_data(code, PipelineData::empty())
275    }
276
277    /// Run Nushell code with input data and return the raw execution results.
278    ///
279    /// This parses, compiles, and evaluates the code against the current engine state.
280    #[track_caller]
281    pub fn run_raw_with_data(
282        &mut self,
283        code: impl AsRef<str>,
284        data: PipelineData,
285    ) -> Result<PipelineExecutionData> {
286        let location = TestLocation(Location::caller());
287        let (delta, block) = self.parse_and_compile(code)?;
288        self.engine_state.merge_delta(delta)?;
289        nu_engine::eval_block::<WithoutDebug>(&self.engine_state, &mut self.stack, &block, data)
290            .map_err(|err| TestError {
291                location,
292                kind: TestErrorKind::Shell(err),
293            })
294    }
295
296    #[track_caller]
297    pub fn parse_and_compile(&self, code: impl AsRef<str>) -> Result<(StateDelta, Arc<Block>)> {
298        let location = TestLocation(Location::caller());
299        let code = code.as_ref().as_bytes();
300
301        let mut working_set = StateWorkingSet::new(&self.engine_state);
302        let block = nu_parser::parse(&mut working_set, None, code, false);
303
304        if let Some(err) = working_set.parse_errors.into_iter().next() {
305            return Err(TestError {
306                location,
307                kind: TestErrorKind::Parse(err),
308            });
309        }
310
311        if let Some(err) = working_set.compile_errors.into_iter().next() {
312            return Err(TestError {
313                location,
314                kind: TestErrorKind::Compile(err),
315            });
316        }
317
318        Ok((working_set.delta, block))
319    }
320
321    #[track_caller]
322    fn extract_value<T: FromValue>(
323        pipeline_execution_data: PipelineExecutionData,
324    ) -> Result<T, TestError> {
325        let pipeline_data = pipeline_execution_data.body;
326        let value = pipeline_data.into_value(Span::test_data())?;
327        let value = T::from_value(value)?;
328        Ok(value)
329    }
330
331    /// Test examples of a command.
332    #[track_caller]
333    pub fn examples(&self, command: impl Command + 'static) -> Result {
334        let location = TestLocation(Location::caller());
335        for example in command.examples() {
336            match example.result {
337                None => self
338                    .parse_and_compile(example.example)
339                    .map(|_| ())
340                    .map_err(|err| TestError {
341                        location,
342                        kind: TestErrorKind::ExampleFailed {
343                            command: command.name().to_string(),
344                            description: example.description.to_string(),
345                            code: example.example.to_string(),
346                            err: Box::new(err.kind),
347                        },
348                    })?,
349                Some(expected) => {
350                    let got = self.clone().run(example.example)?;
351                    if got != expected {
352                        return Err(TestError {
353                            location,
354                            kind: TestErrorKind::ExampleFailed {
355                                command: command.name().to_string(),
356                                description: example.description.to_string(),
357                                code: example.example.to_string(),
358                                err: Box::new(TestErrorKind::UnexpectedValue { expected, got }),
359                            },
360                        });
361                    }
362                }
363            }
364        }
365
366        Ok(())
367    }
368}
369
370#[derive(Debug, Clone, PartialEq)]
371pub struct TestError {
372    location: TestLocation,
373    kind: TestErrorKind,
374}
375
376#[derive(Clone, Copy, PartialEq)]
377pub struct TestLocation(&'static Location<'static>);
378
379impl Debug for TestLocation {
380    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381        write!(f, "{}", self.0)
382    }
383}
384
385/// Errors emitted by `NuTester` when parsing, compiling, or evaluating code.
386///
387/// This enum is marked as non-exhaustive to allow adding new variants.
388#[non_exhaustive]
389#[derive(Debug, Clone, PartialEq)]
390pub enum TestErrorKind {
391    Parse(ParseError),
392    Compile(CompileError),
393    Shell(ShellError),
394    GotValue {
395        got: Value,
396    },
397    NoInner,
398    UnexpectedErrorKind {
399        expected: &'static str,
400        got: ShellError,
401    },
402    UnexpectedValue {
403        expected: Value,
404        got: Value,
405    },
406    ExampleFailed {
407        command: String,
408        description: String,
409        code: String,
410        err: Box<TestErrorKind>,
411    },
412}
413
414impl Display for TestError {
415    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416        write!(f, "{self:#?}")
417    }
418}
419
420impl Error for TestError {}
421
422impl From<ShellError> for TestError {
423    #[track_caller]
424    fn from(err: ShellError) -> Self {
425        Self {
426            location: TestLocation(Location::caller()),
427            kind: TestErrorKind::Shell(err),
428        }
429    }
430}
431
432impl From<ParseError> for TestError {
433    #[track_caller]
434    fn from(err: ParseError) -> Self {
435        Self {
436            location: TestLocation(Location::caller()),
437            kind: TestErrorKind::Parse(err),
438        }
439    }
440}
441
442impl TestError {
443    /// Convert this error into a [`ParseError`], if it is one.
444    pub fn parse(self) -> Result<ParseError, TestError> {
445        match self.kind {
446            TestErrorKind::Parse(err) => Ok(err),
447            _ => Err(self),
448        }
449    }
450
451    /// Convert this error into a [`CompileError`], if it is one.
452    pub fn compile(self) -> Result<CompileError, TestError> {
453        match self.kind {
454            TestErrorKind::Compile(err) => Ok(err),
455            _ => Err(self),
456        }
457    }
458
459    /// Convert this error into a [`ShellError`], if it is one.
460    pub fn shell(self) -> Result<ShellError, TestError> {
461        match self.kind {
462            TestErrorKind::Shell(err) => Ok(err),
463            _ => Err(self),
464        }
465    }
466
467    /// Update it's inner location with the call site of this function.
468    #[track_caller]
469    pub fn update_location(self) -> Self {
470        Self {
471            location: TestLocation(Location::caller()),
472            ..self
473        }
474    }
475}
476
477/// Convenience result type for test helpers.
478pub type Result<T = (), E = TestError> = std::result::Result<T, E>;
479
480/// Extensions for asserting error kinds from test helpers.
481pub trait TestResultExt: Sized {
482    /// Expect the result to be a `Value` equal to the provided input.
483    fn expect_value_eq<T: IntoValue>(self, value: T) -> Result;
484
485    /// Expect the result to be a [`ShellError`].
486    fn expect_shell_error(self) -> Result<ShellError>;
487    /// Expect the result to be a [`ParseError`].
488    fn expect_parse_error(self) -> Result<ParseError>;
489    /// Expect the result to be a [`CompileError`].
490    fn expect_compile_error(self) -> Result<CompileError>;
491
492    /// Expect the result to be a [`ShellError::Io`].
493    fn expect_io_error(self) -> Result<IoError>;
494    /// Expect the result to be a [`ShellError::Network`].
495    fn expect_network_error(self) -> Result<NetworkError>;
496    /// Expect the result to be a [`ShellError::LabeledError`].
497    fn expect_labeled_error(self) -> Result<LabeledError>;
498
499    /// Expect the result to be a [`ShellError`].
500    #[track_caller]
501    fn expect_error(self) -> Result<ShellError> {
502        self.expect_shell_error()
503    }
504}
505
506impl TestResultExt for Result<Value> {
507    #[track_caller]
508    fn expect_value_eq<T: IntoValue>(self, expected: T) -> Result {
509        let expected = expected.into_value(Span::test_data());
510        match self {
511            Err(err) => Err(err.update_location()),
512            Ok(actual) if actual == expected => Ok(()),
513            Ok(actual) => Err(TestError {
514                location: TestLocation(Location::caller()),
515                kind: TestErrorKind::UnexpectedValue {
516                    expected,
517                    got: actual,
518                },
519            }),
520        }
521    }
522
523    #[track_caller]
524    fn expect_shell_error(self) -> Result<ShellError> {
525        match self {
526            Ok(got) => Err(TestError {
527                location: TestLocation(Location::caller()),
528                kind: TestErrorKind::GotValue { got },
529            }),
530            Err(TestError {
531                kind: TestErrorKind::Shell(err),
532                ..
533            }) => Ok(err),
534            Err(err) => Err(err.update_location()),
535        }
536    }
537
538    #[track_caller]
539    fn expect_parse_error(self) -> Result<ParseError> {
540        match self {
541            Ok(got) => Err(TestError {
542                location: TestLocation(Location::caller()),
543                kind: TestErrorKind::GotValue { got },
544            }),
545            Err(TestError {
546                kind: TestErrorKind::Parse(err),
547                ..
548            }) => Ok(err),
549            Err(err) => Err(err.update_location()),
550        }
551    }
552
553    #[track_caller]
554    fn expect_compile_error(self) -> Result<CompileError> {
555        match self {
556            Ok(got) => Err(TestError {
557                location: TestLocation(Location::caller()),
558                kind: TestErrorKind::GotValue { got },
559            }),
560            Err(TestError {
561                kind: TestErrorKind::Compile(err),
562                ..
563            }) => Ok(err),
564            Err(err) => Err(err.update_location()),
565        }
566    }
567
568    #[track_caller]
569    fn expect_io_error(self) -> Result<IoError> {
570        match self {
571            Ok(got) => Err(TestError {
572                location: TestLocation(Location::caller()),
573                kind: TestErrorKind::GotValue { got },
574            }),
575            Err(TestError {
576                kind: TestErrorKind::Shell(ShellError::Io(err)),
577                ..
578            }) => Ok(err),
579            Err(err) => Err(err.update_location()),
580        }
581    }
582
583    #[track_caller]
584    fn expect_network_error(self) -> Result<NetworkError> {
585        match self {
586            Ok(got) => Err(TestError {
587                location: TestLocation(Location::caller()),
588                kind: TestErrorKind::GotValue { got },
589            }),
590            Err(TestError {
591                kind: TestErrorKind::Shell(ShellError::Network(err)),
592                ..
593            }) => Ok(err),
594            Err(err) => Err(err.update_location()),
595        }
596    }
597
598    #[track_caller]
599    fn expect_labeled_error(self) -> Result<LabeledError> {
600        match self {
601            Ok(got) => Err(TestError {
602                location: TestLocation(Location::caller()),
603                kind: TestErrorKind::GotValue { got },
604            }),
605            Err(TestError {
606                kind: TestErrorKind::Shell(ShellError::LabeledError(err)),
607                ..
608            }) => Ok(*err),
609            Err(err) => Err(err.update_location()),
610        }
611    }
612}
613
614/// Extensions for interrogating [`ShellError`] values in tests.
615pub trait ShellErrorExt {
616    /// Tries to convert into an inner value from a [`ShellError`].
617    ///
618    /// Useful if the error is expected to be a generic error that contains an inner error or a
619    /// chained error that chained another error.
620    ///
621    /// However, this function returns [`None`]
622    /// - if `inner` of [`ShellError::Generic`] is empty
623    /// - if `sources` of [`ShellError::ChainedError`] is empty
624    /// - the error is none of the above types
625    ///
626    /// So make sure that a [`None`] value is not surprise.
627    fn into_inner(self) -> Result<ShellError>;
628
629    /// Extract the [`LabeledError`] from [`ShellError::LabeledError`], if it is one.
630    fn into_labeled(self) -> Result<LabeledError>;
631
632    /// Extract the error field from [`ShellError::Generic`], if it is one.
633    fn generic_error(self) -> Result<String>;
634
635    /// Extract the message field from [`ShellError::Generic`], if it is one.
636    fn generic_msg(self) -> Result<String>;
637}
638
639impl ShellErrorExt for ShellError {
640    #[track_caller]
641    fn into_inner(self) -> Result<ShellError> {
642        let no_inner = TestError {
643            location: TestLocation(Location::caller()),
644            kind: TestErrorKind::NoInner,
645        };
646        match self {
647            ShellError::Generic(err) => err.inner.into_iter().next().ok_or(no_inner),
648            ShellError::ChainedError(err) => err.sources_iter().next().ok_or(no_inner),
649            _ => Err(no_inner),
650        }
651    }
652
653    #[track_caller]
654    fn into_labeled(self) -> Result<LabeledError> {
655        match self {
656            ShellError::LabeledError(err) => Ok(*err),
657            got => Err(TestError {
658                location: TestLocation(Location::caller()),
659                kind: TestErrorKind::UnexpectedErrorKind {
660                    expected: "Labeled",
661                    got,
662                },
663            }),
664        }
665    }
666
667    #[track_caller]
668    fn generic_error(self) -> Result<String> {
669        match self {
670            ShellError::Generic(err) => Ok(err.error.into_owned()),
671            got => Err(TestError {
672                location: TestLocation(Location::caller()),
673                kind: TestErrorKind::UnexpectedErrorKind {
674                    expected: "Generic",
675                    got,
676                },
677            }),
678        }
679    }
680
681    #[track_caller]
682    fn generic_msg(self) -> Result<String> {
683        match self {
684            ShellError::Generic(err) => Ok(err.msg.into_owned()),
685            got => Err(TestError {
686                location: TestLocation(Location::caller()),
687                kind: TestErrorKind::UnexpectedErrorKind {
688                    expected: "Generic",
689                    got,
690                },
691            }),
692        }
693    }
694}