use std::{
env,
error::Error,
fmt::{Debug, Display},
panic::Location,
path::PathBuf,
sync::{Arc, LazyLock},
};
use nu_protocol::{
CompileError, Config, FromValue, IntoValue, LabeledError, ParseError, PipelineData,
PipelineExecutionData, ShellError, Span, Value,
ast::Block,
debugger::WithoutDebug,
engine::{Command, EngineState, Stack, StateDelta, StateWorkingSet},
shell_error::{io::IoError, network::NetworkError},
};
use nu_utils::{consts::ENV_PATH_SEPARATOR_CHAR, sync::KeyedLazyLock};
use crate::harness::group::GroupKey;
static ROOT: LazyLock<PathBuf> = LazyLock::new(|| {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("../..")
.canonicalize()
.expect("could not canonicalize root")
});
static INITIAL_ENGINE_STATES: KeyedLazyLock<GroupKey, EngineState> = KeyedLazyLock::new(|_| {
let engine_state = nu_cmd_lang::create_default_context();
let engine_state = nu_command::add_shell_command_context(engine_state);
let mut engine_state = nu_cmd_extra::add_extra_command_context(engine_state);
engine_state.generate_nu_constant();
[
("PWD", Value::test_string(ROOT.to_string_lossy())),
("config", Config::default().into_value(Span::unknown())),
]
.into_iter()
.for_each(|(key, val)| engine_state.add_env_var(key.into(), val));
nu_std::load_standard_library(&mut engine_state).expect("could not load standard library");
engine_state
});
pub fn test() -> NuTester {
NuTester::default()
}
#[derive(Clone)]
pub struct NuTester {
engine_state: EngineState,
stack: Stack,
}
impl Default for NuTester {
fn default() -> Self {
Self {
engine_state: INITIAL_ENGINE_STATES.get(&GroupKey::current()).clone(),
stack: Stack::new().collect_value(),
}
}
}
impl NuTester {
pub fn new() -> Self {
Self::default()
}
pub fn cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
let cwd = cwd.into();
let cwd = match cwd.is_absolute() {
true => cwd,
false => ROOT
.join(cwd)
.canonicalize()
.expect("could not canonicalize path"),
};
self.engine_state
.add_env_var("PWD".into(), Value::test_string(cwd.to_string_lossy()));
self
}
pub fn locale(mut self, locale: impl Into<String>) -> Self {
self.engine_state.add_env_var(
"NU_TEST_LOCALE_OVERRIDE".into(),
Value::test_string(locale.into()),
);
self
}
pub fn locale_en(self) -> Self {
self.locale("en_US.utf8")
}
pub fn inherit_path(self) -> Self {
let path = env::var("PATH").expect("PATH not available in env");
self.env("PATH", path)
}
pub fn inherit_env_if_set(self, key: impl AsRef<str>) -> Self {
let key = key.as_ref();
match env::var(key) {
Ok(val) => self.env(key, val),
Err(_) => self,
}
}
pub fn inherit_rust_toolchain_env(self) -> Self {
self.inherit_env_if_set("PATH")
.inherit_env_if_set("CARGO_HOME")
.inherit_env_if_set("RUSTUP_HOME")
.inherit_env_if_set("RUSTUP_TOOLCHAIN")
.inherit_env_if_set("RUSTUP_DIST_SERVER")
.inherit_env_if_set("RUSTUP_UPDATE_ROOT")
.inherit_env_if_set("HTTP_PROXY")
.inherit_env_if_set("HTTPS_PROXY")
.inherit_env_if_set("NO_PROXY")
.inherit_env_if_set("http_proxy")
.inherit_env_if_set("https_proxy")
.inherit_env_if_set("no_proxy")
}
pub fn add_nu_to_path(self) -> Self {
let nu_home = crate::fs::binaries();
let path = self.engine_state.get_env_var("PATH");
let path = match path {
None => nu_home.display().to_string(),
Some(path) => format!(
"{nu}{sep}{prev}",
nu = nu_home.display(),
sep = ENV_PATH_SEPARATOR_CHAR,
prev = path.as_str().expect("PATH should always be a string")
),
};
self.env("PATH", path)
}
pub fn env(mut self, key: impl Into<String>, val: impl Into<String>) -> Self {
self.engine_state
.add_env_var(key.into(), Value::test_string(val.into()));
self
}
#[track_caller]
pub fn run<T: FromValue>(&mut self, code: impl AsRef<str>) -> Result<T> {
Self::extract_value(self.run_raw(code)?)
}
#[track_caller]
pub fn run_with_data<T: FromValue>(
&mut self,
code: impl AsRef<str>,
data: impl IntoValue,
) -> Result<T> {
let input = PipelineData::value(data.into_value(Span::test_data()), None);
Self::extract_value(self.run_raw_with_data(code, input)?)
}
#[track_caller]
pub fn run_raw(&mut self, code: impl AsRef<str>) -> Result<PipelineExecutionData> {
self.run_raw_with_data(code, PipelineData::empty())
}
#[track_caller]
pub fn run_raw_with_data(
&mut self,
code: impl AsRef<str>,
data: PipelineData,
) -> Result<PipelineExecutionData> {
let location = TestLocation(Location::caller());
let (delta, block) = self.parse_and_compile(code)?;
self.engine_state.merge_delta(delta)?;
nu_engine::eval_block::<WithoutDebug>(&self.engine_state, &mut self.stack, &block, data)
.map_err(|err| TestError {
location,
kind: TestErrorKind::Shell(err),
})
}
#[track_caller]
pub fn parse_and_compile(&self, code: impl AsRef<str>) -> Result<(StateDelta, Arc<Block>)> {
let location = TestLocation(Location::caller());
let code = code.as_ref().as_bytes();
let mut working_set = StateWorkingSet::new(&self.engine_state);
let block = nu_parser::parse(&mut working_set, None, code, false);
if let Some(err) = working_set.parse_errors.into_iter().next() {
return Err(TestError {
location,
kind: TestErrorKind::Parse(err),
});
}
if let Some(err) = working_set.compile_errors.into_iter().next() {
return Err(TestError {
location,
kind: TestErrorKind::Compile(err),
});
}
Ok((working_set.delta, block))
}
#[track_caller]
fn extract_value<T: FromValue>(
pipeline_execution_data: PipelineExecutionData,
) -> Result<T, TestError> {
let pipeline_data = pipeline_execution_data.body;
let value = pipeline_data.into_value(Span::test_data())?;
let value = T::from_value(value)?;
Ok(value)
}
#[track_caller]
pub fn examples(&self, command: impl Command + 'static) -> Result {
let location = TestLocation(Location::caller());
for example in command.examples() {
match example.result {
None => self
.parse_and_compile(example.example)
.map(|_| ())
.map_err(|err| TestError {
location,
kind: TestErrorKind::ExampleFailed {
command: command.name().to_string(),
description: example.description.to_string(),
code: example.example.to_string(),
err: Box::new(err.kind),
},
})?,
Some(expected) => {
let got = self.clone().run(example.example)?;
if got != expected {
return Err(TestError {
location,
kind: TestErrorKind::ExampleFailed {
command: command.name().to_string(),
description: example.description.to_string(),
code: example.example.to_string(),
err: Box::new(TestErrorKind::UnexpectedValue { expected, got }),
},
});
}
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TestError {
location: TestLocation,
kind: TestErrorKind,
}
#[derive(Clone, Copy, PartialEq, derive_more::Debug)]
#[debug("{_0}")]
pub struct TestLocation(&'static Location<'static>);
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum TestErrorKind {
Parse(ParseError),
Compile(CompileError),
Shell(ShellError),
GotValue {
got: Value,
},
NoInner,
UnexpectedErrorKind {
expected: &'static str,
got: ShellError,
},
UnexpectedValue {
expected: Value,
got: Value,
},
ExampleFailed {
command: String,
description: String,
code: String,
err: Box<TestErrorKind>,
},
}
impl Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:#?}")
}
}
impl Error for TestError {}
impl From<ShellError> for TestError {
#[track_caller]
fn from(err: ShellError) -> Self {
Self {
location: TestLocation(Location::caller()),
kind: TestErrorKind::Shell(err),
}
}
}
impl From<ParseError> for TestError {
#[track_caller]
fn from(err: ParseError) -> Self {
Self {
location: TestLocation(Location::caller()),
kind: TestErrorKind::Parse(err),
}
}
}
impl TestError {
pub fn parse(self) -> Result<ParseError, TestError> {
match self.kind {
TestErrorKind::Parse(err) => Ok(err),
_ => Err(self),
}
}
pub fn compile(self) -> Result<CompileError, TestError> {
match self.kind {
TestErrorKind::Compile(err) => Ok(err),
_ => Err(self),
}
}
pub fn shell(self) -> Result<ShellError, TestError> {
match self.kind {
TestErrorKind::Shell(err) => Ok(err),
_ => Err(self),
}
}
#[track_caller]
pub fn update_location(self) -> Self {
Self {
location: TestLocation(Location::caller()),
..self
}
}
}
pub type Result<T = (), E = TestError> = std::result::Result<T, E>;
pub trait TestResultExt: Sized {
fn expect_value_eq<T: IntoValue>(self, value: T) -> Result;
fn expect_shell_error(self) -> Result<ShellError>;
fn expect_parse_error(self) -> Result<ParseError>;
fn expect_compile_error(self) -> Result<CompileError>;
fn expect_io_error(self) -> Result<IoError>;
fn expect_network_error(self) -> Result<NetworkError>;
fn expect_labeled_error(self) -> Result<LabeledError>;
#[track_caller]
fn expect_error(self) -> Result<ShellError> {
self.expect_shell_error()
}
}
impl TestResultExt for Result<Value> {
#[track_caller]
fn expect_value_eq<T: IntoValue>(self, expected: T) -> Result {
let expected = expected.into_value(Span::test_data());
match self {
Err(err) => Err(err.update_location()),
Ok(actual) if actual == expected => Ok(()),
Ok(actual) => Err(TestError {
location: TestLocation(Location::caller()),
kind: TestErrorKind::UnexpectedValue {
expected,
got: actual,
},
}),
}
}
#[track_caller]
fn expect_shell_error(self) -> Result<ShellError> {
match self {
Ok(got) => Err(TestError {
location: TestLocation(Location::caller()),
kind: TestErrorKind::GotValue { got },
}),
Err(TestError {
kind: TestErrorKind::Shell(err),
..
}) => Ok(err),
Err(err) => Err(err.update_location()),
}
}
#[track_caller]
fn expect_parse_error(self) -> Result<ParseError> {
match self {
Ok(got) => Err(TestError {
location: TestLocation(Location::caller()),
kind: TestErrorKind::GotValue { got },
}),
Err(TestError {
kind: TestErrorKind::Parse(err),
..
}) => Ok(err),
Err(err) => Err(err.update_location()),
}
}
#[track_caller]
fn expect_compile_error(self) -> Result<CompileError> {
match self {
Ok(got) => Err(TestError {
location: TestLocation(Location::caller()),
kind: TestErrorKind::GotValue { got },
}),
Err(TestError {
kind: TestErrorKind::Compile(err),
..
}) => Ok(err),
Err(err) => Err(err.update_location()),
}
}
#[track_caller]
fn expect_io_error(self) -> Result<IoError> {
match self {
Ok(got) => Err(TestError {
location: TestLocation(Location::caller()),
kind: TestErrorKind::GotValue { got },
}),
Err(TestError {
kind: TestErrorKind::Shell(ShellError::Io(err)),
..
}) => Ok(err),
Err(err) => Err(err.update_location()),
}
}
#[track_caller]
fn expect_network_error(self) -> Result<NetworkError> {
match self {
Ok(got) => Err(TestError {
location: TestLocation(Location::caller()),
kind: TestErrorKind::GotValue { got },
}),
Err(TestError {
kind: TestErrorKind::Shell(ShellError::Network(err)),
..
}) => Ok(err),
Err(err) => Err(err.update_location()),
}
}
#[track_caller]
fn expect_labeled_error(self) -> Result<LabeledError> {
match self {
Ok(got) => Err(TestError {
location: TestLocation(Location::caller()),
kind: TestErrorKind::GotValue { got },
}),
Err(TestError {
kind: TestErrorKind::Shell(ShellError::LabeledError(err)),
..
}) => Ok(*err),
Err(err) => Err(err.update_location()),
}
}
}
pub trait ShellErrorExt {
fn into_inner(self) -> Result<ShellError>;
fn into_labeled(self) -> Result<LabeledError>;
fn generic_error(self) -> Result<String>;
fn generic_msg(self) -> Result<String>;
}
impl ShellErrorExt for ShellError {
#[track_caller]
fn into_inner(self) -> Result<ShellError> {
let no_inner = TestError {
location: TestLocation(Location::caller()),
kind: TestErrorKind::NoInner,
};
match self {
ShellError::Generic(err) => err.inner.into_iter().next().ok_or(no_inner),
ShellError::ChainedError(err) => err.sources_iter().next().ok_or(no_inner),
_ => Err(no_inner),
}
}
#[track_caller]
fn into_labeled(self) -> Result<LabeledError> {
match self {
ShellError::LabeledError(err) => Ok(*err),
got => Err(TestError {
location: TestLocation(Location::caller()),
kind: TestErrorKind::UnexpectedErrorKind {
expected: "Labeled",
got,
},
}),
}
}
#[track_caller]
fn generic_error(self) -> Result<String> {
match self {
ShellError::Generic(err) => Ok(err.error.into_owned()),
got => Err(TestError {
location: TestLocation(Location::caller()),
kind: TestErrorKind::UnexpectedErrorKind {
expected: "Generic",
got,
},
}),
}
}
#[track_caller]
fn generic_msg(self) -> Result<String> {
match self {
ShellError::Generic(err) => Ok(err.msg.into_owned()),
got => Err(TestError {
location: TestLocation(Location::caller()),
kind: TestErrorKind::UnexpectedErrorKind {
expected: "Generic",
got,
},
}),
}
}
}