1use crate::parser::Rule;
2use colored::{Color, Colorize};
3use pest::{iterators::Pair, RuleType};
4use snailquote::unescape;
5use std::{
6 collections::HashSet,
7 fmt::{Display, Result as FmtResult, Write},
8};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12#[error("Error creating model element from parser pair")]
13pub struct ModelError(String);
14
15impl ModelError {
16 fn from_str(msg: &str) -> Self {
17 Self(msg.to_owned())
18 }
19}
20
21fn assert_rule(pair: Pair<'_, Rule>, rule: Rule) -> Result<Pair<'_, Rule>, ModelError> {
22 if pair.as_rule() == rule {
23 Ok(pair)
24 } else {
25 Err(ModelError(format!(
26 "Expected pair {:?} rule to be {:?}",
27 pair, rule
28 )))
29 }
30}
31
32#[derive(Clone, Debug)]
33pub enum Expression {
34 Terminal {
35 name: String,
36 value: Option<String>,
37 },
38 NonTerminal {
39 name: String,
40 children: Vec<Expression>,
41 },
42 Skip {
43 depth: usize,
44 next: Box<Expression>,
45 },
46}
47
48impl Expression {
49 pub fn try_from_sexpr(pair: Pair<'_, Rule>) -> Result<Self, ModelError> {
50 let mut inner = pair.into_inner();
51 let skip_depth: usize = if inner.peek().map(|pair| pair.as_rule()) == Some(Rule::skip) {
52 let depth_pair = inner
53 .next()
54 .unwrap()
55 .into_inner()
56 .next()
57 .ok_or_else(|| ModelError::from_str("Missing skip depth"))
58 .and_then(|pair| assert_rule(pair, Rule::int))?;
59 depth_pair
60 .as_str()
61 .parse()
62 .map_err(|err| ModelError(format!("Error parsing skip depth: {:?}", err)))?
63 } else {
64 0
65 };
66 let name = inner
67 .next()
68 .ok_or_else(|| ModelError::from_str("Missing rule name"))
69 .and_then(|pair| assert_rule(pair, Rule::identifier))
70 .map(|pair| pair.as_str().to_owned())?;
71 let expr = match inner.next() {
72 None => Self::Terminal { name, value: None },
73 Some(pair) => match pair.as_rule() {
74 Rule::sub_expressions => {
75 let children: Result<Vec<Expression>, ModelError> =
76 pair.into_inner().map(Self::try_from_sexpr).collect();
77 Self::NonTerminal {
78 name,
79 children: children?,
80 }
81 }
82 Rule::string => {
83 let s = pair.as_str().trim();
84 let value = Some(unescape(s).map_err(|err| {
85 ModelError(format!("Error unescaping string value {}: {:?}", s, err))
86 })?);
87 Self::Terminal { name, value }
88 }
89 other => return Err(ModelError(format!("Unexpected rule {:?}", other))),
90 },
91 };
92 if skip_depth == 0 {
93 Ok(expr)
94 } else {
95 Ok(Self::Skip {
96 depth: skip_depth,
97 next: Box::new(expr),
98 })
99 }
100 }
101
102 pub fn try_from_code<R: RuleType>(
103 pair: Pair<'_, R>,
104 skip_rules: &HashSet<R>,
105 ) -> Result<Self, ModelError> {
106 let name = format!("{:?}", pair.as_rule());
107 let value = pair.as_str();
108 let children: Result<Vec<Expression>, ModelError> = pair
109 .into_inner()
110 .filter(|pair| !skip_rules.contains(&pair.as_rule()))
111 .map(|pair| Self::try_from_code(pair, skip_rules))
112 .collect();
113 match children {
114 Ok(children) if children.is_empty() => Ok(Self::Terminal {
115 name,
116 value: Some(value.to_owned()),
117 }),
118 Ok(children) => Ok(Self::NonTerminal { name, children }),
119 Err(e) => Err(e),
120 }
121 }
122
123 pub fn name(&self) -> &String {
124 match self {
125 Self::Terminal { name, value: _ } => name,
126 Self::NonTerminal { name, children: _ } => name,
127 Self::Skip { depth: _, next } => next.name(),
128 }
129 }
130
131 pub fn skip_depth(&self) -> usize {
132 match self {
133 Expression::Skip { depth, next: _ } => *depth,
134 _ => 0,
135 }
136 }
137
138 pub fn get_descendant(&self, depth: usize) -> Option<&Expression> {
142 if depth > 0 {
143 match self {
144 Self::NonTerminal { name: _, children } if !children.is_empty() => {
145 children.first().unwrap().get_descendant(depth - 1)
146 }
147 Self::Skip {
148 depth: skip_depth,
149 next,
150 } if *skip_depth <= depth => next.as_ref().get_descendant(depth - skip_depth),
151 _ => None,
152 }
153 } else {
154 Some(self)
155 }
156 }
157}
158
159pub struct ExpressionFormatter<'a> {
160 writer: &'a mut dyn Write,
161 indent: &'a str,
162 pub(crate) level: usize,
163 pub(crate) color: Option<Color>,
164 buffering: bool,
165}
166
167impl<'a> ExpressionFormatter<'a> {
168 pub fn from_defaults(writer: &'a mut dyn Write) -> Self {
169 Self {
170 writer,
171 indent: " ",
172 level: 0,
173 color: None,
174 buffering: true,
175 }
176 }
177
178 pub(crate) fn write_indent(&mut self) -> FmtResult {
179 for _ in 0..self.level {
180 self.writer.write_str(self.indent)?;
181 }
182 Ok(())
183 }
184
185 pub(crate) fn write_newline(&mut self) -> FmtResult {
186 self.writer.write_char('\n')
187 }
188
189 pub(crate) fn write_char(&mut self, c: char) -> FmtResult {
190 match self.color {
191 Some(color) => self
192 .writer
193 .write_str(format!("{}", c.to_string().color(color)).as_ref()),
194 None => self.writer.write_char(c),
195 }
196 }
197
198 pub(crate) fn write_str(&mut self, s: &str) -> FmtResult {
199 match self.color {
200 Some(color) => self
201 .writer
202 .write_str(format!("{}", s.color(color)).as_ref()),
203 None => self.writer.write_str(s),
204 }
205 }
206
207 fn fmt_buffered(&mut self, expression: &Expression) -> FmtResult {
208 let mut buf = String::with_capacity(1024);
209 let mut string_formatter = ExpressionFormatter {
210 writer: &mut buf,
211 indent: self.indent,
212 level: self.level,
213 color: None,
214 buffering: false,
215 };
216 string_formatter.fmt(expression)?;
217 self.write_str(buf.as_ref())?;
218 Ok(())
219 }
220
221 fn fmt_unbuffered(&mut self, expression: &Expression) -> FmtResult {
222 self.write_indent()?;
223 match expression {
224 Expression::Terminal { name, value } => {
225 self.write_char('(')?;
226 self.write_str(name)?;
227 if let Some(value) = value {
228 self.write_str(": \"")?;
229 self.write_str(&value.escape_default().to_string())?;
230 self.write_char('"')?;
231 }
232 self.write_char(')')?;
233 }
234 Expression::NonTerminal { name, children } if children.is_empty() => {
235 self.write_char('(')?;
236 self.write_str(name)?;
237 self.write_char(')')?;
238 }
239 Expression::NonTerminal { name, children } => {
240 self.write_char('(')?;
241 self.write_str(name)?;
242 self.write_newline()?;
243 self.level += 1;
244 for child in children {
245 self.fmt(child)?;
246 self.write_newline()?;
247 }
248 self.level -= 1;
249 self.write_indent()?;
250 self.write_char(')')?;
251 }
252 Expression::Skip { depth, next } => {
253 self.write_str(format!("#[skip(depth = {})]", depth).as_ref())?;
254 self.write_newline()?;
255 self.fmt_unbuffered(next.as_ref())?;
256 }
257 }
258 Ok(())
259 }
260
261 pub fn fmt(&mut self, expression: &Expression) -> FmtResult {
262 if self.buffering {
263 self.fmt_buffered(expression)
264 } else {
265 self.fmt_unbuffered(expression)
266 }
267 }
268}
269
270impl Display for Expression {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> FmtResult {
272 ExpressionFormatter::from_defaults(f).fmt(self)
273 }
274}
275
276#[derive(Clone, Debug)]
277pub struct TestCase {
278 pub name: String,
279 pub code: String,
280 pub expression: Expression,
281}
282
283impl TestCase {
284 pub fn try_from_pair(pair: Pair<'_, Rule>) -> Result<Self, ModelError> {
285 let mut inner = pair.into_inner();
286 let name = inner
287 .next()
288 .ok_or_else(|| ModelError::from_str("Missing test name"))
289 .and_then(|pair| assert_rule(pair, Rule::test_name))
290 .map(|pair| pair.as_str().trim().to_owned())?;
291 let mut code_block = inner
292 .next()
293 .ok_or_else(|| ModelError::from_str("Missing code block"))
294 .and_then(|pair| assert_rule(pair, Rule::code_block))
295 .map(|pair| pair.into_inner())?;
296 code_block
297 .next()
298 .ok_or_else(|| ModelError::from_str("Missing div"))
299 .and_then(|pair| assert_rule(pair, Rule::div))?;
300 let code_untrimmed = code_block
301 .next()
302 .ok_or_else(|| ModelError::from_str("Missing code"))
303 .and_then(|pair| assert_rule(pair, Rule::code))
304 .map(|pair| pair.as_str())?;
305 let code_len = code_untrimmed.len();
307 assert!(code_len >= 2);
308 let mut code_chars = code_untrimmed.chars();
309 let code_start: usize = match code_chars.next() {
310 Some('\n') => 1,
311 Some('\r') => match code_chars.next() {
312 Some('\n') if code_len > 2 => 2,
313 _ => 1,
314 },
315 _ => {
316 return Err(ModelError::from_str(
317 "Code block must be preceeded by at least one line separator",
318 ))
319 }
320 };
321 let mut code_chars = code_untrimmed.chars().rev();
322 let code_end: usize = code_len
323 - match code_chars.next() {
324 Some('\r') => 1,
325 Some('\n') => match code_chars.next() {
326 Some('\r') if code_len - code_start > 2 => 2,
327 _ => 1,
328 },
329 _ => {
330 return Err(ModelError::from_str(
331 "Code block must be followed by at least one line separator",
332 ))
333 }
334 };
335 let code = code_untrimmed[code_start..code_end].to_owned();
336 let expression = inner
337 .next()
338 .ok_or_else(|| ModelError::from_str("Missing expression"))
339 .and_then(|pair| assert_rule(pair, Rule::expression))?;
340 Ok(TestCase {
341 name,
342 code,
343 expression: Expression::try_from_sexpr(expression)?,
344 })
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::{Expression, ExpressionFormatter, TestCase};
351 use crate::{
352 parser::{Rule, TestParser},
353 TestError,
354 };
355 use indoc::indoc;
356 use std::collections::HashSet;
357
358 fn assert_nonterminal<'a>(
359 expression: &'a Expression,
360 expected_name: &str,
361 ) -> &'a Vec<Expression> {
362 match expression {
363 Expression::NonTerminal { name, children } => {
364 assert_eq!(name, expected_name);
365 children
366 }
367 _ => panic!("Expected non-terminal expression but found {expression:?}"),
368 }
369 }
370
371 fn assert_skip<'a>(expression: &'a Expression, expected_depth: usize) -> &'a Box<Expression> {
372 match expression {
373 Expression::Skip { depth, next } => {
374 assert_eq!(expected_depth, *depth);
375 next
376 }
377 _ => panic!("Expected skip expression but found {expression:?}"),
378 }
379 }
380
381 fn assert_terminal(expression: &Expression, expected_name: &str, expected_value: Option<&str>) {
382 match expression {
383 Expression::Terminal { name, value } => {
384 assert_eq!(name, expected_name);
385 match (value, expected_value) {
386 (Some(actual), Some(expected)) => assert_eq!(actual.trim(), expected),
387 (Some(actual), None) => {
388 panic!("Terminal node has value {actual} but there is no expected value")
389 }
390 (None, Some(expected)) => {
391 panic!("Terminal node has no value but expected {expected}")
392 }
393 _ => (),
394 }
395 }
396 _ => panic!("Expected terminal expression but found {expression:?}"),
397 }
398 }
399
400 fn assert_nonterminal_sexpr<'a>(
401 expression: &'a Expression,
402 expected_name: &str,
403 ) -> &'a Vec<Expression> {
404 let children = assert_nonterminal(expression, "expression");
405 assert_eq!(children.len(), 2);
406 assert_terminal(&children[0], "identifier", Some(expected_name));
407 assert_nonterminal(&children[1], "sub_expressions")
408 }
409
410 fn assert_terminal_sexpr(
411 expression: &Expression,
412 expected_name: &str,
413 expected_value: Option<&str>,
414 ) {
415 let children = assert_nonterminal(expression, "expression");
416 assert!(children.len() >= 1);
417 assert_terminal(&children[0], "identifier", Some(expected_name));
418 if expected_value.is_some() {
419 assert_eq!(children.len(), 2);
420 let value = assert_nonterminal(&children[1], "string");
421 assert_eq!(value.len(), 1);
422 assert_terminal(&value[0], "string_value", expected_value);
423 }
424 }
425
426 const WITH_QUOTE: &str = indoc! {r#"
427 Quoted
428 ======
429
430 x = "hi"
431
432 ======
433
434 (source_file
435 (declaration
436 (identifier: "x")
437 (value: "\"hi\"")
438 )
439 )
440 "#};
441
442 #[test]
443 fn test_quoted_value() -> Result<(), TestError<Rule>> {
444 let test_case: TestCase = TestParser::parse(WITH_QUOTE)
445 .map_err(|source| TestError::Parser { source })
446 .and_then(|pair| {
447 TestCase::try_from_pair(pair).map_err(|source| TestError::Model { source })
448 })?;
449 let expression = test_case.expression;
450 let children = assert_nonterminal(&expression, "source_file");
451 assert_eq!(children.len(), 1);
452 let children = assert_nonterminal(&children[0], "declaration");
453 assert_eq!(children.len(), 2);
454 assert_terminal(&children[0], "identifier", Some("x"));
455 assert_terminal(&children[1], "value", Some("\"hi\""));
456 Ok(())
457 }
458
459 const BLANK_LINES: &str = indoc! {r#"
460
461
462"#};
463
464 #[test]
465 fn test_escape_whitespace() -> Result<(), TestError<Rule>> {
466 let mut writer = String::new();
467 let mut formatter = ExpressionFormatter::from_defaults(&mut writer);
468 let expression = Expression::Terminal {
469 name: "blank_lines".to_string(),
470 value: Some(BLANK_LINES.to_string()),
471 };
472 formatter
473 .fmt(&expression)
474 .expect("Error formatting expression");
475 let expected = r#"(blank_lines: "\n\n")"#;
476 assert_eq!(writer, expected);
477 Ok(())
478 }
479
480 const TEXT: &str = indoc! {r#"
481 My Test
482
483 =======
484
485 fn x() int {
486 return 1;
487 }
488
489 =======
490
491 (source_file
492 (function_definition
493 (identifier: "x")
494 (parameter_list)
495 (primitive_type: "int")
496 (block
497 (return_statement
498 (number: "1")
499 )
500 )
501 )
502 )
503 "#};
504
505 #[test]
506 fn test_parse_from_code() -> Result<(), TestError<Rule>> {
507 let test_pair = TestParser::parse(TEXT).map_err(|source| TestError::Parser { source })?;
508 let skip_rules = HashSet::from([Rule::EOI]);
509 let code_expression = Expression::try_from_code(test_pair, &skip_rules)
510 .map_err(|source| TestError::Model { source })?;
511 let children = assert_nonterminal(&code_expression, "test_case");
512 assert_eq!(children.len(), 3);
513 assert_terminal(&children[0], "test_name", Some("My Test"));
514 let code_block = assert_nonterminal(&children[1], "code_block");
515 assert_eq!(code_block.len(), 2);
516 assert_terminal(&code_block[0], "div", Some("======="));
517 assert_terminal(&code_block[1], "code", Some("fn x() int {\n return 1;\n}"));
518 let s_expression = assert_nonterminal_sexpr(&children[2], "source_file");
519 assert_eq!(s_expression.len(), 1);
520 let s_expression = assert_nonterminal_sexpr(&s_expression[0], "function_definition");
521 assert_eq!(s_expression.len(), 4);
522 assert_terminal_sexpr(&s_expression[0], "identifier", Some("x"));
523 assert_terminal_sexpr(&s_expression[1], "parameter_list", None);
524 assert_terminal_sexpr(&s_expression[2], "primitive_type", Some("int"));
525 let s_expression = assert_nonterminal_sexpr(&s_expression[3], "block");
526 assert_eq!(s_expression.len(), 1);
527 let s_expression = assert_nonterminal_sexpr(&s_expression[0], "return_statement");
528 assert_eq!(s_expression.len(), 1);
529 assert_terminal_sexpr(&s_expression[0], "number", Some("1"));
530 Ok(())
531 }
532
533 const TEXT_WITH_SKIP: &str = indoc! {r#"
534 My Test
535
536 =======
537
538 fn x() int {
539 return 1;
540 }
541
542 =======
543
544 (source_file
545 (function_definition
546 (identifier: "x")
547 (parameter_list)
548 (primitive_type: "int")
549 (block
550 #[skip(depth = 1)]
551 (return_statement
552 (number: "1")
553 )
554 )
555 )
556 )
557 "#};
558
559 #[test]
560 fn test_parse() -> Result<(), TestError<Rule>> {
561 let test_case: TestCase = TestParser::parse(TEXT_WITH_SKIP)
562 .map_err(|source| TestError::Parser { source })
563 .and_then(|pair| {
564 TestCase::try_from_pair(pair).map_err(|source| TestError::Model { source })
565 })?;
566 assert_eq!(test_case.name, "My Test");
567 assert_eq!(test_case.code, "\nfn x() int {\n return 1;\n}\n");
568 let expression = test_case.expression;
569 let children = assert_nonterminal(&expression, "source_file");
570 assert_eq!(children.len(), 1);
571 let children = assert_nonterminal(&children[0], "function_definition");
572 assert_eq!(children.len(), 4);
573 assert_terminal(&children[0], "identifier", Some("x"));
574 assert_terminal(&children[1], "parameter_list", None);
575 assert_terminal(&children[2], "primitive_type", Some("int"));
576 let children = assert_nonterminal(&children[3], "block");
577 assert_eq!(children.len(), 1);
578 let next = assert_skip(&children[0], 1);
579 let children = assert_nonterminal(&next, "return_statement");
580 assert_eq!(children.len(), 1);
581 assert_terminal(&children[0], "number", Some("1"));
582 Ok(())
583 }
584
585 #[test]
586 fn test_format() -> Result<(), TestError<Rule>> {
587 let mut writer = String::new();
588 let mut formatter = ExpressionFormatter::from_defaults(&mut writer);
589 let test_case: TestCase = TestParser::parse(TEXT_WITH_SKIP)
590 .map_err(|source| TestError::Parser { source })
591 .and_then(|pair| {
592 TestCase::try_from_pair(pair).map_err(|source| TestError::Model { source })
593 })?;
594 formatter
595 .fmt(&test_case.expression)
596 .expect("Error formatting expression");
597 let expected = indoc! {r#"
598 (source_file
599 (function_definition
600 (identifier: "x")
601 (parameter_list)
602 (primitive_type: "int")
603 (block
604 #[skip(depth = 1)]
605 (return_statement
606 (number: "1")
607 )
608 )
609 )
610 )"#};
611 assert_eq!(writer, expected);
612 Ok(())
613 }
614}