1use std::{
2 env,
3 error::Error,
4 fmt::{Debug, Display},
5 panic::Location,
6 path::PathBuf,
7 sync::{Arc, LazyLock},
8};
9
10use nu_protocol::{
11 CompileError, Config, FromValue, IntoValue, LabeledError, ParseError, PipelineData,
12 PipelineExecutionData, ShellError, Span, Value,
13 ast::Block,
14 debugger::WithoutDebug,
15 engine::{Command, EngineState, Stack, StateDelta, StateWorkingSet},
16 shell_error::{io::IoError, network::NetworkError},
17};
18use nu_utils::{consts::ENV_PATH_SEPARATOR_CHAR, sync::KeyedLazyLock};
19
20use crate::harness::group::GroupKey;
21
22static ROOT: LazyLock<PathBuf> = LazyLock::new(|| {
23 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
24 .join("../..")
25 .canonicalize()
26 .expect("could not canonicalize root")
27});
28
29static INITIAL_ENGINE_STATES: KeyedLazyLock<GroupKey, EngineState> = KeyedLazyLock::new(|_| {
32 let engine_state = nu_cmd_lang::create_default_context();
33 let engine_state = nu_command::add_shell_command_context(engine_state);
34 let mut engine_state = nu_cmd_extra::add_extra_command_context(engine_state);
35
36 engine_state.generate_nu_constant();
37 [
38 ("PWD", Value::test_string(ROOT.to_string_lossy())),
39 ("config", Config::default().into_value(Span::unknown())),
40 ]
41 .into_iter()
42 .for_each(|(key, val)| engine_state.add_env_var(key.into(), val));
43
44 nu_std::load_standard_library(&mut engine_state).expect("could not load standard library");
45
46 engine_state
47});
48
49pub fn test() -> NuTester {
100 NuTester::default()
101}
102
103#[derive(Clone)]
108pub struct NuTester {
109 engine_state: EngineState,
110 stack: Stack,
111}
112
113impl Default for NuTester {
114 fn default() -> Self {
118 Self {
119 engine_state: INITIAL_ENGINE_STATES.get(&GroupKey::current()).clone(),
120 stack: Stack::new().collect_value(),
121 }
122 }
123}
124
125impl NuTester {
126 pub fn new() -> Self {
130 Self::default()
131 }
132
133 pub fn cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
137 let cwd = cwd.into();
138
139 let cwd = match cwd.is_absolute() {
140 true => cwd,
141 false => ROOT
142 .join(cwd)
143 .canonicalize()
144 .expect("could not canonicalize path"),
145 };
146
147 self.engine_state
148 .add_env_var("PWD".into(), Value::test_string(cwd.to_string_lossy()));
149 self
150 }
151
152 pub fn locale(mut self, locale: impl Into<String>) -> Self {
154 self.engine_state.add_env_var(
155 "NU_TEST_LOCALE_OVERRIDE".into(),
156 Value::test_string(locale.into()),
157 );
158 self
159 }
160
161 pub fn locale_en(self) -> Self {
163 self.locale("en_US.utf8")
164 }
165
166 pub fn inherit_path(self) -> Self {
173 let path = env::var("PATH").expect("PATH not available in env");
174 self.env("PATH", path)
175 }
176
177 pub fn inherit_env_if_set(self, key: impl AsRef<str>) -> Self {
181 let key = key.as_ref();
182 match env::var(key) {
183 Ok(val) => self.env(key, val),
184 Err(_) => self,
185 }
186 }
187
188 pub fn inherit_rust_toolchain_env(self) -> Self {
211 self.inherit_env_if_set("PATH")
212 .inherit_env_if_set("CARGO_HOME")
213 .inherit_env_if_set("RUSTUP_HOME")
214 .inherit_env_if_set("RUSTUP_TOOLCHAIN")
215 .inherit_env_if_set("RUSTUP_DIST_SERVER")
216 .inherit_env_if_set("RUSTUP_UPDATE_ROOT")
217 .inherit_env_if_set("HTTP_PROXY")
218 .inherit_env_if_set("HTTPS_PROXY")
219 .inherit_env_if_set("NO_PROXY")
220 .inherit_env_if_set("http_proxy")
221 .inherit_env_if_set("https_proxy")
222 .inherit_env_if_set("no_proxy")
223 }
224
225 pub fn add_nu_to_path(self) -> Self {
229 let nu_home = crate::fs::binaries();
230 let path = self.engine_state.get_env_var("PATH");
231 let path = match path {
232 None => nu_home.display().to_string(),
233 Some(path) => format!(
234 "{nu}{sep}{prev}",
235 nu = nu_home.display(),
236 sep = ENV_PATH_SEPARATOR_CHAR,
237 prev = path.as_str().expect("PATH should always be a string")
238 ),
239 };
240 self.env("PATH", path)
241 }
242
243 pub fn env(mut self, key: impl Into<String>, val: impl Into<String>) -> Self {
245 self.engine_state
246 .add_env_var(key.into(), Value::test_string(val.into()));
247 self
248 }
249
250 #[track_caller]
254 pub fn run<T: FromValue>(&mut self, code: impl AsRef<str>) -> Result<T> {
255 Self::extract_value(self.run_raw(code)?)
256 }
257
258 #[track_caller]
262 pub fn run_with_data<T: FromValue>(
263 &mut self,
264 code: impl AsRef<str>,
265 data: impl IntoValue,
266 ) -> Result<T> {
267 let input = PipelineData::value(data.into_value(Span::test_data()), None);
268 Self::extract_value(self.run_raw_with_data(code, input)?)
269 }
270
271 #[track_caller]
273 pub fn run_raw(&mut self, code: impl AsRef<str>) -> Result<PipelineExecutionData> {
274 self.run_raw_with_data(code, PipelineData::empty())
275 }
276
277 #[track_caller]
281 pub fn run_raw_with_data(
282 &mut self,
283 code: impl AsRef<str>,
284 data: PipelineData,
285 ) -> Result<PipelineExecutionData> {
286 let location = TestLocation(Location::caller());
287 let (delta, block) = self.parse_and_compile(code)?;
288 self.engine_state.merge_delta(delta)?;
289 nu_engine::eval_block::<WithoutDebug>(&self.engine_state, &mut self.stack, &block, data)
290 .map_err(|err| TestError {
291 location,
292 kind: TestErrorKind::Shell(err),
293 })
294 }
295
296 #[track_caller]
297 pub fn parse_and_compile(&self, code: impl AsRef<str>) -> Result<(StateDelta, Arc<Block>)> {
298 let location = TestLocation(Location::caller());
299 let code = code.as_ref().as_bytes();
300
301 let mut working_set = StateWorkingSet::new(&self.engine_state);
302 let block = nu_parser::parse(&mut working_set, None, code, false);
303
304 if let Some(err) = working_set.parse_errors.into_iter().next() {
305 return Err(TestError {
306 location,
307 kind: TestErrorKind::Parse(err),
308 });
309 }
310
311 if let Some(err) = working_set.compile_errors.into_iter().next() {
312 return Err(TestError {
313 location,
314 kind: TestErrorKind::Compile(err),
315 });
316 }
317
318 Ok((working_set.delta, block))
319 }
320
321 #[track_caller]
322 fn extract_value<T: FromValue>(
323 pipeline_execution_data: PipelineExecutionData,
324 ) -> Result<T, TestError> {
325 let pipeline_data = pipeline_execution_data.body;
326 let value = pipeline_data.into_value(Span::test_data())?;
327 let value = T::from_value(value)?;
328 Ok(value)
329 }
330
331 #[track_caller]
333 pub fn examples(&self, command: impl Command + 'static) -> Result {
334 let location = TestLocation(Location::caller());
335 for example in command.examples() {
336 match example.result {
337 None => self
338 .parse_and_compile(example.example)
339 .map(|_| ())
340 .map_err(|err| TestError {
341 location,
342 kind: TestErrorKind::ExampleFailed {
343 command: command.name().to_string(),
344 description: example.description.to_string(),
345 code: example.example.to_string(),
346 err: Box::new(err.kind),
347 },
348 })?,
349 Some(expected) => {
350 let got = self.clone().run(example.example)?;
351 if got != expected {
352 return Err(TestError {
353 location,
354 kind: TestErrorKind::ExampleFailed {
355 command: command.name().to_string(),
356 description: example.description.to_string(),
357 code: example.example.to_string(),
358 err: Box::new(TestErrorKind::UnexpectedValue { expected, got }),
359 },
360 });
361 }
362 }
363 }
364 }
365
366 Ok(())
367 }
368}
369
370#[derive(Debug, Clone, PartialEq)]
371pub struct TestError {
372 location: TestLocation,
373 kind: TestErrorKind,
374}
375
376#[derive(Clone, Copy, PartialEq, derive_more::Debug)]
377#[debug("{_0}")]
378pub struct TestLocation(&'static Location<'static>);
379
380#[non_exhaustive]
384#[derive(Debug, Clone, PartialEq)]
385pub enum TestErrorKind {
386 Parse(ParseError),
387 Compile(CompileError),
388 Shell(ShellError),
389 GotValue {
390 got: Value,
391 },
392 NoInner,
393 UnexpectedErrorKind {
394 expected: &'static str,
395 got: ShellError,
396 },
397 UnexpectedValue {
398 expected: Value,
399 got: Value,
400 },
401 ExampleFailed {
402 command: String,
403 description: String,
404 code: String,
405 err: Box<TestErrorKind>,
406 },
407}
408
409impl Display for TestError {
410 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411 write!(f, "{self:#?}")
412 }
413}
414
415impl Error for TestError {}
416
417impl From<ShellError> for TestError {
418 #[track_caller]
419 fn from(err: ShellError) -> Self {
420 Self {
421 location: TestLocation(Location::caller()),
422 kind: TestErrorKind::Shell(err),
423 }
424 }
425}
426
427impl From<ParseError> for TestError {
428 #[track_caller]
429 fn from(err: ParseError) -> Self {
430 Self {
431 location: TestLocation(Location::caller()),
432 kind: TestErrorKind::Parse(err),
433 }
434 }
435}
436
437impl TestError {
438 pub fn parse(self) -> Result<ParseError, TestError> {
440 match self.kind {
441 TestErrorKind::Parse(err) => Ok(err),
442 _ => Err(self),
443 }
444 }
445
446 pub fn compile(self) -> Result<CompileError, TestError> {
448 match self.kind {
449 TestErrorKind::Compile(err) => Ok(err),
450 _ => Err(self),
451 }
452 }
453
454 pub fn shell(self) -> Result<ShellError, TestError> {
456 match self.kind {
457 TestErrorKind::Shell(err) => Ok(err),
458 _ => Err(self),
459 }
460 }
461
462 #[track_caller]
464 pub fn update_location(self) -> Self {
465 Self {
466 location: TestLocation(Location::caller()),
467 ..self
468 }
469 }
470}
471
472pub type Result<T = (), E = TestError> = std::result::Result<T, E>;
474
475pub trait TestResultExt: Sized {
477 fn expect_value_eq<T: IntoValue>(self, value: T) -> Result;
479
480 fn expect_shell_error(self) -> Result<ShellError>;
482 fn expect_parse_error(self) -> Result<ParseError>;
484 fn expect_compile_error(self) -> Result<CompileError>;
486
487 fn expect_io_error(self) -> Result<IoError>;
489 fn expect_network_error(self) -> Result<NetworkError>;
491 fn expect_labeled_error(self) -> Result<LabeledError>;
493
494 #[track_caller]
496 fn expect_error(self) -> Result<ShellError> {
497 self.expect_shell_error()
498 }
499}
500
501impl TestResultExt for Result<Value> {
502 #[track_caller]
503 fn expect_value_eq<T: IntoValue>(self, expected: T) -> Result {
504 let expected = expected.into_value(Span::test_data());
505 match self {
506 Err(err) => Err(err.update_location()),
507 Ok(actual) if actual == expected => Ok(()),
508 Ok(actual) => Err(TestError {
509 location: TestLocation(Location::caller()),
510 kind: TestErrorKind::UnexpectedValue {
511 expected,
512 got: actual,
513 },
514 }),
515 }
516 }
517
518 #[track_caller]
519 fn expect_shell_error(self) -> Result<ShellError> {
520 match self {
521 Ok(got) => Err(TestError {
522 location: TestLocation(Location::caller()),
523 kind: TestErrorKind::GotValue { got },
524 }),
525 Err(TestError {
526 kind: TestErrorKind::Shell(err),
527 ..
528 }) => Ok(err),
529 Err(err) => Err(err.update_location()),
530 }
531 }
532
533 #[track_caller]
534 fn expect_parse_error(self) -> Result<ParseError> {
535 match self {
536 Ok(got) => Err(TestError {
537 location: TestLocation(Location::caller()),
538 kind: TestErrorKind::GotValue { got },
539 }),
540 Err(TestError {
541 kind: TestErrorKind::Parse(err),
542 ..
543 }) => Ok(err),
544 Err(err) => Err(err.update_location()),
545 }
546 }
547
548 #[track_caller]
549 fn expect_compile_error(self) -> Result<CompileError> {
550 match self {
551 Ok(got) => Err(TestError {
552 location: TestLocation(Location::caller()),
553 kind: TestErrorKind::GotValue { got },
554 }),
555 Err(TestError {
556 kind: TestErrorKind::Compile(err),
557 ..
558 }) => Ok(err),
559 Err(err) => Err(err.update_location()),
560 }
561 }
562
563 #[track_caller]
564 fn expect_io_error(self) -> Result<IoError> {
565 match self {
566 Ok(got) => Err(TestError {
567 location: TestLocation(Location::caller()),
568 kind: TestErrorKind::GotValue { got },
569 }),
570 Err(TestError {
571 kind: TestErrorKind::Shell(ShellError::Io(err)),
572 ..
573 }) => Ok(err),
574 Err(err) => Err(err.update_location()),
575 }
576 }
577
578 #[track_caller]
579 fn expect_network_error(self) -> Result<NetworkError> {
580 match self {
581 Ok(got) => Err(TestError {
582 location: TestLocation(Location::caller()),
583 kind: TestErrorKind::GotValue { got },
584 }),
585 Err(TestError {
586 kind: TestErrorKind::Shell(ShellError::Network(err)),
587 ..
588 }) => Ok(err),
589 Err(err) => Err(err.update_location()),
590 }
591 }
592
593 #[track_caller]
594 fn expect_labeled_error(self) -> Result<LabeledError> {
595 match self {
596 Ok(got) => Err(TestError {
597 location: TestLocation(Location::caller()),
598 kind: TestErrorKind::GotValue { got },
599 }),
600 Err(TestError {
601 kind: TestErrorKind::Shell(ShellError::LabeledError(err)),
602 ..
603 }) => Ok(*err),
604 Err(err) => Err(err.update_location()),
605 }
606 }
607}
608
609pub trait ShellErrorExt {
611 fn into_inner(self) -> Result<ShellError>;
623
624 fn into_labeled(self) -> Result<LabeledError>;
626
627 fn generic_error(self) -> Result<String>;
629
630 fn generic_msg(self) -> Result<String>;
632}
633
634impl ShellErrorExt for ShellError {
635 #[track_caller]
636 fn into_inner(self) -> Result<ShellError> {
637 let no_inner = TestError {
638 location: TestLocation(Location::caller()),
639 kind: TestErrorKind::NoInner,
640 };
641 match self {
642 ShellError::Generic(err) => err.inner.into_iter().next().ok_or(no_inner),
643 ShellError::ChainedError(err) => err.sources_iter().next().ok_or(no_inner),
644 _ => Err(no_inner),
645 }
646 }
647
648 #[track_caller]
649 fn into_labeled(self) -> Result<LabeledError> {
650 match self {
651 ShellError::LabeledError(err) => Ok(*err),
652 got => Err(TestError {
653 location: TestLocation(Location::caller()),
654 kind: TestErrorKind::UnexpectedErrorKind {
655 expected: "Labeled",
656 got,
657 },
658 }),
659 }
660 }
661
662 #[track_caller]
663 fn generic_error(self) -> Result<String> {
664 match self {
665 ShellError::Generic(err) => Ok(err.error.into_owned()),
666 got => Err(TestError {
667 location: TestLocation(Location::caller()),
668 kind: TestErrorKind::UnexpectedErrorKind {
669 expected: "Generic",
670 got,
671 },
672 }),
673 }
674 }
675
676 #[track_caller]
677 fn generic_msg(self) -> Result<String> {
678 match self {
679 ShellError::Generic(err) => Ok(err.msg.into_owned()),
680 got => Err(TestError {
681 location: TestLocation(Location::caller()),
682 kind: TestErrorKind::UnexpectedErrorKind {
683 expected: "Generic",
684 got,
685 },
686 }),
687 }
688 }
689}