nu_plugin_test_support/
plugin_test.rs

1use std::{cmp::Ordering, convert::Infallible, sync::Arc};
2
3use nu_ansi_term::Style;
4use nu_cmd_lang::create_default_context;
5use nu_engine::eval_block;
6use nu_parser::parse;
7use nu_plugin::{Plugin, PluginCommand};
8use nu_plugin_engine::{PluginCustomValueWithSource, PluginSource, WithSource};
9use nu_plugin_protocol::PluginCustomValue;
10use nu_protocol::{
11    CustomValue, Example, IntoSpanned as _, LabeledError, PipelineData, ShellError, Signals, Span,
12    Value,
13    debugger::WithoutDebug,
14    engine::{EngineState, Stack, StateWorkingSet},
15    report_shell_error,
16};
17
18use crate::{diff::diff_by_line, fake_register::fake_register};
19
20/// An object through which plugins can be tested.
21pub struct PluginTest {
22    engine_state: EngineState,
23    source: Arc<PluginSource>,
24    entry_num: usize,
25}
26
27impl PluginTest {
28    /// Create a new test for the given `plugin` named `name`.
29    ///
30    /// # Example
31    ///
32    /// ```rust,no_run
33    /// # use nu_plugin_test_support::PluginTest;
34    /// # use nu_protocol::ShellError;
35    /// # use nu_plugin::*;
36    /// # fn test(MyPlugin: impl Plugin + Send + 'static) -> Result<PluginTest, ShellError> {
37    /// PluginTest::new("my_plugin", MyPlugin.into())
38    /// # }
39    /// ```
40    pub fn new(
41        name: &str,
42        plugin: Arc<impl Plugin + Send + 'static>,
43    ) -> Result<PluginTest, ShellError> {
44        let mut engine_state = create_default_context();
45        let mut working_set = StateWorkingSet::new(&engine_state);
46
47        let reg_plugin = fake_register(&mut working_set, name, plugin)?;
48        let source = Arc::new(PluginSource::new(reg_plugin));
49
50        engine_state.merge_delta(working_set.render())?;
51
52        Ok(PluginTest {
53            engine_state,
54            source,
55            entry_num: 1,
56        })
57    }
58
59    /// Get the [`EngineState`].
60    pub fn engine_state(&self) -> &EngineState {
61        &self.engine_state
62    }
63
64    /// Get a mutable reference to the [`EngineState`].
65    pub fn engine_state_mut(&mut self) -> &mut EngineState {
66        &mut self.engine_state
67    }
68
69    /// Make additional command declarations available for use by tests.
70    ///
71    /// This can be used to pull in commands from `nu-cmd-lang` for example, as required.
72    pub fn add_decl(
73        &mut self,
74        decl: Box<dyn nu_protocol::engine::Command>,
75    ) -> Result<&mut Self, ShellError> {
76        let mut working_set = StateWorkingSet::new(&self.engine_state);
77        working_set.add_decl(decl);
78        self.engine_state.merge_delta(working_set.render())?;
79        Ok(self)
80    }
81
82    /// Evaluate some Nushell source code with the plugin commands in scope with the given input to
83    /// the pipeline.
84    ///
85    /// # Example
86    ///
87    /// ```rust,no_run
88    /// # use nu_plugin_test_support::PluginTest;
89    /// # use nu_protocol::{IntoInterruptiblePipelineData, ShellError, Signals, Span, Value};
90    /// # use nu_plugin::*;
91    /// # fn test(MyPlugin: impl Plugin + Send + 'static) -> Result<(), ShellError> {
92    /// let result = PluginTest::new("my_plugin", MyPlugin.into())?
93    ///     .eval_with(
94    ///         "my-command",
95    ///         vec![Value::test_int(42)].into_pipeline_data(Span::test_data(), Signals::empty())
96    ///     )?
97    ///     .into_value(Span::test_data())?;
98    /// assert_eq!(Value::test_string("42"), result);
99    /// # Ok(())
100    /// # }
101    /// ```
102    pub fn eval_with(
103        &mut self,
104        nu_source: &str,
105        input: PipelineData,
106    ) -> Result<PipelineData, ShellError> {
107        let mut working_set = StateWorkingSet::new(&self.engine_state);
108        let fname = format!("entry #{}", self.entry_num);
109        self.entry_num += 1;
110
111        // Parse the source code
112        let block = parse(&mut working_set, Some(&fname), nu_source.as_bytes(), false);
113
114        // Check for parse errors
115        let error = if !working_set.parse_errors.is_empty() {
116            // ShellError doesn't have ParseError, use LabeledError to contain it.
117            let mut error = LabeledError::new("Example failed to parse");
118            error.inner.extend(
119                working_set
120                    .parse_errors
121                    .iter()
122                    .map(LabeledError::from_diagnostic),
123            );
124            Some(ShellError::LabeledError(error.into()))
125        } else {
126            None
127        };
128
129        // Merge into state
130        self.engine_state.merge_delta(working_set.render())?;
131
132        // Return error if set. We merge the delta even if we have errors so that printing the error
133        // based on the engine state still works.
134        if let Some(error) = error {
135            return Err(error);
136        }
137
138        // Serialize custom values in the input
139        let source = self.source.clone();
140        let input = match input {
141            input @ PipelineData::ByteStream(..) => input,
142            input => input.map(
143                move |mut value| {
144                    let result = PluginCustomValue::serialize_custom_values_in(&mut value)
145                        // Make sure to mark them with the source so they pass correctly, too.
146                        .and_then(|_| {
147                            PluginCustomValueWithSource::add_source_in(&mut value, &source)
148                        });
149                    match result {
150                        Ok(()) => value,
151                        Err(err) => Value::error(err, value.span()),
152                    }
153                },
154                &Signals::empty(),
155            )?,
156        };
157
158        // Eval the block with the input
159        let mut stack = Stack::new().collect_value();
160        let data = eval_block::<WithoutDebug>(&self.engine_state, &mut stack, &block, input)
161            .map(|p| p.body)?;
162        match data {
163            data @ PipelineData::ByteStream(..) => Ok(data),
164            data => data.map(
165                |mut value| {
166                    // Make sure to deserialize custom values
167                    let result = PluginCustomValueWithSource::remove_source_in(&mut value)
168                        .and_then(|_| PluginCustomValue::deserialize_custom_values_in(&mut value));
169                    match result {
170                        Ok(()) => value,
171                        Err(err) => Value::error(err, value.span()),
172                    }
173                },
174                &Signals::empty(),
175            ),
176        }
177    }
178
179    /// Evaluate some Nushell source code with the plugin commands in scope.
180    ///
181    /// # Example
182    ///
183    /// ```rust,no_run
184    /// # use nu_plugin_test_support::PluginTest;
185    /// # use nu_protocol::{ShellError, Span, Value, IntoInterruptiblePipelineData};
186    /// # use nu_plugin::*;
187    /// # fn test(MyPlugin: impl Plugin + Send + 'static) -> Result<(), ShellError> {
188    /// let result = PluginTest::new("my_plugin", MyPlugin.into())?
189    ///     .eval("42 | my-command")?
190    ///     .into_value(Span::test_data())?;
191    /// assert_eq!(Value::test_string("42"), result);
192    /// # Ok(())
193    /// # }
194    /// ```
195    pub fn eval(&mut self, nu_source: &str) -> Result<PipelineData, ShellError> {
196        self.eval_with(nu_source, PipelineData::empty())
197    }
198
199    /// Test a list of plugin examples. Prints an error for each failing example.
200    ///
201    /// See [`.test_command_examples()`] for easier usage of this method on a command's examples.
202    ///
203    /// # Example
204    ///
205    /// ```rust,no_run
206    /// # use nu_plugin_test_support::PluginTest;
207    /// # use nu_protocol::{ShellError, Example, Value};
208    /// # use nu_plugin::*;
209    /// # fn test(MyPlugin: impl Plugin + Send + 'static) -> Result<(), ShellError> {
210    /// PluginTest::new("my_plugin", MyPlugin.into())?
211    ///     .test_examples(&[
212    ///         Example {
213    ///             example: "my-command",
214    ///             description: "Run my-command",
215    ///             result: Some(Value::test_string("my-command output")),
216    ///         },
217    ///     ])
218    /// # }
219    /// ```
220    pub fn test_examples(&mut self, examples: &[Example]) -> Result<(), ShellError> {
221        let mut failed = false;
222
223        for example in examples {
224            let bold = Style::new().bold();
225            let mut failed_header = || {
226                failed = true;
227                eprintln!("{} {}", bold.paint("Example:"), example.example);
228                eprintln!("{} {}", bold.paint("Description:"), example.description);
229            };
230            if let Some(expectation) = &example.result {
231                match self.eval(example.example) {
232                    Ok(data) => {
233                        let mut value = data.into_value(Span::test_data())?;
234
235                        // Set all of the spans in the value to test_data() to avoid unnecessary
236                        // differences when printing
237                        let _: Result<(), Infallible> = value.recurse_mut(&mut |here| {
238                            here.set_span(Span::test_data());
239                            Ok(())
240                        });
241
242                        // Check for equality with the result
243                        if !self.value_eq(expectation, &value)? {
244                            // If they're not equal, print a diff of the debug format
245                            let (expectation_formatted, value_formatted) =
246                                match (expectation, &value) {
247                                    (
248                                        Value::Custom { val: ex_val, .. },
249                                        Value::Custom { val: v_val, .. },
250                                    ) => {
251                                        // We have to serialize both custom values before handing them to the plugin
252                                        let expectation_serialized =
253                                            PluginCustomValue::serialize_from_custom_value(
254                                                ex_val.as_ref(),
255                                                expectation.span(),
256                                            )?
257                                            .with_source(self.source.clone());
258
259                                        let value_serialized =
260                                            PluginCustomValue::serialize_from_custom_value(
261                                                v_val.as_ref(),
262                                                expectation.span(),
263                                            )?
264                                            .with_source(self.source.clone());
265
266                                        let persistent =
267                                            self.source.persistent(None)?.get_plugin(None)?;
268                                        let expectation_base = persistent
269                                            .custom_value_to_base_value(
270                                                expectation_serialized
271                                                    .into_spanned(expectation.span()),
272                                            )?;
273                                        let value_base = persistent.custom_value_to_base_value(
274                                            value_serialized.into_spanned(value.span()),
275                                        )?;
276
277                                        (
278                                            format!("{expectation_base:#?}"),
279                                            format!("{value_base:#?}"),
280                                        )
281                                    }
282                                    _ => (format!("{expectation:#?}"), format!("{value:#?}")),
283                                };
284
285                            let diff = diff_by_line(&expectation_formatted, &value_formatted);
286                            failed_header();
287                            eprintln!("{} {}", bold.paint("Result:"), diff);
288                        }
289                    }
290                    Err(err) => {
291                        // Report the error
292                        failed_header();
293                        report_shell_error(&self.engine_state, &err);
294                    }
295                }
296            }
297        }
298
299        if !failed {
300            Ok(())
301        } else {
302            Err(ShellError::GenericError {
303                error: "Some examples failed. See the error output for details".into(),
304                msg: "".into(),
305                span: None,
306                help: None,
307                inner: vec![],
308            })
309        }
310    }
311
312    /// Test examples from a command.
313    ///
314    /// # Example
315    ///
316    /// ```rust,no_run
317    /// # use nu_plugin_test_support::PluginTest;
318    /// # use nu_protocol::ShellError;
319    /// # use nu_plugin::*;
320    /// # fn test(MyPlugin: impl Plugin + Send + 'static, MyCommand: impl PluginCommand) -> Result<(), ShellError> {
321    /// PluginTest::new("my_plugin", MyPlugin.into())?
322    ///     .test_command_examples(&MyCommand)
323    /// # }
324    /// ```
325    pub fn test_command_examples(
326        &mut self,
327        command: &impl PluginCommand,
328    ) -> Result<(), ShellError> {
329        self.test_examples(&command.examples())
330    }
331
332    /// This implements custom value comparison with `plugin.custom_value_partial_cmp()` to behave
333    /// as similarly as possible to comparison in the engine.
334    ///
335    /// NOTE: Try to keep these reflecting the same comparison as `Value::partial_cmp` does under
336    /// normal circumstances. Otherwise people will be very confused.
337    fn value_eq(&self, a: &Value, b: &Value) -> Result<bool, ShellError> {
338        match (a, b) {
339            (Value::Custom { val, .. }, _) => {
340                // We have to serialize both custom values before handing them to the plugin
341                let serialized =
342                    PluginCustomValue::serialize_from_custom_value(val.as_ref(), a.span())?
343                        .with_source(self.source.clone());
344                let mut b_serialized = b.clone();
345                PluginCustomValue::serialize_custom_values_in(&mut b_serialized)?;
346                PluginCustomValueWithSource::add_source_in(&mut b_serialized, &self.source)?;
347                // Now get the plugin reference and execute the comparison
348                let persistent = self.source.persistent(None)?.get_plugin(None)?;
349                let ordering = persistent.custom_value_partial_cmp(serialized, b_serialized)?;
350                Ok(matches!(
351                    ordering.map(Ordering::from),
352                    Some(Ordering::Equal)
353                ))
354            }
355            // All container types need to be here except Closure.
356            (Value::List { vals: a_vals, .. }, Value::List { vals: b_vals, .. }) => {
357                // Must be the same length, with all elements equivalent
358                Ok(a_vals.len() == b_vals.len() && {
359                    for (a_el, b_el) in a_vals.iter().zip(b_vals) {
360                        if !self.value_eq(a_el, b_el)? {
361                            return Ok(false);
362                        }
363                    }
364                    true
365                })
366            }
367            (Value::Record { val: a_rec, .. }, Value::Record { val: b_rec, .. }) => {
368                // Must be the same length
369                if a_rec.len() != b_rec.len() {
370                    return Ok(false);
371                }
372
373                // reorder cols and vals to make more logically compare.
374                // more general, if two record have same col and values,
375                // the order of cols shouldn't affect the equal property.
376                let mut a_rec = a_rec.clone().into_owned();
377                let mut b_rec = b_rec.clone().into_owned();
378                a_rec.sort_cols();
379                b_rec.sort_cols();
380
381                // Check columns first
382                for (a, b) in a_rec.columns().zip(b_rec.columns()) {
383                    if a != b {
384                        return Ok(false);
385                    }
386                }
387                // Then check the values
388                for (a, b) in a_rec.values().zip(b_rec.values()) {
389                    if !self.value_eq(a, b)? {
390                        return Ok(false);
391                    }
392                }
393                // All equal, and same length
394                Ok(true)
395            }
396            // Fall back to regular eq.
397            _ => Ok(a == b),
398        }
399    }
400
401    /// This implements custom value comparison with `plugin.custom_value_to_base_value()` to behave
402    /// as similarly as possible to comparison in the engine.
403    pub fn custom_value_to_base_value(
404        &self,
405        val: &dyn CustomValue,
406        span: Span,
407    ) -> Result<Value, ShellError> {
408        let serialized = PluginCustomValue::serialize_from_custom_value(val, span)?
409            .with_source(self.source.clone());
410        let persistent = self.source.persistent(None)?.get_plugin(None)?;
411        persistent.custom_value_to_base_value(serialized.into_spanned(span))
412    }
413}