1use std::fmt::Write;
11use std::sync::Arc;
12
13use panproto_expr::{BuiltinOp, Expr, Literal, Pattern};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
19enum Prec {
20 Top = 0,
22 Pipe = 1,
24 Or = 3,
26 And = 4,
28 Cmp = 5,
30 Concat = 6,
32 AddSub = 7,
34 MulDiv = 8,
36 Unary = 9,
38 App = 10,
40 Atom = 11,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46enum Assoc {
47 Left,
48 Right,
49}
50
51#[must_use]
69pub fn pretty_print(expr: &Expr) -> String {
70 let mut buf = String::new();
71 write_expr(&mut buf, expr, Prec::Top);
72 buf
73}
74
75fn write_expr(buf: &mut String, expr: &Expr, ctx: Prec) {
79 match expr {
80 Expr::Var(name) => buf.push_str(name),
81
82 Expr::Lit(lit) => write_literal(buf, lit),
83
84 Expr::Lam(param, body) => {
85 let needs_parens = ctx > Prec::Top;
86 if needs_parens {
87 buf.push('(');
88 }
89 write_lambda_chain(buf, param, body);
90 if needs_parens {
91 buf.push(')');
92 }
93 }
94
95 Expr::App(func, arg) => {
96 write_app(buf, expr, ctx);
97 let _ = (func, arg); }
99
100 Expr::Record(fields) => {
101 write_record_expr(buf, fields);
102 }
103
104 Expr::List(items) => {
105 buf.push('[');
106 for (i, item) in items.iter().enumerate() {
107 if i > 0 {
108 buf.push_str(", ");
109 }
110 write_expr(buf, item, Prec::Top);
111 }
112 buf.push(']');
113 }
114
115 Expr::Field(inner, name) => {
116 write_expr(buf, inner, Prec::Atom);
117 buf.push('.');
118 buf.push_str(name);
119 }
120
121 Expr::Index(inner, idx) => {
122 write_expr(buf, inner, Prec::Atom);
123 buf.push('[');
124 write_expr(buf, idx, Prec::Top);
125 buf.push(']');
126 }
127
128 Expr::Match { scrutinee, arms } => {
129 write_match(buf, scrutinee, arms, ctx);
130 }
131
132 Expr::Let { name, value, body } => {
133 write_let(buf, name, value, body, ctx);
134 }
135
136 Expr::Builtin(op, args) => {
137 write_builtin(buf, *op, args, ctx);
138 }
139 }
140}
141
142fn write_lambda_chain(buf: &mut String, first_param: &Arc<str>, first_body: &Expr) {
144 buf.push('\\');
145 buf.push_str(first_param);
146 let mut body = first_body;
147 while let Expr::Lam(param, inner) = body {
148 buf.push(' ');
149 buf.push_str(param);
150 body = inner;
151 }
152 buf.push_str(" -> ");
153 write_expr(buf, body, Prec::Top);
154}
155
156fn write_app(buf: &mut String, expr: &Expr, ctx: Prec) {
158 let needs_parens = ctx > Prec::App;
159 if needs_parens {
160 buf.push('(');
161 }
162
163 let mut spine: Vec<&Expr> = Vec::new();
165 let mut head = expr;
166 while let Expr::App(func, arg) = head {
167 spine.push(arg);
168 head = func;
169 }
170 spine.reverse();
171
172 write_expr(buf, head, Prec::App);
173 for arg in &spine {
174 buf.push(' ');
175 write_expr(buf, arg, Prec::Atom);
176 }
177
178 if needs_parens {
179 buf.push(')');
180 }
181}
182
183fn write_record_expr(buf: &mut String, fields: &[(Arc<str>, Expr)]) {
185 buf.push_str("{ ");
186 for (i, (name, val)) in fields.iter().enumerate() {
187 if i > 0 {
188 buf.push_str(", ");
189 }
190 if let Expr::Var(v) = val {
192 if v == name {
193 buf.push_str(name);
194 continue;
195 }
196 }
197 buf.push_str(name);
198 buf.push_str(" = ");
199 write_expr(buf, val, Prec::Top);
200 }
201 buf.push_str(" }");
202}
203
204fn write_match(buf: &mut String, scrutinee: &Expr, arms: &[(Pattern, Expr)], ctx: Prec) {
209 if arms.len() == 2 {
211 if let (Pattern::Lit(Literal::Bool(true)), then_branch) = &arms[0] {
212 if let (Pattern::Wildcard, else_branch) = &arms[1] {
213 let needs_parens = ctx > Prec::Top;
214 if needs_parens {
215 buf.push('(');
216 }
217 buf.push_str("if ");
218 write_expr(buf, scrutinee, Prec::Top);
219 buf.push_str(" then ");
220 write_expr(buf, then_branch, Prec::Top);
221 buf.push_str(" else ");
222 write_expr(buf, else_branch, Prec::Top);
223 if needs_parens {
224 buf.push(')');
225 }
226 return;
227 }
228 }
229 }
230
231 let needs_parens = ctx > Prec::Top;
232 if needs_parens {
233 buf.push('(');
234 }
235 buf.push_str("case ");
236 write_expr(buf, scrutinee, Prec::Top);
237 buf.push_str(" of\n");
238 for (i, (pat, body)) in arms.iter().enumerate() {
239 if i > 0 {
240 buf.push('\n');
241 }
242 buf.push_str(" ");
243 write_pattern(buf, pat);
244 buf.push_str(" -> ");
245 write_expr(buf, body, Prec::Top);
246 }
247 if needs_parens {
248 buf.push(')');
249 }
250}
251
252fn write_let(buf: &mut String, name: &Arc<str>, value: &Expr, body: &Expr, ctx: Prec) {
254 let needs_parens = ctx > Prec::Top;
255 if needs_parens {
256 buf.push('(');
257 }
258
259 let mut bindings: Vec<(&Arc<str>, &Expr)> = vec![(name, value)];
261 let mut final_body = body;
262 while let Expr::Let {
263 name: n,
264 value: v,
265 body: b,
266 } = final_body
267 {
268 bindings.push((n, v));
269 final_body = b;
270 }
271
272 if bindings.len() == 1 {
273 buf.push_str("let ");
274 buf.push_str(name);
275 buf.push_str(" = ");
276 write_expr(buf, value, Prec::Top);
277 buf.push_str(" in ");
278 } else {
279 buf.push_str("let\n");
280 for (n, v) in &bindings {
281 buf.push_str(" ");
282 buf.push_str(n);
283 buf.push_str(" = ");
284 write_expr(buf, v, Prec::Top);
285 buf.push('\n');
286 }
287 buf.push_str("in ");
288 }
289 write_expr(buf, final_body, Prec::Top);
290
291 if needs_parens {
292 buf.push(')');
293 }
294}
295
296fn write_builtin(buf: &mut String, op: BuiltinOp, args: &[Expr], ctx: Prec) {
298 if let Some((sym, prec, assoc)) = infix_info(op) {
300 if args.len() == 2 {
301 let needs_parens = ctx > prec;
302 if needs_parens {
303 buf.push('(');
304 }
305 let (left_ctx, right_ctx) = match assoc {
309 Assoc::Left => (prec, next_prec(prec)),
310 Assoc::Right => (next_prec(prec), prec),
311 };
312 write_expr(buf, &args[0], left_ctx);
313 buf.push(' ');
314 buf.push_str(sym);
315 buf.push(' ');
316 write_expr(buf, &args[1], right_ctx);
317 if needs_parens {
318 buf.push(')');
319 }
320 return;
321 }
322 }
323
324 if op == BuiltinOp::Edge && args.len() == 2 {
326 if let Expr::Lit(Literal::Str(edge_name)) = &args[1] {
327 let needs_parens = ctx > Prec::Atom;
328 if needs_parens {
329 buf.push('(');
330 }
331 write_expr(buf, &args[0], Prec::Atom);
332 buf.push_str(" -> ");
333 buf.push_str(edge_name);
334 if needs_parens {
335 buf.push(')');
336 }
337 return;
338 }
339 }
340
341 if op == BuiltinOp::Neg && args.len() == 1 {
343 let needs_parens = ctx > Prec::Unary;
344 if needs_parens {
345 buf.push('(');
346 }
347 buf.push('-');
348 write_expr(buf, &args[0], Prec::Atom);
349 if needs_parens {
350 buf.push(')');
351 }
352 return;
353 }
354
355 if op == BuiltinOp::Not && args.len() == 1 {
356 let needs_parens = ctx > Prec::Unary;
357 if needs_parens {
358 buf.push('(');
359 }
360 buf.push_str("not ");
361 write_expr(buf, &args[0], Prec::Atom);
362 if needs_parens {
363 buf.push(')');
364 }
365 return;
366 }
367
368 let needs_parens = ctx > Prec::App && !args.is_empty();
370 if needs_parens {
371 buf.push('(');
372 }
373 buf.push_str(builtin_name(op));
374 for arg in args {
375 buf.push(' ');
376 write_expr(buf, arg, Prec::Atom);
377 }
378 if needs_parens {
379 buf.push(')');
380 }
381}
382
383const fn infix_info(op: BuiltinOp) -> Option<(&'static str, Prec, Assoc)> {
387 match op {
388 BuiltinOp::Or => Some(("||", Prec::Or, Assoc::Left)),
389 BuiltinOp::And => Some(("&&", Prec::And, Assoc::Left)),
390 BuiltinOp::Eq => Some(("==", Prec::Cmp, Assoc::Right)),
391 BuiltinOp::Neq => Some(("/=", Prec::Cmp, Assoc::Right)),
392 BuiltinOp::Lt => Some(("<", Prec::Cmp, Assoc::Right)),
393 BuiltinOp::Lte => Some(("<=", Prec::Cmp, Assoc::Right)),
394 BuiltinOp::Gt => Some((">", Prec::Cmp, Assoc::Right)),
395 BuiltinOp::Gte => Some((">=", Prec::Cmp, Assoc::Right)),
396 BuiltinOp::Concat => Some(("++", Prec::Concat, Assoc::Right)),
397 BuiltinOp::Add => Some(("+", Prec::AddSub, Assoc::Left)),
398 BuiltinOp::Sub => Some(("-", Prec::AddSub, Assoc::Left)),
399 BuiltinOp::Mul => Some(("*", Prec::MulDiv, Assoc::Left)),
400 BuiltinOp::Div => Some(("/", Prec::MulDiv, Assoc::Left)),
401 BuiltinOp::Mod => Some(("%", Prec::MulDiv, Assoc::Left)),
402 _ => None,
403 }
404}
405
406const fn next_prec(p: Prec) -> Prec {
408 match p {
409 Prec::Top => Prec::Pipe,
410 Prec::Pipe => Prec::Or,
411 Prec::Or => Prec::And,
412 Prec::And => Prec::Cmp,
413 Prec::Cmp => Prec::Concat,
414 Prec::Concat => Prec::AddSub,
415 Prec::AddSub => Prec::MulDiv,
416 Prec::MulDiv => Prec::Unary,
417 Prec::Unary => Prec::App,
418 Prec::App | Prec::Atom => Prec::Atom,
419 }
420}
421
422const fn builtin_name(op: BuiltinOp) -> &'static str {
424 match op {
425 BuiltinOp::Add => "add",
426 BuiltinOp::Sub => "sub",
427 BuiltinOp::Mul => "mul",
428 BuiltinOp::Div => "div",
429 BuiltinOp::Mod => "mod",
430 BuiltinOp::Neg => "neg",
431 BuiltinOp::Abs => "abs",
432 BuiltinOp::Floor => "floor",
433 BuiltinOp::Ceil => "ceil",
434 BuiltinOp::Eq => "eq",
435 BuiltinOp::Neq => "neq",
436 BuiltinOp::Lt => "lt",
437 BuiltinOp::Lte => "lte",
438 BuiltinOp::Gt => "gt",
439 BuiltinOp::Gte => "gte",
440 BuiltinOp::And => "and",
441 BuiltinOp::Or => "or",
442 BuiltinOp::Not => "not",
443 BuiltinOp::Concat => "concat",
444 BuiltinOp::Len => "len",
445 BuiltinOp::Slice => "slice",
446 BuiltinOp::Upper => "upper",
447 BuiltinOp::Lower => "lower",
448 BuiltinOp::Trim => "trim",
449 BuiltinOp::Split => "split",
450 BuiltinOp::Join => "join",
451 BuiltinOp::Replace => "replace",
452 BuiltinOp::Contains => "contains",
453 BuiltinOp::Map => "map",
454 BuiltinOp::Filter => "filter",
455 BuiltinOp::Fold => "fold",
456 BuiltinOp::Append => "append",
457 BuiltinOp::Head => "head",
458 BuiltinOp::Tail => "tail",
459 BuiltinOp::Reverse => "reverse",
460 BuiltinOp::FlatMap => "flat_map",
461 BuiltinOp::Length => "length",
462 BuiltinOp::MergeRecords => "merge",
463 BuiltinOp::Keys => "keys",
464 BuiltinOp::Values => "values",
465 BuiltinOp::HasField => "has_field",
466 BuiltinOp::IntToFloat => "int_to_float",
467 BuiltinOp::FloatToInt => "float_to_int",
468 BuiltinOp::IntToStr => "int_to_str",
469 BuiltinOp::FloatToStr => "float_to_str",
470 BuiltinOp::StrToInt => "str_to_int",
471 BuiltinOp::StrToFloat => "str_to_float",
472 BuiltinOp::TypeOf => "type_of",
473 BuiltinOp::IsNull => "is_null",
474 BuiltinOp::IsList => "is_list",
475 BuiltinOp::Edge => "edge",
476 BuiltinOp::Children => "children",
477 BuiltinOp::HasEdge => "has_edge",
478 BuiltinOp::EdgeCount => "edge_count",
479 BuiltinOp::Anchor => "anchor",
480 }
481}
482
483fn write_literal(buf: &mut String, lit: &Literal) {
485 match lit {
486 Literal::Bool(true) => buf.push_str("True"),
487 Literal::Bool(false) => buf.push_str("False"),
488 Literal::Int(n) => {
489 let _ = write!(buf, "{n}");
490 }
491 Literal::Float(f) => {
492 let s = format!("{f}");
495 if s.contains('.') {
496 buf.push_str(&s);
497 } else {
498 let _ = write!(buf, "{f}.0");
499 }
500 }
501 Literal::Str(s) => {
502 buf.push('"');
503 for ch in s.chars() {
505 match ch {
506 '\\' => buf.push_str("\\\\"),
507 '"' => buf.push_str("\\\""),
508 '\n' => buf.push_str("\\n"),
509 '\r' => buf.push_str("\\r"),
510 '\t' => buf.push_str("\\t"),
511 c => buf.push(c),
512 }
513 }
514 buf.push('"');
515 }
516 Literal::Bytes(bytes) => {
517 buf.push('[');
519 for (i, b) in bytes.iter().enumerate() {
520 if i > 0 {
521 buf.push_str(", ");
522 }
523 let _ = write!(buf, "{b}");
524 }
525 buf.push(']');
526 }
527 Literal::Null => buf.push_str("Nothing"),
528 Literal::Record(fields) => {
529 buf.push_str("{ ");
530 for (i, (name, val)) in fields.iter().enumerate() {
531 if i > 0 {
532 buf.push_str(", ");
533 }
534 buf.push_str(name);
535 buf.push_str(" = ");
536 write_literal(buf, val);
537 }
538 buf.push_str(" }");
539 }
540 Literal::List(items) => {
541 buf.push('[');
542 for (i, item) in items.iter().enumerate() {
543 if i > 0 {
544 buf.push_str(", ");
545 }
546 write_literal(buf, item);
547 }
548 buf.push(']');
549 }
550 Literal::Closure { param, body, .. } => {
551 buf.push('\\');
554 buf.push_str(param);
555 buf.push_str(" -> ");
556 write_expr(buf, body, Prec::Top);
557 }
558 }
559}
560
561fn write_pattern(buf: &mut String, pat: &Pattern) {
563 match pat {
564 Pattern::Wildcard => buf.push('_'),
565 Pattern::Var(name) => buf.push_str(name),
566 Pattern::Lit(lit) => write_literal(buf, lit),
567 Pattern::Record(fields) => {
568 buf.push_str("{ ");
569 for (i, (name, p)) in fields.iter().enumerate() {
570 if i > 0 {
571 buf.push_str(", ");
572 }
573 if let Pattern::Var(v) = p {
575 if v == name {
576 buf.push_str(name);
577 continue;
578 }
579 }
580 buf.push_str(name);
581 buf.push_str(" = ");
582 write_pattern(buf, p);
583 }
584 buf.push_str(" }");
585 }
586 Pattern::List(pats) => {
587 buf.push('[');
588 for (i, p) in pats.iter().enumerate() {
589 if i > 0 {
590 buf.push_str(", ");
591 }
592 write_pattern(buf, p);
593 }
594 buf.push(']');
595 }
596 Pattern::Constructor(name, args) => {
597 buf.push_str(name);
598 for arg in args {
599 buf.push(' ');
600 let needs_parens = matches!(arg, Pattern::Constructor(_, a) if !a.is_empty());
603 if needs_parens {
604 buf.push('(');
605 }
606 write_pattern(buf, arg);
607 if needs_parens {
608 buf.push(')');
609 }
610 }
611 }
612 }
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use crate::{parse, tokenize};
619
620 fn round_trip(input: &str) {
622 let tokens1 = tokenize(input).unwrap_or_else(|e| panic!("first lex failed: {e}"));
623 let expr1 = parse(&tokens1).unwrap_or_else(|e| panic!("first parse failed: {e:?}"));
624 let printed = pretty_print(&expr1);
625 let tokens2 = tokenize(&printed).unwrap_or_else(|e| {
626 panic!("re-lex failed for {printed:?}: {e}");
627 });
628 let expr2 = parse(&tokens2).unwrap_or_else(|e| {
629 panic!("re-parse failed for {printed:?}: {e:?}");
630 });
631 assert_eq!(
632 expr1, expr2,
633 "round trip failed.\n input: {input:?}\n printed: {printed:?}"
634 );
635 }
636
637 fn prints_as(expr: &Expr, expected: &str) {
639 let actual = pretty_print(expr);
640 assert_eq!(actual, expected, "pretty_print mismatch");
641 }
642
643 #[test]
646 fn lit_int() {
647 prints_as(&Expr::Lit(Literal::Int(42)), "42");
648 }
649
650 #[test]
651 fn lit_negative_int() {
652 prints_as(&Expr::Lit(Literal::Int(-5)), "-5");
653 }
654
655 #[test]
656 fn lit_float() {
657 prints_as(&Expr::Lit(Literal::Float(3.125)), "3.125");
658 }
659
660 #[test]
661 fn lit_string() {
662 prints_as(&Expr::Lit(Literal::Str("hello".into())), r#""hello""#);
663 }
664
665 #[test]
666 fn lit_string_escapes() {
667 prints_as(
668 &Expr::Lit(Literal::Str("say \"hi\"".into())),
669 r#""say \"hi\"""#,
670 );
671 }
672
673 #[test]
674 fn lit_bool() {
675 prints_as(&Expr::Lit(Literal::Bool(true)), "True");
676 prints_as(&Expr::Lit(Literal::Bool(false)), "False");
677 }
678
679 #[test]
680 fn lit_null() {
681 prints_as(&Expr::Lit(Literal::Null), "Nothing");
682 }
683
684 #[test]
685 fn lit_bytes() {
686 prints_as(&Expr::Lit(Literal::Bytes(vec![1, 2, 3])), "[1, 2, 3]");
687 }
688
689 #[test]
692 fn variable() {
693 prints_as(&Expr::Var(Arc::from("x")), "x");
694 }
695
696 #[test]
699 fn lambda_simple() {
700 prints_as(
701 &Expr::Lam(Arc::from("x"), Box::new(Expr::Var(Arc::from("x")))),
702 "\\x -> x",
703 );
704 }
705
706 #[test]
707 fn lambda_multi_param() {
708 prints_as(
709 &Expr::Lam(
710 Arc::from("x"),
711 Box::new(Expr::Lam(
712 Arc::from("y"),
713 Box::new(Expr::Builtin(
714 BuiltinOp::Add,
715 vec![Expr::Var(Arc::from("x")), Expr::Var(Arc::from("y"))],
716 )),
717 )),
718 ),
719 "\\x y -> x + y",
720 );
721 }
722
723 #[test]
724 fn lambda_round_trip() {
725 round_trip("\\x -> x + 1");
726 round_trip("\\x y -> x + y");
727 }
728
729 #[test]
732 fn app_simple() {
733 prints_as(
734 &Expr::App(
735 Box::new(Expr::Var(Arc::from("f"))),
736 Box::new(Expr::Var(Arc::from("x"))),
737 ),
738 "f x",
739 );
740 }
741
742 #[test]
743 fn app_chain() {
744 prints_as(
745 &Expr::App(
746 Box::new(Expr::App(
747 Box::new(Expr::Var(Arc::from("f"))),
748 Box::new(Expr::Var(Arc::from("x"))),
749 )),
750 Box::new(Expr::Var(Arc::from("y"))),
751 ),
752 "f x y",
753 );
754 }
755
756 #[test]
757 fn app_complex_arg() {
758 prints_as(
760 &Expr::App(
761 Box::new(Expr::Var(Arc::from("f"))),
762 Box::new(Expr::App(
763 Box::new(Expr::Var(Arc::from("g"))),
764 Box::new(Expr::Var(Arc::from("x"))),
765 )),
766 ),
767 "f (g x)",
768 );
769 }
770
771 #[test]
774 fn record_simple() {
775 prints_as(
776 &Expr::Record(vec![
777 (Arc::from("x"), Expr::Lit(Literal::Int(1))),
778 (Arc::from("y"), Expr::Lit(Literal::Int(2))),
779 ]),
780 "{ x = 1, y = 2 }",
781 );
782 }
783
784 #[test]
785 fn record_punning() {
786 prints_as(
787 &Expr::Record(vec![
788 (Arc::from("x"), Expr::Var(Arc::from("x"))),
789 (Arc::from("y"), Expr::Var(Arc::from("y"))),
790 ]),
791 "{ x, y }",
792 );
793 }
794
795 #[test]
796 fn record_mixed_punning() {
797 prints_as(
798 &Expr::Record(vec![
799 (Arc::from("x"), Expr::Var(Arc::from("x"))),
800 (Arc::from("y"), Expr::Lit(Literal::Int(42))),
801 ]),
802 "{ x, y = 42 }",
803 );
804 }
805
806 #[test]
807 fn record_round_trip() {
808 round_trip("{ name = x, age = 30 }");
809 round_trip("{ x, y }");
810 }
811
812 #[test]
815 fn list_simple() {
816 prints_as(
817 &Expr::List(vec![
818 Expr::Lit(Literal::Int(1)),
819 Expr::Lit(Literal::Int(2)),
820 Expr::Lit(Literal::Int(3)),
821 ]),
822 "[1, 2, 3]",
823 );
824 }
825
826 #[test]
827 fn list_empty() {
828 prints_as(&Expr::List(vec![]), "[]");
829 }
830
831 #[test]
832 fn list_round_trip() {
833 round_trip("[1, 2, 3]");
834 round_trip("[]");
835 }
836
837 #[test]
840 fn field_access() {
841 prints_as(
842 &Expr::Field(Box::new(Expr::Var(Arc::from("x"))), Arc::from("name")),
843 "x.name",
844 );
845 }
846
847 #[test]
848 fn field_chain() {
849 prints_as(
850 &Expr::Field(
851 Box::new(Expr::Field(
852 Box::new(Expr::Var(Arc::from("x"))),
853 Arc::from("a"),
854 )),
855 Arc::from("b"),
856 ),
857 "x.a.b",
858 );
859 }
860
861 #[test]
862 fn field_round_trip() {
863 round_trip("x.name");
864 round_trip("x.a.b");
865 }
866
867 #[test]
870 fn edge_traversal() {
871 prints_as(
872 &Expr::Builtin(
873 BuiltinOp::Edge,
874 vec![
875 Expr::Var(Arc::from("doc")),
876 Expr::Lit(Literal::Str("layers".into())),
877 ],
878 ),
879 "doc -> layers",
880 );
881 }
882
883 #[test]
884 fn edge_chain() {
885 prints_as(
886 &Expr::Builtin(
887 BuiltinOp::Edge,
888 vec![
889 Expr::Builtin(
890 BuiltinOp::Edge,
891 vec![
892 Expr::Var(Arc::from("doc")),
893 Expr::Lit(Literal::Str("layers".into())),
894 ],
895 ),
896 Expr::Lit(Literal::Str("annotations".into())),
897 ],
898 ),
899 "doc -> layers -> annotations",
900 );
901 }
902
903 #[test]
904 fn edge_round_trip() {
905 round_trip("doc -> layers");
906 round_trip("doc -> layers -> annotations");
907 }
908
909 #[test]
912 fn infix_add() {
913 prints_as(
914 &Expr::Builtin(
915 BuiltinOp::Add,
916 vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))],
917 ),
918 "x + 1",
919 );
920 }
921
922 #[test]
923 fn infix_precedence_no_parens() {
924 prints_as(
926 &Expr::Builtin(
927 BuiltinOp::Add,
928 vec![
929 Expr::Lit(Literal::Int(1)),
930 Expr::Builtin(
931 BuiltinOp::Mul,
932 vec![Expr::Lit(Literal::Int(2)), Expr::Lit(Literal::Int(3))],
933 ),
934 ],
935 ),
936 "1 + 2 * 3",
937 );
938 }
939
940 #[test]
941 fn infix_precedence_needs_parens() {
942 prints_as(
944 &Expr::Builtin(
945 BuiltinOp::Mul,
946 vec![
947 Expr::Builtin(
948 BuiltinOp::Add,
949 vec![Expr::Lit(Literal::Int(1)), Expr::Lit(Literal::Int(2))],
950 ),
951 Expr::Lit(Literal::Int(3)),
952 ],
953 ),
954 "(1 + 2) * 3",
955 );
956 }
957
958 #[test]
959 fn infix_left_assoc_no_parens() {
960 prints_as(
962 &Expr::Builtin(
963 BuiltinOp::Add,
964 vec![
965 Expr::Builtin(
966 BuiltinOp::Add,
967 vec![Expr::Lit(Literal::Int(1)), Expr::Lit(Literal::Int(2))],
968 ),
969 Expr::Lit(Literal::Int(3)),
970 ],
971 ),
972 "1 + 2 + 3",
973 );
974 }
975
976 #[test]
977 fn infix_right_assoc_needs_parens() {
978 prints_as(
980 &Expr::Builtin(
981 BuiltinOp::Add,
982 vec![
983 Expr::Lit(Literal::Int(1)),
984 Expr::Builtin(
985 BuiltinOp::Add,
986 vec![Expr::Lit(Literal::Int(2)), Expr::Lit(Literal::Int(3))],
987 ),
988 ],
989 ),
990 "1 + (2 + 3)",
991 );
992 }
993
994 #[test]
995 fn infix_concat_right_assoc() {
996 prints_as(
998 &Expr::Builtin(
999 BuiltinOp::Concat,
1000 vec![
1001 Expr::Var(Arc::from("a")),
1002 Expr::Builtin(
1003 BuiltinOp::Concat,
1004 vec![Expr::Var(Arc::from("b")), Expr::Var(Arc::from("c"))],
1005 ),
1006 ],
1007 ),
1008 "a ++ b ++ c",
1009 );
1010 }
1011
1012 #[test]
1013 fn infix_comparison() {
1014 prints_as(
1015 &Expr::Builtin(
1016 BuiltinOp::Eq,
1017 vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))],
1018 ),
1019 "x == 1",
1020 );
1021 prints_as(
1022 &Expr::Builtin(
1023 BuiltinOp::Neq,
1024 vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))],
1025 ),
1026 "x /= 1",
1027 );
1028 prints_as(
1029 &Expr::Builtin(
1030 BuiltinOp::Lt,
1031 vec![Expr::Var(Arc::from("x")), Expr::Var(Arc::from("y"))],
1032 ),
1033 "x < y",
1034 );
1035 }
1036
1037 #[test]
1038 fn infix_logical() {
1039 prints_as(
1040 &Expr::Builtin(
1041 BuiltinOp::And,
1042 vec![Expr::Var(Arc::from("a")), Expr::Var(Arc::from("b"))],
1043 ),
1044 "a && b",
1045 );
1046 prints_as(
1047 &Expr::Builtin(
1048 BuiltinOp::Or,
1049 vec![Expr::Var(Arc::from("a")), Expr::Var(Arc::from("b"))],
1050 ),
1051 "a || b",
1052 );
1053 }
1054
1055 #[test]
1056 fn infix_round_trips() {
1057 round_trip("1 + 2");
1058 round_trip("1 + 2 * 3");
1059 round_trip("(1 + 2) * 3");
1060 round_trip("a && b || c");
1061 round_trip("x == 1");
1062 round_trip("x /= 1");
1063 }
1064
1065 #[test]
1068 fn prefix_neg() {
1069 prints_as(
1070 &Expr::Builtin(BuiltinOp::Neg, vec![Expr::Var(Arc::from("x"))]),
1071 "-x",
1072 );
1073 }
1074
1075 #[test]
1076 fn prefix_not() {
1077 prints_as(
1078 &Expr::Builtin(BuiltinOp::Not, vec![Expr::Lit(Literal::Bool(true))]),
1079 "not True",
1080 );
1081 }
1082
1083 #[test]
1084 fn prefix_round_trip() {
1085 round_trip("-x");
1086 round_trip("not True");
1087 }
1088
1089 #[test]
1092 fn builtin_function_call() {
1093 prints_as(
1094 &Expr::Builtin(
1095 BuiltinOp::Map,
1096 vec![Expr::Var(Arc::from("f")), Expr::Var(Arc::from("xs"))],
1097 ),
1098 "map f xs",
1099 );
1100 }
1101
1102 #[test]
1103 fn builtin_unary() {
1104 prints_as(
1105 &Expr::Builtin(BuiltinOp::Head, vec![Expr::Var(Arc::from("xs"))]),
1106 "head xs",
1107 );
1108 }
1109
1110 #[test]
1111 fn builtin_round_trip() {
1112 round_trip("map f xs");
1113 round_trip("head xs");
1114 round_trip("filter f xs");
1115 }
1116
1117 #[test]
1120 fn let_simple() {
1121 prints_as(
1122 &Expr::Let {
1123 name: Arc::from("x"),
1124 value: Box::new(Expr::Lit(Literal::Int(1))),
1125 body: Box::new(Expr::Builtin(
1126 BuiltinOp::Add,
1127 vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))],
1128 )),
1129 },
1130 "let x = 1 in x + 1",
1131 );
1132 }
1133
1134 #[test]
1135 fn let_round_trip() {
1136 round_trip("let x = 1 in x + 1");
1137 }
1138
1139 #[test]
1142 fn if_then_else() {
1143 let expr = Expr::Match {
1144 scrutinee: Box::new(Expr::Lit(Literal::Bool(true))),
1145 arms: vec![
1146 (
1147 Pattern::Lit(Literal::Bool(true)),
1148 Expr::Lit(Literal::Int(1)),
1149 ),
1150 (Pattern::Wildcard, Expr::Lit(Literal::Int(0))),
1151 ],
1152 };
1153 prints_as(&expr, "if True then 1 else 0");
1154 }
1155
1156 #[test]
1157 fn if_round_trip() {
1158 round_trip("if True then 1 else 0");
1159 }
1160
1161 #[test]
1164 fn case_of() {
1165 let expr = Expr::Match {
1166 scrutinee: Box::new(Expr::Var(Arc::from("x"))),
1167 arms: vec![
1168 (
1169 Pattern::Lit(Literal::Bool(true)),
1170 Expr::Lit(Literal::Int(1)),
1171 ),
1172 (
1173 Pattern::Lit(Literal::Bool(false)),
1174 Expr::Lit(Literal::Int(0)),
1175 ),
1176 ],
1177 };
1178 prints_as(&expr, "case x of\n True -> 1\n False -> 0");
1179 }
1180
1181 #[test]
1182 fn case_round_trip() {
1183 round_trip("case x of\n True -> 1\n False -> 0");
1184 }
1185
1186 #[test]
1189 fn nested_let_in_lambda() {
1190 round_trip("\\x -> let y = x + 1 in y * 2");
1191 }
1192
1193 #[test]
1194 fn nested_if_in_let() {
1195 round_trip("let x = if True then 1 else 0 in x + 1");
1196 }
1197
1198 #[test]
1199 fn lambda_as_arg() {
1200 prints_as(
1202 &Expr::App(
1203 Box::new(Expr::Var(Arc::from("f"))),
1204 Box::new(Expr::Lam(
1205 Arc::from("x"),
1206 Box::new(Expr::Var(Arc::from("x"))),
1207 )),
1208 ),
1209 "f (\\x -> x)",
1210 );
1211 }
1212
1213 #[test]
1214 fn complex_expression_round_trip() {
1215 round_trip("\\f xs -> map (\\x -> f x + 1) xs");
1216 }
1217
1218 #[test]
1221 fn pattern_wildcard() {
1222 let mut buf = String::new();
1223 write_pattern(&mut buf, &Pattern::Wildcard);
1224 assert_eq!(buf, "_");
1225 }
1226
1227 #[test]
1228 fn pattern_var() {
1229 let mut buf = String::new();
1230 write_pattern(&mut buf, &Pattern::Var(Arc::from("x")));
1231 assert_eq!(buf, "x");
1232 }
1233
1234 #[test]
1235 fn pattern_lit() {
1236 let mut buf = String::new();
1237 write_pattern(&mut buf, &Pattern::Lit(Literal::Int(42)));
1238 assert_eq!(buf, "42");
1239 }
1240
1241 #[test]
1242 fn pattern_list() {
1243 let mut buf = String::new();
1244 write_pattern(
1245 &mut buf,
1246 &Pattern::List(vec![
1247 Pattern::Var(Arc::from("x")),
1248 Pattern::Var(Arc::from("y")),
1249 ]),
1250 );
1251 assert_eq!(buf, "[x, y]");
1252 }
1253
1254 #[test]
1255 fn pattern_record_punning() {
1256 let mut buf = String::new();
1257 write_pattern(
1258 &mut buf,
1259 &Pattern::Record(vec![
1260 (Arc::from("x"), Pattern::Var(Arc::from("x"))),
1261 (Arc::from("y"), Pattern::Var(Arc::from("y"))),
1262 ]),
1263 );
1264 assert_eq!(buf, "{ x, y }");
1265 }
1266
1267 #[test]
1268 fn pattern_constructor() {
1269 let mut buf = String::new();
1270 write_pattern(
1271 &mut buf,
1272 &Pattern::Constructor(Arc::from("Just"), vec![Pattern::Var(Arc::from("x"))]),
1273 );
1274 assert_eq!(buf, "Just x");
1275 }
1276
1277 #[test]
1280 fn index_access() {
1281 prints_as(
1282 &Expr::Index(
1283 Box::new(Expr::Var(Arc::from("xs"))),
1284 Box::new(Expr::Lit(Literal::Int(0))),
1285 ),
1286 "xs[0]",
1287 );
1288 }
1289
1290 #[test]
1293 fn literal_record() {
1294 prints_as(
1295 &Expr::Lit(Literal::Record(vec![
1296 (Arc::from("x"), Literal::Int(1)),
1297 (Arc::from("y"), Literal::Int(2)),
1298 ])),
1299 "{ x = 1, y = 2 }",
1300 );
1301 }
1302
1303 #[test]
1304 fn literal_list() {
1305 prints_as(
1306 &Expr::Lit(Literal::List(vec![Literal::Int(1), Literal::Int(2)])),
1307 "[1, 2]",
1308 );
1309 }
1310
1311 #[test]
1314 fn precedence_logical_and_comparison() {
1315 round_trip("x == 1 && y == 2");
1316 }
1317
1318 #[test]
1319 fn precedence_arithmetic_in_comparison() {
1320 round_trip("x + 1 == y * 2");
1321 }
1322
1323 #[test]
1324 fn concat_round_trip() {
1325 round_trip(r#""hello" ++ " world""#);
1326 }
1327}