1use std::{
2 env,
3 error::Error,
4 fmt::{Debug, Display},
5 panic::Location,
6 path::PathBuf,
7 sync::{Arc, LazyLock},
8};
9
10use nu_protocol::{
11 CompileError, Config, FromValue, IntoValue, LabeledError, ParseError, PipelineData,
12 PipelineExecutionData, ShellError, Span, Value,
13 ast::Block,
14 debugger::WithoutDebug,
15 engine::{Command, EngineState, Stack, StateDelta, StateWorkingSet},
16 shell_error::{io::IoError, network::NetworkError},
17};
18use nu_utils::{consts::ENV_PATH_SEPARATOR_CHAR, sync::KeyedLazyLock};
19
20use crate::harness::group::GroupKey;
21
22static ROOT: LazyLock<PathBuf> = LazyLock::new(|| {
23 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
24 .join("../..")
25 .canonicalize()
26 .expect("could not canonicalize root")
27});
28
29static INITIAL_ENGINE_STATES: KeyedLazyLock<GroupKey, EngineState> = KeyedLazyLock::new(|_| {
32 let engine_state = nu_cmd_lang::create_default_context();
33 let engine_state = nu_command::add_shell_command_context(engine_state);
34 let mut engine_state = nu_cmd_extra::add_extra_command_context(engine_state);
35
36 engine_state.generate_nu_constant();
37 [
38 ("PWD", Value::test_string(ROOT.to_string_lossy())),
39 ("config", Config::default().into_value(Span::unknown())),
40 ]
41 .into_iter()
42 .for_each(|(key, val)| engine_state.add_env_var(key.into(), val));
43
44 nu_std::load_standard_library(&mut engine_state).expect("could not load standard library");
45
46 engine_state
47});
48
49pub fn test() -> NuTester {
100 NuTester::default()
101}
102
103#[derive(Clone)]
108pub struct NuTester {
109 engine_state: EngineState,
110 stack: Stack,
111}
112
113impl Default for NuTester {
114 fn default() -> Self {
118 Self {
119 engine_state: INITIAL_ENGINE_STATES.get(&GroupKey::current()).clone(),
120 stack: Stack::new().collect_value(),
121 }
122 }
123}
124
125impl NuTester {
126 pub fn new() -> Self {
130 Self::default()
131 }
132
133 pub fn cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
137 let cwd = cwd.into();
138
139 let cwd = match cwd.is_absolute() {
140 true => cwd,
141 false => ROOT
142 .join(cwd)
143 .canonicalize()
144 .expect("could not canonicalize path"),
145 };
146
147 self.engine_state
148 .add_env_var("PWD".into(), Value::test_string(cwd.to_string_lossy()));
149 self
150 }
151
152 pub fn locale(mut self, locale: impl Into<String>) -> Self {
154 self.engine_state.add_env_var(
155 "NU_TEST_LOCALE_OVERRIDE".into(),
156 Value::test_string(locale.into()),
157 );
158 self
159 }
160
161 pub fn locale_en(self) -> Self {
163 self.locale("en_US.utf8")
164 }
165
166 pub fn inherit_path(self) -> Self {
173 let path = env::var("PATH").expect("PATH not available in env");
174 self.env("PATH", path)
175 }
176
177 pub fn inherit_env_if_set(self, key: impl AsRef<str>) -> Self {
181 let key = key.as_ref();
182 match env::var(key) {
183 Ok(val) => self.env(key, val),
184 Err(_) => self,
185 }
186 }
187
188 pub fn inherit_rust_toolchain_env(self) -> Self {
211 self.inherit_env_if_set("PATH")
212 .inherit_env_if_set("CARGO_HOME")
213 .inherit_env_if_set("RUSTUP_HOME")
214 .inherit_env_if_set("RUSTUP_TOOLCHAIN")
215 .inherit_env_if_set("RUSTUP_DIST_SERVER")
216 .inherit_env_if_set("RUSTUP_UPDATE_ROOT")
217 .inherit_env_if_set("HTTP_PROXY")
218 .inherit_env_if_set("HTTPS_PROXY")
219 .inherit_env_if_set("NO_PROXY")
220 .inherit_env_if_set("http_proxy")
221 .inherit_env_if_set("https_proxy")
222 .inherit_env_if_set("no_proxy")
223 }
224
225 pub fn add_nu_to_path(self) -> Self {
229 let nu_home = crate::fs::binaries();
230 let path = self.engine_state.get_env_var("PATH");
231 let path = match path {
232 None => nu_home.display().to_string(),
233 Some(path) => format!(
234 "{nu}{sep}{prev}",
235 nu = nu_home.display(),
236 sep = ENV_PATH_SEPARATOR_CHAR,
237 prev = path.as_str().expect("PATH should always be a string")
238 ),
239 };
240 self.env("PATH", path)
241 }
242
243 pub fn env(mut self, key: impl Into<String>, val: impl Into<String>) -> Self {
245 self.engine_state
246 .add_env_var(key.into(), Value::test_string(val.into()));
247 self
248 }
249
250 #[track_caller]
254 pub fn run<T: FromValue>(&mut self, code: impl AsRef<str>) -> Result<T> {
255 Self::extract_value(self.run_raw(code)?)
256 }
257
258 #[track_caller]
262 pub fn run_with_data<T: FromValue>(
263 &mut self,
264 code: impl AsRef<str>,
265 data: impl IntoValue,
266 ) -> Result<T> {
267 let input = PipelineData::value(data.into_value(Span::test_data()), None);
268 Self::extract_value(self.run_raw_with_data(code, input)?)
269 }
270
271 #[track_caller]
273 pub fn run_raw(&mut self, code: impl AsRef<str>) -> Result<PipelineExecutionData> {
274 self.run_raw_with_data(code, PipelineData::empty())
275 }
276
277 #[track_caller]
281 pub fn run_raw_with_data(
282 &mut self,
283 code: impl AsRef<str>,
284 data: PipelineData,
285 ) -> Result<PipelineExecutionData> {
286 let location = TestLocation(Location::caller());
287 let (delta, block) = self.parse_and_compile(code)?;
288 self.engine_state.merge_delta(delta)?;
289 nu_engine::eval_block::<WithoutDebug>(&self.engine_state, &mut self.stack, &block, data)
290 .map_err(|err| TestError {
291 location,
292 kind: TestErrorKind::Shell(err),
293 })
294 }
295
296 #[track_caller]
297 pub fn parse_and_compile(&self, code: impl AsRef<str>) -> Result<(StateDelta, Arc<Block>)> {
298 let location = TestLocation(Location::caller());
299 let code = code.as_ref().as_bytes();
300
301 let mut working_set = StateWorkingSet::new(&self.engine_state);
302 let block = nu_parser::parse(&mut working_set, None, code, false);
303
304 if let Some(err) = working_set.parse_errors.into_iter().next() {
305 return Err(TestError {
306 location,
307 kind: TestErrorKind::Parse(err),
308 });
309 }
310
311 if let Some(err) = working_set.compile_errors.into_iter().next() {
312 return Err(TestError {
313 location,
314 kind: TestErrorKind::Compile(err),
315 });
316 }
317
318 Ok((working_set.delta, block))
319 }
320
321 #[track_caller]
322 fn extract_value<T: FromValue>(
323 pipeline_execution_data: PipelineExecutionData,
324 ) -> Result<T, TestError> {
325 let pipeline_data = pipeline_execution_data.body;
326 let value = pipeline_data.into_value(Span::test_data())?;
327 let value = T::from_value(value)?;
328 Ok(value)
329 }
330
331 #[track_caller]
333 pub fn examples(&self, command: impl Command + 'static) -> Result {
334 let location = TestLocation(Location::caller());
335 for example in command.examples() {
336 match example.result {
337 None => self
338 .parse_and_compile(example.example)
339 .map(|_| ())
340 .map_err(|err| TestError {
341 location,
342 kind: TestErrorKind::ExampleFailed {
343 command: command.name().to_string(),
344 description: example.description.to_string(),
345 code: example.example.to_string(),
346 err: Box::new(err.kind),
347 },
348 })?,
349 Some(expected) => {
350 let got = self.clone().run(example.example)?;
351 if got != expected {
352 return Err(TestError {
353 location,
354 kind: TestErrorKind::ExampleFailed {
355 command: command.name().to_string(),
356 description: example.description.to_string(),
357 code: example.example.to_string(),
358 err: Box::new(TestErrorKind::UnexpectedValue { expected, got }),
359 },
360 });
361 }
362 }
363 }
364 }
365
366 Ok(())
367 }
368}
369
370#[derive(Debug, Clone, PartialEq)]
371pub struct TestError {
372 location: TestLocation,
373 kind: TestErrorKind,
374}
375
376#[derive(Clone, Copy, PartialEq)]
377pub struct TestLocation(&'static Location<'static>);
378
379impl Debug for TestLocation {
380 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381 write!(f, "{}", self.0)
382 }
383}
384
385#[non_exhaustive]
389#[derive(Debug, Clone, PartialEq)]
390pub enum TestErrorKind {
391 Parse(ParseError),
392 Compile(CompileError),
393 Shell(ShellError),
394 GotValue {
395 got: Value,
396 },
397 NoInner,
398 UnexpectedErrorKind {
399 expected: &'static str,
400 got: ShellError,
401 },
402 UnexpectedValue {
403 expected: Value,
404 got: Value,
405 },
406 ExampleFailed {
407 command: String,
408 description: String,
409 code: String,
410 err: Box<TestErrorKind>,
411 },
412}
413
414impl Display for TestError {
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 write!(f, "{self:#?}")
417 }
418}
419
420impl Error for TestError {}
421
422impl From<ShellError> for TestError {
423 #[track_caller]
424 fn from(err: ShellError) -> Self {
425 Self {
426 location: TestLocation(Location::caller()),
427 kind: TestErrorKind::Shell(err),
428 }
429 }
430}
431
432impl From<ParseError> for TestError {
433 #[track_caller]
434 fn from(err: ParseError) -> Self {
435 Self {
436 location: TestLocation(Location::caller()),
437 kind: TestErrorKind::Parse(err),
438 }
439 }
440}
441
442impl TestError {
443 pub fn parse(self) -> Result<ParseError, TestError> {
445 match self.kind {
446 TestErrorKind::Parse(err) => Ok(err),
447 _ => Err(self),
448 }
449 }
450
451 pub fn compile(self) -> Result<CompileError, TestError> {
453 match self.kind {
454 TestErrorKind::Compile(err) => Ok(err),
455 _ => Err(self),
456 }
457 }
458
459 pub fn shell(self) -> Result<ShellError, TestError> {
461 match self.kind {
462 TestErrorKind::Shell(err) => Ok(err),
463 _ => Err(self),
464 }
465 }
466
467 #[track_caller]
469 pub fn update_location(self) -> Self {
470 Self {
471 location: TestLocation(Location::caller()),
472 ..self
473 }
474 }
475}
476
477pub type Result<T = (), E = TestError> = std::result::Result<T, E>;
479
480pub trait TestResultExt: Sized {
482 fn expect_value_eq<T: IntoValue>(self, value: T) -> Result;
484
485 fn expect_shell_error(self) -> Result<ShellError>;
487 fn expect_parse_error(self) -> Result<ParseError>;
489 fn expect_compile_error(self) -> Result<CompileError>;
491
492 fn expect_io_error(self) -> Result<IoError>;
494 fn expect_network_error(self) -> Result<NetworkError>;
496 fn expect_labeled_error(self) -> Result<LabeledError>;
498
499 #[track_caller]
501 fn expect_error(self) -> Result<ShellError> {
502 self.expect_shell_error()
503 }
504}
505
506impl TestResultExt for Result<Value> {
507 #[track_caller]
508 fn expect_value_eq<T: IntoValue>(self, expected: T) -> Result {
509 let expected = expected.into_value(Span::test_data());
510 match self {
511 Err(err) => Err(err.update_location()),
512 Ok(actual) if actual == expected => Ok(()),
513 Ok(actual) => Err(TestError {
514 location: TestLocation(Location::caller()),
515 kind: TestErrorKind::UnexpectedValue {
516 expected,
517 got: actual,
518 },
519 }),
520 }
521 }
522
523 #[track_caller]
524 fn expect_shell_error(self) -> Result<ShellError> {
525 match self {
526 Ok(got) => Err(TestError {
527 location: TestLocation(Location::caller()),
528 kind: TestErrorKind::GotValue { got },
529 }),
530 Err(TestError {
531 kind: TestErrorKind::Shell(err),
532 ..
533 }) => Ok(err),
534 Err(err) => Err(err.update_location()),
535 }
536 }
537
538 #[track_caller]
539 fn expect_parse_error(self) -> Result<ParseError> {
540 match self {
541 Ok(got) => Err(TestError {
542 location: TestLocation(Location::caller()),
543 kind: TestErrorKind::GotValue { got },
544 }),
545 Err(TestError {
546 kind: TestErrorKind::Parse(err),
547 ..
548 }) => Ok(err),
549 Err(err) => Err(err.update_location()),
550 }
551 }
552
553 #[track_caller]
554 fn expect_compile_error(self) -> Result<CompileError> {
555 match self {
556 Ok(got) => Err(TestError {
557 location: TestLocation(Location::caller()),
558 kind: TestErrorKind::GotValue { got },
559 }),
560 Err(TestError {
561 kind: TestErrorKind::Compile(err),
562 ..
563 }) => Ok(err),
564 Err(err) => Err(err.update_location()),
565 }
566 }
567
568 #[track_caller]
569 fn expect_io_error(self) -> Result<IoError> {
570 match self {
571 Ok(got) => Err(TestError {
572 location: TestLocation(Location::caller()),
573 kind: TestErrorKind::GotValue { got },
574 }),
575 Err(TestError {
576 kind: TestErrorKind::Shell(ShellError::Io(err)),
577 ..
578 }) => Ok(err),
579 Err(err) => Err(err.update_location()),
580 }
581 }
582
583 #[track_caller]
584 fn expect_network_error(self) -> Result<NetworkError> {
585 match self {
586 Ok(got) => Err(TestError {
587 location: TestLocation(Location::caller()),
588 kind: TestErrorKind::GotValue { got },
589 }),
590 Err(TestError {
591 kind: TestErrorKind::Shell(ShellError::Network(err)),
592 ..
593 }) => Ok(err),
594 Err(err) => Err(err.update_location()),
595 }
596 }
597
598 #[track_caller]
599 fn expect_labeled_error(self) -> Result<LabeledError> {
600 match self {
601 Ok(got) => Err(TestError {
602 location: TestLocation(Location::caller()),
603 kind: TestErrorKind::GotValue { got },
604 }),
605 Err(TestError {
606 kind: TestErrorKind::Shell(ShellError::LabeledError(err)),
607 ..
608 }) => Ok(*err),
609 Err(err) => Err(err.update_location()),
610 }
611 }
612}
613
614pub trait ShellErrorExt {
616 fn into_inner(self) -> Result<ShellError>;
628
629 fn into_labeled(self) -> Result<LabeledError>;
631
632 fn generic_error(self) -> Result<String>;
634
635 fn generic_msg(self) -> Result<String>;
637}
638
639impl ShellErrorExt for ShellError {
640 #[track_caller]
641 fn into_inner(self) -> Result<ShellError> {
642 let no_inner = TestError {
643 location: TestLocation(Location::caller()),
644 kind: TestErrorKind::NoInner,
645 };
646 match self {
647 ShellError::Generic(err) => err.inner.into_iter().next().ok_or(no_inner),
648 ShellError::ChainedError(err) => err.sources_iter().next().ok_or(no_inner),
649 _ => Err(no_inner),
650 }
651 }
652
653 #[track_caller]
654 fn into_labeled(self) -> Result<LabeledError> {
655 match self {
656 ShellError::LabeledError(err) => Ok(*err),
657 got => Err(TestError {
658 location: TestLocation(Location::caller()),
659 kind: TestErrorKind::UnexpectedErrorKind {
660 expected: "Labeled",
661 got,
662 },
663 }),
664 }
665 }
666
667 #[track_caller]
668 fn generic_error(self) -> Result<String> {
669 match self {
670 ShellError::Generic(err) => Ok(err.error.into_owned()),
671 got => Err(TestError {
672 location: TestLocation(Location::caller()),
673 kind: TestErrorKind::UnexpectedErrorKind {
674 expected: "Generic",
675 got,
676 },
677 }),
678 }
679 }
680
681 #[track_caller]
682 fn generic_msg(self) -> Result<String> {
683 match self {
684 ShellError::Generic(err) => Ok(err.msg.into_owned()),
685 got => Err(TestError {
686 location: TestLocation(Location::caller()),
687 kind: TestErrorKind::UnexpectedErrorKind {
688 expected: "Generic",
689 got,
690 },
691 }),
692 }
693 }
694}