pest_test/
diff.rs

1use crate::model::{Expression, ExpressionFormatter};
2use colored::Color;
3use std::{
4    collections::HashSet,
5    fmt::{Display, Result as FmtResult},
6};
7
8#[derive(Debug)]
9pub enum ExpressionDiff {
10    Equal(Expression),
11    NotEqual {
12        expected: Expression,
13        actual: Expression,
14    },
15    Missing(Expression),
16    Extra(Expression),
17    Partial {
18        name: String,
19        children: Vec<ExpressionDiff>,
20    },
21}
22
23impl ExpressionDiff {
24    pub fn from_expressions(
25        expected: &Expression,
26        actual: &Expression,
27        ignore_missing_expected_values: bool,
28    ) -> ExpressionDiff {
29        match (expected, actual) {
30            (
31                Expression::Terminal {
32                    name: expected_name,
33                    value: expected_value,
34                },
35                Expression::Terminal {
36                    name: actual_name,
37                    value: actual_value,
38                },
39            ) if expected_name == actual_name && expected_value == actual_value => {
40                ExpressionDiff::Equal(actual.clone())
41            }
42            (
43                Expression::Terminal {
44                    name: expected_name,
45                    value: None,
46                },
47                Expression::Terminal {
48                    name: actual_name,
49                    value: Some(actual_value),
50                },
51            ) if expected_name == actual_name
52                && (ignore_missing_expected_values || actual_value.is_empty()) =>
53            {
54                ExpressionDiff::Equal(actual.clone())
55            }
56            (
57                Expression::NonTerminal {
58                    name: expected_name,
59                    children: expected_children,
60                },
61                Expression::NonTerminal {
62                    name: actual_name,
63                    children: actual_children,
64                },
65            ) if expected_name == actual_name => {
66                let expected_names: HashSet<&String> =
67                    expected_children.iter().map(|expr| expr.name()).collect();
68                let mut expected_iter = expected_children.iter().peekable();
69                let mut actual_iter = actual_children.iter();
70                let mut children = Vec::new();
71                loop {
72                    if let Some(expected_child) = expected_iter.next() {
73                        match actual_iter.next() {
74                            Some(actual_child)
75                                if Some(expected_child.name())
76                                    == actual_child
77                                        .get_descendant(expected_child.skip_depth())
78                                        .map(|e| e.name()) =>
79                            {
80                                children.push(Self::from_expressions(
81                                    expected_child,
82                                    actual_child,
83                                    ignore_missing_expected_values,
84                                ));
85                            }
86                            Some(actual_child) => {
87                                children.push(ExpressionDiff::Missing(expected_child.clone()));
88                                if expected_names.contains(actual_child.name()) {
89                                    while let Some(next) = expected_iter.peek() {
90                                        if next.name() == actual_child.name() {
91                                            break;
92                                        } else {
93                                            children.push(ExpressionDiff::Missing(
94                                                expected_iter.next().unwrap().clone(),
95                                            ));
96                                        }
97                                    }
98                                } else {
99                                    children.push(ExpressionDiff::Extra(actual_child.clone()))
100                                }
101                            }
102                            None => children.push(ExpressionDiff::Missing(expected_child.clone())),
103                        }
104                    } else {
105                        children.extend(
106                            actual_iter
107                                .map(|actual_child| ExpressionDiff::Extra(actual_child.clone())),
108                        );
109                        break;
110                    }
111                }
112                let partial = children
113                    .iter()
114                    .filter(|child| !matches!(child, ExpressionDiff::Equal(_)))
115                    .count()
116                    > 0;
117                if partial {
118                    ExpressionDiff::Partial {
119                        name: expected_name.clone(),
120                        children,
121                    }
122                } else {
123                    ExpressionDiff::Equal(Expression::NonTerminal {
124                        name: expected_name.clone(),
125                        children: children
126                            .into_iter()
127                            .map(|child| match child {
128                                ExpressionDiff::Equal(expression) => expression,
129                                _ => panic!("Unexpected non-equal value"),
130                            })
131                            .collect(),
132                    })
133                }
134            }
135            (Expression::Skip { depth, next }, actual) => match actual.get_descendant(*depth) {
136                Some(descendant) => Self::from_expressions(
137                    next.as_ref(),
138                    descendant,
139                    ignore_missing_expected_values,
140                ),
141                None => ExpressionDiff::NotEqual {
142                    expected: expected.clone(),
143                    actual: actual.clone(),
144                },
145            },
146            _ => ExpressionDiff::NotEqual {
147                expected: expected.clone(),
148                actual: actual.clone(),
149            },
150        }
151    }
152
153    pub fn name(&self) -> String {
154        match self {
155            ExpressionDiff::Equal(exp) => exp.name().clone(),
156            ExpressionDiff::NotEqual { expected, actual } if expected.name() == actual.name() => {
157                expected.name().to_owned()
158            }
159            ExpressionDiff::NotEqual { expected, actual } => {
160                format!("{}/{}", expected.name(), actual.name())
161            }
162            ExpressionDiff::Missing(exp) => exp.name().to_owned(),
163            ExpressionDiff::Extra(exp) => exp.name().to_owned(),
164            ExpressionDiff::Partial { name, children: _ } => name.to_owned(),
165        }
166    }
167
168    /// Print this diff to stderr. Intended to be used in a unit test to print the diff when the
169    /// evaluation result is a `TestError::Diff`. This is necessary because, by default, an Err
170    /// result is displayed using its `Debug` value.
171    ///
172    /// Example:
173    /// fn test () -> Result<(), TestError> {
174    ///   let tester: PestTester<Rule, MyParser> = PestTester::from_defaults(Rule::root_rule);
175    ///   let res = tester.evaluate_strict("mytest");
176    ///   if let Err(TestError::Diff { diff }) = res {
177    ///     diff.print_test_result();
178    ///   }
179    ///   res
180    /// }
181    pub fn print_test_result(&self, colorize: bool) -> FmtResult {
182        let mut writer = String::new();
183        let (expected_color, actual_color) = if colorize {
184            (Some(Color::Green), Some(Color::Red))
185        } else {
186            (None, None)
187        };
188        let mut formatter = ExpressionFormatter::from_defaults(&mut writer);
189        formatter.write_str("========================================================\n")?;
190        formatter.write_str("Parse tree differs between ")?;
191        formatter.color = expected_color;
192        formatter.write_str("expected")?;
193        formatter.color = None;
194        formatter.write_str(" and ")?;
195        formatter.color = actual_color;
196        formatter.write_str("actual")?;
197        formatter.color = None;
198        formatter.write_str(" results:")?;
199        formatter.write_newline()?;
200        formatter.write_str("========================================================")?;
201        formatter.write_newline()?;
202        formatter.fmt_diff(self, expected_color, actual_color)?;
203        formatter.write_newline()?;
204        formatter.write_str("========================================================")?;
205        formatter.write_newline()?;
206        eprintln!("{}", writer);
207        Ok(())
208    }
209}
210
211pub trait ExpressionDiffFormatterExt {
212    fn fmt_diff(
213        &mut self,
214        diff: &ExpressionDiff,
215        expected_color: Option<Color>,
216        actual_color: Option<Color>,
217    ) -> FmtResult;
218}
219
220impl<'a> ExpressionDiffFormatterExt for ExpressionFormatter<'a> {
221    fn fmt_diff(
222        &mut self,
223        diff: &ExpressionDiff,
224        expected_color: Option<Color>,
225        actual_color: Option<Color>,
226    ) -> FmtResult {
227        match diff {
228            ExpressionDiff::Equal(expression) => self.fmt(expression)?,
229            ExpressionDiff::NotEqual { expected, actual } => {
230                self.color = expected_color;
231                self.fmt(expected)?;
232                self.write_newline()?;
233                self.color = actual_color;
234                self.fmt(actual)?;
235                self.color = None;
236            }
237            ExpressionDiff::Missing(expression) => {
238                self.color = expected_color;
239                self.fmt(expression)?;
240                self.color = None;
241            }
242            ExpressionDiff::Extra(expression) => {
243                self.color = actual_color;
244                self.fmt(expression)?;
245                self.color = None;
246            }
247            ExpressionDiff::Partial { name, children } => {
248                self.write_indent()?;
249                self.write_char('(')?;
250                self.write_str(name)?;
251                self.write_newline()?;
252                self.level += 1;
253                for child in children {
254                    self.fmt_diff(child, expected_color, actual_color)?;
255                    self.write_newline()?;
256                }
257                self.level -= 1;
258                self.write_indent()?;
259                self.write_char(')')?;
260            }
261        }
262        Ok(())
263    }
264}
265
266impl Display for ExpressionDiff {
267    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268        ExpressionFormatter::from_defaults(f).fmt_diff(self, Some(Color::Green), Some(Color::Red))
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::{ExpressionDiff, ExpressionDiffFormatterExt};
275    use crate::{
276        model::{Expression, ExpressionFormatter, TestCase},
277        parser::Rule,
278        TestError, TestParser,
279    };
280    use colored::Color;
281    use indoc::indoc;
282
283    const TEXT: &str = indoc! {r#"
284    My Test
285    =======
286
287    fn x() int {
288      return 1;
289    }
290
291    =======
292    
293    (source_file
294      (function_definition
295        (identifier: "x")
296        (parameter_list)
297        (primitive_type: "int")
298        (block
299          (return_statement 
300            (number: "1")
301          )
302        )
303      )
304    )
305    "#};
306
307    fn make_expected_sexpression(with_skip: bool) -> Expression {
308        let block_children = if with_skip {
309            Vec::from([Expression::Skip {
310                depth: 1,
311                next: Box::new(Expression::Terminal {
312                    name: String::from("number"),
313                    value: Some(String::from("1")),
314                }),
315            }])
316        } else {
317            Vec::from([Expression::Terminal {
318                name: String::from("return_statement"),
319                value: None,
320            }])
321        };
322        Expression::NonTerminal {
323            name: String::from("source_file"),
324            children: Vec::from([Expression::NonTerminal {
325                name: String::from("function_definition"),
326                children: Vec::from([
327                    Expression::Terminal {
328                        name: String::from("identifier"),
329                        value: Some(String::from("y")),
330                    },
331                    Expression::NonTerminal {
332                        name: String::from("missing"),
333                        children: Vec::from([Expression::Terminal {
334                            name: String::from("foo"),
335                            value: None,
336                        }]),
337                    },
338                    Expression::Terminal {
339                        name: String::from("primitive_type"),
340                        value: None,
341                    },
342                    Expression::NonTerminal {
343                        name: String::from("block"),
344                        children: block_children,
345                    },
346                ]),
347            }]),
348        }
349    }
350
351    fn assert_equal<'a>(diff: &'a ExpressionDiff, expected_name: &'a str) -> &'a Expression {
352        match diff {
353            ExpressionDiff::Equal(expr) => {
354                assert_eq!(expr.name(), expected_name);
355                expr
356            }
357            _ => panic!("Expected diff to be equal but was {}", diff),
358        }
359    }
360
361    fn assert_partial<'a>(
362        diff: &'a ExpressionDiff,
363        expected_name: &'a str,
364    ) -> &'a Vec<ExpressionDiff> {
365        match diff {
366            ExpressionDiff::Partial { name, children } => {
367                assert_eq!(expected_name, name);
368                children
369            }
370            _ => panic!("Expected diff to be partial but was {}", diff),
371        }
372    }
373
374    fn assert_value_equal(
375        diff: &ExpressionDiff,
376        expected_name: &str,
377        expected_value: Option<&str>,
378    ) {
379        match diff {
380            ExpressionDiff::Equal(Expression::Terminal { name, value }) => {
381                assert_eq!(expected_name, name);
382                match (expected_value, value) {
383                    (Some(expected), Some(actual)) => assert_eq!(expected, actual),
384                    _ => (),
385                }
386            }
387            _ => panic!("Expectedc diff to be equal but was {}", diff),
388        }
389    }
390
391    fn assert_value_nonequal(
392        diff: &ExpressionDiff,
393        name: &str,
394        expected_expected_value: Option<&str>,
395        expected_actual_value: Option<&str>,
396    ) {
397        match diff {
398            ExpressionDiff::NotEqual {
399                expected:
400                    Expression::Terminal {
401                        name: expected_name,
402                        value: expected_value,
403                    },
404                actual:
405                    Expression::Terminal {
406                        name: actual_name,
407                        value: actual_value,
408                    },
409            } => {
410                assert_eq!(expected_name, name);
411                assert_eq!(actual_name, name);
412                assert_eq!(
413                    expected_expected_value.map(|s| s.to_owned()),
414                    *expected_value
415                );
416                assert_eq!(expected_actual_value.map(|s| s.to_owned()), *actual_value);
417            }
418            _ => panic!("Expected diff to be non-equal but was {}", diff),
419        }
420    }
421
422    fn assert_missing(diff: &ExpressionDiff, expected_name: &str) {
423        match diff {
424            ExpressionDiff::Missing(expr) => assert_eq!(expr.name(), expected_name),
425            _ => panic!("Expected diff to be missing but was {}", diff),
426        }
427    }
428
429    fn assert_extra(diff: &ExpressionDiff, expected_name: &str) {
430        match diff {
431            ExpressionDiff::Extra(expr) => assert_eq!(expr.name(), expected_name),
432            _ => panic!("Expected diff to be extra but was {}", diff),
433        }
434    }
435
436    fn assert_nonequal_type(diff: &ExpressionDiff, expected_name: &str) {
437        match diff {
438            ExpressionDiff::NotEqual {
439                expected:
440                    Expression::Terminal {
441                        name: terminal_name,
442                        value: _,
443                    },
444                actual:
445                    Expression::NonTerminal {
446                        name: nonterminal_name,
447                        children: _,
448                    },
449            } => {
450                assert_eq!(expected_name, nonterminal_name);
451                assert_eq!(expected_name, terminal_name);
452            }
453            ExpressionDiff::NotEqual {
454                expected:
455                    Expression::NonTerminal {
456                        name: nonterminal_name,
457                        children: _,
458                    },
459                actual:
460                    Expression::Terminal {
461                        name: terminal_name,
462                        value: _,
463                    },
464            } => {
465                assert_eq!(expected_name, nonterminal_name);
466                assert_eq!(expected_name, terminal_name);
467            }
468            _ => panic!("Expected diff to be non-equal but was {}", diff),
469        }
470    }
471
472    #[test]
473    fn test_diff_strict() -> Result<(), TestError<Rule>> {
474        let test_case: TestCase = TestParser::parse(TEXT)
475            .map_err(|source| TestError::Parser { source })
476            .and_then(|pair| {
477                TestCase::try_from_pair(pair).map_err(|source| TestError::Model { source })
478            })?;
479        let expected_sexpr = make_expected_sexpression(false);
480        let diff_strict =
481            ExpressionDiff::from_expressions(&expected_sexpr, &test_case.expression, false);
482        let children = assert_partial(&diff_strict, "source_file");
483        assert_eq!(children.len(), 1);
484        let children = assert_partial(&children[0], "function_definition");
485        assert_eq!(children.len(), 5);
486        assert_value_nonequal(&children[0], "identifier", Some("y"), Some("x"));
487        assert_missing(&children[1], "missing");
488        assert_extra(&children[2], "parameter_list");
489        assert_value_nonequal(&children[3], "primitive_type", None, Some("int"));
490        let children = assert_partial(&children[4], "block");
491        assert_eq!(children.len(), 1);
492        assert_nonequal_type(&children[0], "return_statement");
493        Ok(())
494    }
495
496    #[test]
497    fn test_diff_lenient() -> Result<(), TestError<Rule>> {
498        let test_case: TestCase = TestParser::parse(TEXT)
499            .map_err(|source| TestError::Parser { source })
500            .and_then(|pair| {
501                TestCase::try_from_pair(pair).map_err(|source| TestError::Model { source })
502            })?;
503        let expected_sexpr = make_expected_sexpression(false);
504        let diff_lenient =
505            ExpressionDiff::from_expressions(&expected_sexpr, &test_case.expression, true);
506        let children = assert_partial(&diff_lenient, "source_file");
507        let children = assert_partial(&children[0], "function_definition");
508        assert_value_equal(&children[3], "primitive_type", Some("int"));
509        Ok(())
510    }
511
512    #[test]
513    fn test_diff_with_skip() -> Result<(), TestError<Rule>> {
514        let test_case: TestCase = TestParser::parse(TEXT)
515            .map_err(|source| TestError::Parser { source })
516            .and_then(|pair| {
517                TestCase::try_from_pair(pair).map_err(|source| TestError::Model { source })
518            })?;
519        let expected_sexpr = make_expected_sexpression(true);
520        let diff_lenient =
521            ExpressionDiff::from_expressions(&expected_sexpr, &test_case.expression, true);
522        let children = assert_partial(&diff_lenient, "source_file");
523        let children = assert_partial(&children[0], "function_definition");
524        assert_value_equal(&children[3], "primitive_type", Some("int"));
525        assert_equal(&children[4], "block");
526        Ok(())
527    }
528
529    #[test]
530    fn test_format_nocolor() -> Result<(), TestError<Rule>> {
531        let test_case: TestCase = TestParser::parse(TEXT)
532            .map_err(|source| TestError::Parser { source })
533            .and_then(|pair| {
534                TestCase::try_from_pair(pair).map_err(|source| TestError::Model { source })
535            })?;
536        let expected_sexpr = make_expected_sexpression(false);
537        let diff = ExpressionDiff::from_expressions(&expected_sexpr, &test_case.expression, false);
538        let mut writer = String::new();
539        let mut formatter = ExpressionFormatter::from_defaults(&mut writer);
540        formatter.fmt_diff(&diff, None, None).ok();
541        let expected = indoc! {r#"
542            (source_file
543              (function_definition
544                (identifier: "y")
545                (identifier: "x")
546                (missing
547                  (foo)
548                )
549                (parameter_list)
550                (primitive_type)
551                (primitive_type: "int")
552                (block
553                  (return_statement)
554                  (return_statement
555                    (number: "1")
556                  )
557                )
558              )
559            )"#};
560        assert_eq!(writer, expected);
561        Ok(())
562    }
563
564    #[test]
565    fn test_format_color() -> Result<(), TestError<Rule>> {
566        let test_case: TestCase = TestParser::parse(TEXT)
567            .map_err(|source| TestError::Parser { source })
568            .and_then(|pair| {
569                TestCase::try_from_pair(pair).map_err(|source| TestError::Model { source })
570            })?;
571        let expected_sexpr = make_expected_sexpression(false);
572        let diff = ExpressionDiff::from_expressions(&expected_sexpr, &test_case.expression, false);
573        let mut writer = String::new();
574        let mut formatter = ExpressionFormatter::from_defaults(&mut writer);
575        formatter
576            .fmt_diff(&diff, Some(Color::Green), Some(Color::Red))
577            .ok();
578        let expected = format!(
579            indoc! {r#"
580        (source_file
581          (function_definition
582        {green_start}    (identifier: "y"){end}
583        {red_start}    (identifier: "x"){end}
584        {green_start}    (missing
585              (foo)
586            ){end}
587        {red_start}    (parameter_list){end}
588        {green_start}    (primitive_type){end}
589        {red_start}    (primitive_type: "int"){end}
590            (block
591        {green_start}      (return_statement){end}
592        {red_start}      (return_statement
593                (number: "1")
594              ){end}
595            )
596          )
597        )"#},
598            green_start = "\u{1b}[32m",
599            red_start = "\u{1b}[31m",
600            end = "\u{1b}[0m",
601        );
602        assert_eq!(writer, expected);
603        Ok(())
604    }
605}