Skip to main content

nu_plugin_test_support/
plugin_test.rs

1use std::{cmp::Ordering, convert::Infallible, sync::Arc};
2
3use nu_ansi_term::Style;
4use nu_cmd_lang::create_default_context;
5use nu_engine::eval_block;
6use nu_parser::parse;
7use nu_plugin::{Plugin, PluginCommand};
8use nu_plugin_engine::{PluginCustomValueWithSource, PluginSource, WithSource};
9use nu_plugin_protocol::PluginCustomValue;
10use nu_protocol::{
11    CustomValue, Example, IntoSpanned as _, LabeledError, PipelineData, ShellError, Signals, Span,
12    Value,
13    debugger::WithoutDebug,
14    engine::{EngineState, Stack, StateWorkingSet},
15    report_shell_error,
16    shell_error::generic::GenericError,
17};
18
19use crate::{diff::diff_by_line, fake_register::fake_register};
20
21/// An object through which plugins can be tested.
22pub struct PluginTest {
23    engine_state: EngineState,
24    source: Arc<PluginSource>,
25    entry_num: usize,
26}
27
28impl PluginTest {
29    /// Create a new test for the given `plugin` named `name`.
30    ///
31    /// # Example
32    ///
33    /// ```rust,no_run
34    /// # use nu_plugin_test_support::PluginTest;
35    /// # use nu_protocol::ShellError;
36    /// # use nu_plugin::*;
37    /// # fn test(MyPlugin: impl Plugin + Send + 'static) -> Result<PluginTest, ShellError> {
38    /// PluginTest::new("my_plugin", MyPlugin.into())
39    /// # }
40    /// ```
41    pub fn new(
42        name: &str,
43        plugin: Arc<impl Plugin + Send + 'static>,
44    ) -> Result<PluginTest, ShellError> {
45        let mut engine_state = create_default_context();
46        let mut working_set = StateWorkingSet::new(&engine_state);
47
48        let reg_plugin = fake_register(&mut working_set, name, plugin)?;
49        let source = Arc::new(PluginSource::new(reg_plugin));
50
51        engine_state.merge_delta(working_set.render())?;
52
53        Ok(PluginTest {
54            engine_state,
55            source,
56            entry_num: 1,
57        })
58    }
59
60    /// Get the [`EngineState`].
61    pub fn engine_state(&self) -> &EngineState {
62        &self.engine_state
63    }
64
65    /// Get a mutable reference to the [`EngineState`].
66    pub fn engine_state_mut(&mut self) -> &mut EngineState {
67        &mut self.engine_state
68    }
69
70    /// Make additional command declarations available for use by tests.
71    ///
72    /// This can be used to pull in commands from `nu-cmd-lang` for example, as required.
73    pub fn add_decl(
74        &mut self,
75        decl: Box<dyn nu_protocol::engine::Command>,
76    ) -> Result<&mut Self, ShellError> {
77        let mut working_set = StateWorkingSet::new(&self.engine_state);
78        working_set.add_decl(decl);
79        self.engine_state.merge_delta(working_set.render())?;
80        Ok(self)
81    }
82
83    /// Evaluate some Nushell source code with the plugin commands in scope with the given input to
84    /// the pipeline.
85    ///
86    /// # Example
87    ///
88    /// ```rust,no_run
89    /// # use nu_plugin_test_support::PluginTest;
90    /// # use nu_protocol::{IntoInterruptiblePipelineData, ShellError, Signals, Span, Value};
91    /// # use nu_plugin::*;
92    /// # fn test(MyPlugin: impl Plugin + Send + 'static) -> Result<(), ShellError> {
93    /// let result = PluginTest::new("my_plugin", MyPlugin.into())?
94    ///     .eval_with(
95    ///         "my-command",
96    ///         vec![Value::test_int(42)].into_pipeline_data(Span::test_data(), Signals::empty())
97    ///     )?
98    ///     .into_value(Span::test_data())?;
99    /// assert_eq!(Value::test_string("42"), result);
100    /// # Ok(())
101    /// # }
102    /// ```
103    pub fn eval_with(
104        &mut self,
105        nu_source: &str,
106        input: PipelineData,
107    ) -> Result<PipelineData, ShellError> {
108        let mut working_set = StateWorkingSet::new(&self.engine_state);
109        let fname = format!("repl_entry #{}", self.entry_num);
110        self.entry_num += 1;
111
112        // Parse the source code
113        let block = parse(&mut working_set, Some(&fname), nu_source.as_bytes(), false);
114
115        // Check for parse errors
116        let error = if !working_set.parse_errors.is_empty() {
117            // ShellError doesn't have ParseError, use LabeledError to contain it.
118            let mut error = LabeledError::new("Example failed to parse");
119            error.inner.extend(
120                working_set
121                    .parse_errors
122                    .iter()
123                    .map(|i| LabeledError::from_diagnostic(i).into()),
124            );
125            Some(ShellError::LabeledError(error.into()))
126        } else {
127            None
128        };
129
130        // Merge into state
131        self.engine_state.merge_delta(working_set.render())?;
132
133        // Return error if set. We merge the delta even if we have errors so that printing the error
134        // based on the engine state still works.
135        if let Some(error) = error {
136            return Err(error);
137        }
138
139        // Serialize custom values in the input
140        let source = self.source.clone();
141        let input = match input {
142            input @ PipelineData::ByteStream(..) => input,
143            input => input.map(
144                move |mut value| {
145                    let result = PluginCustomValue::serialize_custom_values_in(&mut value)
146                        // Make sure to mark them with the source so they pass correctly, too.
147                        .and_then(|_| {
148                            PluginCustomValueWithSource::add_source_in(&mut value, &source)
149                        });
150                    match result {
151                        Ok(()) => value,
152                        Err(err) => Value::error(err, value.span()),
153                    }
154                },
155                &Signals::empty(),
156            )?,
157        };
158
159        // Eval the block with the input
160        let mut stack = Stack::new().collect_value();
161        let data = eval_block::<WithoutDebug>(&self.engine_state, &mut stack, &block, input)
162            .map(|p| p.body)?;
163        match data {
164            data @ PipelineData::ByteStream(..) => Ok(data),
165            data => data.map(
166                |mut value| {
167                    // Make sure to deserialize custom values
168                    let result = PluginCustomValueWithSource::remove_source_in(&mut value)
169                        .and_then(|_| PluginCustomValue::deserialize_custom_values_in(&mut value));
170                    match result {
171                        Ok(()) => value,
172                        Err(err) => Value::error(err, value.span()),
173                    }
174                },
175                &Signals::empty(),
176            ),
177        }
178    }
179
180    /// Evaluate some Nushell source code with the plugin commands in scope.
181    ///
182    /// # Example
183    ///
184    /// ```rust,no_run
185    /// # use nu_plugin_test_support::PluginTest;
186    /// # use nu_protocol::{ShellError, Span, Value, IntoInterruptiblePipelineData};
187    /// # use nu_plugin::*;
188    /// # fn test(MyPlugin: impl Plugin + Send + 'static) -> Result<(), ShellError> {
189    /// let result = PluginTest::new("my_plugin", MyPlugin.into())?
190    ///     .eval("42 | my-command")?
191    ///     .into_value(Span::test_data())?;
192    /// assert_eq!(Value::test_string("42"), result);
193    /// # Ok(())
194    /// # }
195    /// ```
196    pub fn eval(&mut self, nu_source: &str) -> Result<PipelineData, ShellError> {
197        self.eval_with(nu_source, PipelineData::empty())
198    }
199
200    /// Test a list of plugin examples. Prints an error for each failing example.
201    ///
202    /// See [`.test_command_examples()`](Self::test_command_examples) for easier usage of this method on a command's examples.
203    ///
204    /// # Example
205    ///
206    /// ```rust,no_run
207    /// # use nu_plugin_test_support::PluginTest;
208    /// # use nu_protocol::{ShellError, Example, Value};
209    /// # use nu_plugin::*;
210    /// # fn test(MyPlugin: impl Plugin + Send + 'static) -> Result<(), ShellError> {
211    /// PluginTest::new("my_plugin", MyPlugin.into())?
212    ///     .test_examples(&[
213    ///         Example {
214    ///             example: "my-command",
215    ///             description: "Run my-command",
216    ///             result: Some(Value::test_string("my-command output")),
217    ///         },
218    ///     ])
219    /// # }
220    /// ```
221    pub fn test_examples(&mut self, examples: &[Example]) -> Result<(), ShellError> {
222        let mut failed = false;
223
224        for example in examples {
225            let bold = Style::new().bold();
226            let mut failed_header = || {
227                failed = true;
228                eprintln!("{} {}", bold.paint("Example:"), example.example);
229                eprintln!("{} {}", bold.paint("Description:"), example.description);
230            };
231            if let Some(expectation) = &example.result {
232                match self.eval(example.example) {
233                    Ok(data) => {
234                        let mut value = data.into_value(Span::test_data())?;
235
236                        // Set all of the spans in the value to test_data() to avoid unnecessary
237                        // differences when printing
238                        let _: Result<(), Infallible> = value.recurse_mut(&mut |here| {
239                            here.set_span(Span::test_data());
240                            Ok(())
241                        });
242
243                        // Check for equality with the result
244                        if !self.value_eq(expectation, &value)? {
245                            // If they're not equal, print a diff of the debug format
246                            let (expectation_formatted, value_formatted) =
247                                match (expectation, &value) {
248                                    (
249                                        Value::Custom { val: ex_val, .. },
250                                        Value::Custom { val: v_val, .. },
251                                    ) => {
252                                        // We have to serialize both custom values before handing them to the plugin
253                                        let expectation_serialized =
254                                            PluginCustomValue::serialize_from_custom_value(
255                                                ex_val.as_ref(),
256                                                expectation.span(),
257                                            )?
258                                            .with_source(self.source.clone());
259
260                                        let value_serialized =
261                                            PluginCustomValue::serialize_from_custom_value(
262                                                v_val.as_ref(),
263                                                expectation.span(),
264                                            )?
265                                            .with_source(self.source.clone());
266
267                                        let persistent =
268                                            self.source.persistent(None)?.get_plugin(None)?;
269                                        let expectation_base = persistent
270                                            .custom_value_to_base_value(
271                                                expectation_serialized
272                                                    .into_spanned(expectation.span()),
273                                            )?;
274                                        let value_base = persistent.custom_value_to_base_value(
275                                            value_serialized.into_spanned(value.span()),
276                                        )?;
277
278                                        (
279                                            format!("{expectation_base:#?}"),
280                                            format!("{value_base:#?}"),
281                                        )
282                                    }
283                                    _ => (format!("{expectation:#?}"), format!("{value:#?}")),
284                                };
285
286                            let diff = diff_by_line(&expectation_formatted, &value_formatted);
287                            failed_header();
288                            eprintln!("{} {}", bold.paint("Result:"), diff);
289                        }
290                    }
291                    Err(err) => {
292                        // Report the error
293                        failed_header();
294                        report_shell_error(None, &self.engine_state, &err);
295                    }
296                }
297            }
298        }
299
300        if !failed {
301            Ok(())
302        } else {
303            Err(ShellError::Generic(GenericError::new_internal(
304                "Some examples failed. See the error output for details",
305                "",
306            )))
307        }
308    }
309
310    /// Test examples from a command.
311    ///
312    /// # Example
313    ///
314    /// ```rust,no_run
315    /// # use nu_plugin_test_support::PluginTest;
316    /// # use nu_protocol::ShellError;
317    /// # use nu_plugin::*;
318    /// # fn test(MyPlugin: impl Plugin + Send + 'static, MyCommand: impl PluginCommand) -> Result<(), ShellError> {
319    /// PluginTest::new("my_plugin", MyPlugin.into())?
320    ///     .test_command_examples(&MyCommand)
321    /// # }
322    /// ```
323    pub fn test_command_examples(
324        &mut self,
325        command: &impl PluginCommand,
326    ) -> Result<(), ShellError> {
327        self.test_examples(&command.examples())
328    }
329
330    /// This implements custom value comparison with `plugin.custom_value_partial_cmp()` to behave
331    /// as similarly as possible to comparison in the engine.
332    ///
333    /// NOTE: Try to keep these reflecting the same comparison as `Value::partial_cmp` does under
334    /// normal circumstances. Otherwise people will be very confused.
335    fn value_eq(&self, a: &Value, b: &Value) -> Result<bool, ShellError> {
336        match (a, b) {
337            (Value::Custom { val, .. }, _) => {
338                // We have to serialize both custom values before handing them to the plugin
339                let serialized =
340                    PluginCustomValue::serialize_from_custom_value(val.as_ref(), a.span())?
341                        .with_source(self.source.clone());
342                let mut b_serialized = b.clone();
343                PluginCustomValue::serialize_custom_values_in(&mut b_serialized)?;
344                PluginCustomValueWithSource::add_source_in(&mut b_serialized, &self.source)?;
345                // Now get the plugin reference and execute the comparison
346                let persistent = self.source.persistent(None)?.get_plugin(None)?;
347                let ordering =
348                    persistent.custom_value_partial_cmp(serialized, b_serialized, a.span())?;
349                Ok(matches!(
350                    ordering.map(Ordering::from),
351                    Some(Ordering::Equal)
352                ))
353            }
354            // All container types need to be here except Closure.
355            (Value::List { vals: a_vals, .. }, Value::List { vals: b_vals, .. }) => {
356                // Must be the same length, with all elements equivalent
357                Ok(a_vals.len() == b_vals.len() && {
358                    for (a_el, b_el) in a_vals.iter().zip(b_vals) {
359                        if !self.value_eq(a_el, b_el)? {
360                            return Ok(false);
361                        }
362                    }
363                    true
364                })
365            }
366            (Value::Record { val: a_rec, .. }, Value::Record { val: b_rec, .. }) => {
367                // Must be the same length
368                if a_rec.len() != b_rec.len() {
369                    return Ok(false);
370                }
371
372                // reorder cols and vals to make more logically compare.
373                // more general, if two record have same col and values,
374                // the order of cols shouldn't affect the equal property.
375                let mut a_rec = a_rec.clone().into_owned();
376                let mut b_rec = b_rec.clone().into_owned();
377                a_rec.sort_cols();
378                b_rec.sort_cols();
379
380                // Check columns first
381                for (a, b) in a_rec.columns().zip(b_rec.columns()) {
382                    if a != b {
383                        return Ok(false);
384                    }
385                }
386                // Then check the values
387                for (a, b) in a_rec.values().zip(b_rec.values()) {
388                    if !self.value_eq(a, b)? {
389                        return Ok(false);
390                    }
391                }
392                // All equal, and same length
393                Ok(true)
394            }
395            // Fall back to regular eq.
396            _ => Ok(a == b),
397        }
398    }
399
400    /// This implements custom value comparison with `plugin.custom_value_to_base_value()` to behave
401    /// as similarly as possible to comparison in the engine.
402    pub fn custom_value_to_base_value(
403        &self,
404        val: &dyn CustomValue,
405        span: Span,
406    ) -> Result<Value, ShellError> {
407        let serialized = PluginCustomValue::serialize_from_custom_value(val, span)?
408            .with_source(self.source.clone());
409        let persistent = self.source.persistent(None)?.get_plugin(None)?;
410        persistent.custom_value_to_base_value(serialized.into_spanned(span))
411    }
412}