pest_test/
model.rs

1use crate::parser::Rule;
2use colored::{Color, Colorize};
3use pest::{iterators::Pair, RuleType};
4use snailquote::unescape;
5use std::{
6    collections::HashSet,
7    fmt::{Display, Result as FmtResult, Write},
8};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12#[error("Error creating model element from parser pair")]
13pub struct ModelError(String);
14
15impl ModelError {
16    fn from_str(msg: &str) -> Self {
17        Self(msg.to_owned())
18    }
19}
20
21fn assert_rule(pair: Pair<'_, Rule>, rule: Rule) -> Result<Pair<'_, Rule>, ModelError> {
22    if pair.as_rule() == rule {
23        Ok(pair)
24    } else {
25        Err(ModelError(format!(
26            "Expected pair {:?} rule to be {:?}",
27            pair, rule
28        )))
29    }
30}
31
32#[derive(Clone, Debug)]
33pub enum Expression {
34    Terminal {
35        name: String,
36        value: Option<String>,
37    },
38    NonTerminal {
39        name: String,
40        children: Vec<Expression>,
41    },
42    Skip {
43        depth: usize,
44        next: Box<Expression>,
45    },
46}
47
48impl Expression {
49    pub fn try_from_sexpr(pair: Pair<'_, Rule>) -> Result<Self, ModelError> {
50        let mut inner = pair.into_inner();
51        let skip_depth: usize = if inner.peek().map(|pair| pair.as_rule()) == Some(Rule::skip) {
52            let depth_pair = inner
53                .next()
54                .unwrap()
55                .into_inner()
56                .next()
57                .ok_or_else(|| ModelError::from_str("Missing skip depth"))
58                .and_then(|pair| assert_rule(pair, Rule::int))?;
59            depth_pair
60                .as_str()
61                .parse()
62                .map_err(|err| ModelError(format!("Error parsing skip depth: {:?}", err)))?
63        } else {
64            0
65        };
66        let name = inner
67            .next()
68            .ok_or_else(|| ModelError::from_str("Missing rule name"))
69            .and_then(|pair| assert_rule(pair, Rule::identifier))
70            .map(|pair| pair.as_str().to_owned())?;
71        let expr = match inner.next() {
72            None => Self::Terminal { name, value: None },
73            Some(pair) => match pair.as_rule() {
74                Rule::sub_expressions => {
75                    let children: Result<Vec<Expression>, ModelError> =
76                        pair.into_inner().map(Self::try_from_sexpr).collect();
77                    Self::NonTerminal {
78                        name,
79                        children: children?,
80                    }
81                }
82                Rule::string => {
83                    let s = pair.as_str().trim();
84                    let value = Some(unescape(s).map_err(|err| {
85                        ModelError(format!("Error unescaping string value {}: {:?}", s, err))
86                    })?);
87                    Self::Terminal { name, value }
88                }
89                other => return Err(ModelError(format!("Unexpected rule {:?}", other))),
90            },
91        };
92        if skip_depth == 0 {
93            Ok(expr)
94        } else {
95            Ok(Self::Skip {
96                depth: skip_depth,
97                next: Box::new(expr),
98            })
99        }
100    }
101
102    pub fn try_from_code<R: RuleType>(
103        pair: Pair<'_, R>,
104        skip_rules: &HashSet<R>,
105    ) -> Result<Self, ModelError> {
106        let name = format!("{:?}", pair.as_rule());
107        let value = pair.as_str();
108        let children: Result<Vec<Expression>, ModelError> = pair
109            .into_inner()
110            .filter(|pair| !skip_rules.contains(&pair.as_rule()))
111            .map(|pair| Self::try_from_code(pair, skip_rules))
112            .collect();
113        match children {
114            Ok(children) if children.is_empty() => Ok(Self::Terminal {
115                name,
116                value: Some(value.to_owned()),
117            }),
118            Ok(children) => Ok(Self::NonTerminal { name, children }),
119            Err(e) => Err(e),
120        }
121    }
122
123    pub fn name(&self) -> &String {
124        match self {
125            Self::Terminal { name, value: _ } => name,
126            Self::NonTerminal { name, children: _ } => name,
127            Self::Skip { depth: _, next } => next.name(),
128        }
129    }
130
131    pub fn skip_depth(&self) -> usize {
132        match self {
133            Expression::Skip { depth, next: _ } => *depth,
134            _ => 0,
135        }
136    }
137
138    /// Returns the `Nth` descendant of this expression, where `N = depth`. For a
139    /// `NonTerminal` expression, the descendant is its first child. For a `Terminal` node, there
140    /// is no descendant.
141    pub fn get_descendant(&self, depth: usize) -> Option<&Expression> {
142        if depth > 0 {
143            match self {
144                Self::NonTerminal { name: _, children } if !children.is_empty() => {
145                    children.first().unwrap().get_descendant(depth - 1)
146                }
147                Self::Skip {
148                    depth: skip_depth,
149                    next,
150                } if *skip_depth <= depth => next.as_ref().get_descendant(depth - skip_depth),
151                _ => None,
152            }
153        } else {
154            Some(self)
155        }
156    }
157}
158
159pub struct ExpressionFormatter<'a> {
160    writer: &'a mut dyn Write,
161    indent: &'a str,
162    pub(crate) level: usize,
163    pub(crate) color: Option<Color>,
164    buffering: bool,
165}
166
167impl<'a> ExpressionFormatter<'a> {
168    pub fn from_defaults(writer: &'a mut dyn Write) -> Self {
169        Self {
170            writer,
171            indent: "  ",
172            level: 0,
173            color: None,
174            buffering: true,
175        }
176    }
177
178    pub(crate) fn write_indent(&mut self) -> FmtResult {
179        for _ in 0..self.level {
180            self.writer.write_str(self.indent)?;
181        }
182        Ok(())
183    }
184
185    pub(crate) fn write_newline(&mut self) -> FmtResult {
186        self.writer.write_char('\n')
187    }
188
189    pub(crate) fn write_char(&mut self, c: char) -> FmtResult {
190        match self.color {
191            Some(color) => self
192                .writer
193                .write_str(format!("{}", c.to_string().color(color)).as_ref()),
194            None => self.writer.write_char(c),
195        }
196    }
197
198    pub(crate) fn write_str(&mut self, s: &str) -> FmtResult {
199        match self.color {
200            Some(color) => self
201                .writer
202                .write_str(format!("{}", s.color(color)).as_ref()),
203            None => self.writer.write_str(s),
204        }
205    }
206
207    fn fmt_buffered(&mut self, expression: &Expression) -> FmtResult {
208        let mut buf = String::with_capacity(1024);
209        let mut string_formatter = ExpressionFormatter {
210            writer: &mut buf,
211            indent: self.indent,
212            level: self.level,
213            color: None,
214            buffering: false,
215        };
216        string_formatter.fmt(expression)?;
217        self.write_str(buf.as_ref())?;
218        Ok(())
219    }
220
221    fn fmt_unbuffered(&mut self, expression: &Expression) -> FmtResult {
222        self.write_indent()?;
223        match expression {
224            Expression::Terminal { name, value } => {
225                self.write_char('(')?;
226                self.write_str(name)?;
227                if let Some(value) = value {
228                    self.write_str(": \"")?;
229                    self.write_str(&value.escape_default().to_string())?;
230                    self.write_char('"')?;
231                }
232                self.write_char(')')?;
233            }
234            Expression::NonTerminal { name, children } if children.is_empty() => {
235                self.write_char('(')?;
236                self.write_str(name)?;
237                self.write_char(')')?;
238            }
239            Expression::NonTerminal { name, children } => {
240                self.write_char('(')?;
241                self.write_str(name)?;
242                self.write_newline()?;
243                self.level += 1;
244                for child in children {
245                    self.fmt(child)?;
246                    self.write_newline()?;
247                }
248                self.level -= 1;
249                self.write_indent()?;
250                self.write_char(')')?;
251            }
252            Expression::Skip { depth, next } => {
253                self.write_str(format!("#[skip(depth = {})]", depth).as_ref())?;
254                self.write_newline()?;
255                self.fmt_unbuffered(next.as_ref())?;
256            }
257        }
258        Ok(())
259    }
260
261    pub fn fmt(&mut self, expression: &Expression) -> FmtResult {
262        if self.buffering {
263            self.fmt_buffered(expression)
264        } else {
265            self.fmt_unbuffered(expression)
266        }
267    }
268}
269
270impl Display for Expression {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> FmtResult {
272        ExpressionFormatter::from_defaults(f).fmt(self)
273    }
274}
275
276#[derive(Clone, Debug)]
277pub struct TestCase {
278    pub name: String,
279    pub code: String,
280    pub expression: Expression,
281}
282
283impl TestCase {
284    pub fn try_from_pair(pair: Pair<'_, Rule>) -> Result<Self, ModelError> {
285        let mut inner = pair.into_inner();
286        let name = inner
287            .next()
288            .ok_or_else(|| ModelError::from_str("Missing test name"))
289            .and_then(|pair| assert_rule(pair, Rule::test_name))
290            .map(|pair| pair.as_str().trim().to_owned())?;
291        let mut code_block = inner
292            .next()
293            .ok_or_else(|| ModelError::from_str("Missing code block"))
294            .and_then(|pair| assert_rule(pair, Rule::code_block))
295            .map(|pair| pair.into_inner())?;
296        code_block
297            .next()
298            .ok_or_else(|| ModelError::from_str("Missing div"))
299            .and_then(|pair| assert_rule(pair, Rule::div))?;
300        let code_untrimmed = code_block
301            .next()
302            .ok_or_else(|| ModelError::from_str("Missing code"))
303            .and_then(|pair| assert_rule(pair, Rule::code))
304            .map(|pair| pair.as_str())?;
305        // The code must start and end with at least one line separator - remove first and last
306        let code_len = code_untrimmed.len();
307        assert!(code_len >= 2);
308        let mut code_chars = code_untrimmed.chars();
309        let code_start: usize = match code_chars.next() {
310            Some('\n') => 1,
311            Some('\r') => match code_chars.next() {
312                Some('\n') if code_len > 2 => 2,
313                _ => 1,
314            },
315            _ => {
316                return Err(ModelError::from_str(
317                    "Code block must be preceeded by at least one line separator",
318                ))
319            }
320        };
321        let mut code_chars = code_untrimmed.chars().rev();
322        let code_end: usize = code_len
323            - match code_chars.next() {
324                Some('\r') => 1,
325                Some('\n') => match code_chars.next() {
326                    Some('\r') if code_len - code_start > 2 => 2,
327                    _ => 1,
328                },
329                _ => {
330                    return Err(ModelError::from_str(
331                        "Code block must be followed by at least one line separator",
332                    ))
333                }
334            };
335        let code = code_untrimmed[code_start..code_end].to_owned();
336        let expression = inner
337            .next()
338            .ok_or_else(|| ModelError::from_str("Missing expression"))
339            .and_then(|pair| assert_rule(pair, Rule::expression))?;
340        Ok(TestCase {
341            name,
342            code,
343            expression: Expression::try_from_sexpr(expression)?,
344        })
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::{Expression, ExpressionFormatter, TestCase};
351    use crate::{
352        parser::{Rule, TestParser},
353        TestError,
354    };
355    use indoc::indoc;
356    use std::collections::HashSet;
357
358    fn assert_nonterminal<'a>(
359        expression: &'a Expression,
360        expected_name: &str,
361    ) -> &'a Vec<Expression> {
362        match expression {
363            Expression::NonTerminal { name, children } => {
364                assert_eq!(name, expected_name);
365                children
366            }
367            _ => panic!("Expected non-terminal expression but found {expression:?}"),
368        }
369    }
370
371    fn assert_skip<'a>(expression: &'a Expression, expected_depth: usize) -> &'a Box<Expression> {
372        match expression {
373            Expression::Skip { depth, next } => {
374                assert_eq!(expected_depth, *depth);
375                next
376            }
377            _ => panic!("Expected skip expression but found {expression:?}"),
378        }
379    }
380
381    fn assert_terminal(expression: &Expression, expected_name: &str, expected_value: Option<&str>) {
382        match expression {
383            Expression::Terminal { name, value } => {
384                assert_eq!(name, expected_name);
385                match (value, expected_value) {
386                    (Some(actual), Some(expected)) => assert_eq!(actual.trim(), expected),
387                    (Some(actual), None) => {
388                        panic!("Terminal node has value {actual} but there is no expected value")
389                    }
390                    (None, Some(expected)) => {
391                        panic!("Terminal node has no value but expected {expected}")
392                    }
393                    _ => (),
394                }
395            }
396            _ => panic!("Expected terminal expression but found {expression:?}"),
397        }
398    }
399
400    fn assert_nonterminal_sexpr<'a>(
401        expression: &'a Expression,
402        expected_name: &str,
403    ) -> &'a Vec<Expression> {
404        let children = assert_nonterminal(expression, "expression");
405        assert_eq!(children.len(), 2);
406        assert_terminal(&children[0], "identifier", Some(expected_name));
407        assert_nonterminal(&children[1], "sub_expressions")
408    }
409
410    fn assert_terminal_sexpr(
411        expression: &Expression,
412        expected_name: &str,
413        expected_value: Option<&str>,
414    ) {
415        let children = assert_nonterminal(expression, "expression");
416        assert!(children.len() >= 1);
417        assert_terminal(&children[0], "identifier", Some(expected_name));
418        if expected_value.is_some() {
419            assert_eq!(children.len(), 2);
420            let value = assert_nonterminal(&children[1], "string");
421            assert_eq!(value.len(), 1);
422            assert_terminal(&value[0], "string_value", expected_value);
423        }
424    }
425
426    const WITH_QUOTE: &str = indoc! {r#"
427    Quoted
428    ======
429
430    x = "hi"
431    
432    ======
433
434    (source_file
435        (declaration
436            (identifier: "x")
437            (value: "\"hi\"")
438        )
439    )
440    "#};
441
442    #[test]
443    fn test_quoted_value() -> Result<(), TestError<Rule>> {
444        let test_case: TestCase = TestParser::parse(WITH_QUOTE)
445            .map_err(|source| TestError::Parser { source })
446            .and_then(|pair| {
447                TestCase::try_from_pair(pair).map_err(|source| TestError::Model { source })
448            })?;
449        let expression = test_case.expression;
450        let children = assert_nonterminal(&expression, "source_file");
451        assert_eq!(children.len(), 1);
452        let children = assert_nonterminal(&children[0], "declaration");
453        assert_eq!(children.len(), 2);
454        assert_terminal(&children[0], "identifier", Some("x"));
455        assert_terminal(&children[1], "value", Some("\"hi\""));
456        Ok(())
457    }
458
459    const BLANK_LINES: &str = indoc! {r#"
460
461
462"#};
463
464    #[test]
465    fn test_escape_whitespace() -> Result<(), TestError<Rule>> {
466        let mut writer = String::new();
467        let mut formatter = ExpressionFormatter::from_defaults(&mut writer);
468        let expression = Expression::Terminal {
469            name: "blank_lines".to_string(),
470            value: Some(BLANK_LINES.to_string()),
471        };
472        formatter
473            .fmt(&expression)
474            .expect("Error formatting expression");
475        let expected = r#"(blank_lines: "\n\n")"#;
476        assert_eq!(writer, expected);
477        Ok(())
478    }
479
480    const TEXT: &str = indoc! {r#"
481    My Test
482
483    =======
484
485    fn x() int {
486      return 1;
487    }
488
489    =======
490    
491    (source_file
492      (function_definition
493        (identifier: "x")
494        (parameter_list)
495        (primitive_type: "int")
496        (block
497          (return_statement 
498            (number: "1")
499          )
500        )
501      )
502    )
503    "#};
504
505    #[test]
506    fn test_parse_from_code() -> Result<(), TestError<Rule>> {
507        let test_pair = TestParser::parse(TEXT).map_err(|source| TestError::Parser { source })?;
508        let skip_rules = HashSet::from([Rule::EOI]);
509        let code_expression = Expression::try_from_code(test_pair, &skip_rules)
510            .map_err(|source| TestError::Model { source })?;
511        let children = assert_nonterminal(&code_expression, "test_case");
512        assert_eq!(children.len(), 3);
513        assert_terminal(&children[0], "test_name", Some("My Test"));
514        let code_block = assert_nonterminal(&children[1], "code_block");
515        assert_eq!(code_block.len(), 2);
516        assert_terminal(&code_block[0], "div", Some("======="));
517        assert_terminal(&code_block[1], "code", Some("fn x() int {\n  return 1;\n}"));
518        let s_expression = assert_nonterminal_sexpr(&children[2], "source_file");
519        assert_eq!(s_expression.len(), 1);
520        let s_expression = assert_nonterminal_sexpr(&s_expression[0], "function_definition");
521        assert_eq!(s_expression.len(), 4);
522        assert_terminal_sexpr(&s_expression[0], "identifier", Some("x"));
523        assert_terminal_sexpr(&s_expression[1], "parameter_list", None);
524        assert_terminal_sexpr(&s_expression[2], "primitive_type", Some("int"));
525        let s_expression = assert_nonterminal_sexpr(&s_expression[3], "block");
526        assert_eq!(s_expression.len(), 1);
527        let s_expression = assert_nonterminal_sexpr(&s_expression[0], "return_statement");
528        assert_eq!(s_expression.len(), 1);
529        assert_terminal_sexpr(&s_expression[0], "number", Some("1"));
530        Ok(())
531    }
532
533    const TEXT_WITH_SKIP: &str = indoc! {r#"
534    My Test
535
536    =======
537
538    fn x() int {
539      return 1;
540    }
541
542    =======
543    
544    (source_file
545      (function_definition
546        (identifier: "x")
547        (parameter_list)
548        (primitive_type: "int")
549        (block
550          #[skip(depth = 1)]
551          (return_statement 
552            (number: "1")
553          )
554        )
555      )
556    )
557    "#};
558
559    #[test]
560    fn test_parse() -> Result<(), TestError<Rule>> {
561        let test_case: TestCase = TestParser::parse(TEXT_WITH_SKIP)
562            .map_err(|source| TestError::Parser { source })
563            .and_then(|pair| {
564                TestCase::try_from_pair(pair).map_err(|source| TestError::Model { source })
565            })?;
566        assert_eq!(test_case.name, "My Test");
567        assert_eq!(test_case.code, "\nfn x() int {\n  return 1;\n}\n");
568        let expression = test_case.expression;
569        let children = assert_nonterminal(&expression, "source_file");
570        assert_eq!(children.len(), 1);
571        let children = assert_nonterminal(&children[0], "function_definition");
572        assert_eq!(children.len(), 4);
573        assert_terminal(&children[0], "identifier", Some("x"));
574        assert_terminal(&children[1], "parameter_list", None);
575        assert_terminal(&children[2], "primitive_type", Some("int"));
576        let children = assert_nonterminal(&children[3], "block");
577        assert_eq!(children.len(), 1);
578        let next = assert_skip(&children[0], 1);
579        let children = assert_nonterminal(&next, "return_statement");
580        assert_eq!(children.len(), 1);
581        assert_terminal(&children[0], "number", Some("1"));
582        Ok(())
583    }
584
585    #[test]
586    fn test_format() -> Result<(), TestError<Rule>> {
587        let mut writer = String::new();
588        let mut formatter = ExpressionFormatter::from_defaults(&mut writer);
589        let test_case: TestCase = TestParser::parse(TEXT_WITH_SKIP)
590            .map_err(|source| TestError::Parser { source })
591            .and_then(|pair| {
592                TestCase::try_from_pair(pair).map_err(|source| TestError::Model { source })
593            })?;
594        formatter
595            .fmt(&test_case.expression)
596            .expect("Error formatting expression");
597        let expected = indoc! {r#"
598        (source_file
599          (function_definition
600            (identifier: "x")
601            (parameter_list)
602            (primitive_type: "int")
603            (block
604              #[skip(depth = 1)]
605              (return_statement
606                (number: "1")
607              )
608            )
609          )
610        )"#};
611        assert_eq!(writer, expected);
612        Ok(())
613    }
614}