Skip to main content

panproto_expr_parser/
pretty.rs

1//! Pretty printer for panproto expressions.
2//!
3//! Converts `panproto_expr::Expr` back into Haskell-style surface syntax.
4//! The output is designed to round-trip through the parser:
5//! `parse(tokenize(pretty_print(&e))) == e` for well-formed expressions.
6//!
7//! Parentheses are minimized using precedence awareness, and operators
8//! are printed in infix notation where the parser supports it.
9
10use std::fmt::Write;
11use std::sync::Arc;
12
13use panproto_expr::{BuiltinOp, Expr, Literal, Pattern};
14
15/// Precedence levels (higher binds tighter).
16///
17/// These mirror the Pratt parser precedences in `parser.rs`.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
19enum Prec {
20    /// Top level: no parens needed.
21    Top = 0,
22    /// Pipe operator (`&`).
23    Pipe = 1,
24    /// Logical or (`||`).
25    Or = 3,
26    /// Logical and (`&&`).
27    And = 4,
28    /// Comparison (`==`, `/=`, `<`, `<=`, `>`, `>=`).
29    Cmp = 5,
30    /// Concatenation (`++`).
31    Concat = 6,
32    /// Addition and subtraction (`+`, `-`).
33    AddSub = 7,
34    /// Multiplication, division, modulo (`*`, `/`, `%`, `mod`, `div`).
35    MulDiv = 8,
36    /// Unary prefix (`-`, `not`).
37    Unary = 9,
38    /// Function application.
39    App = 10,
40    /// Postfix (`.field`, `->edge`), atoms.
41    Atom = 11,
42}
43
44/// Associativity of a binary operator.
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46enum Assoc {
47    Left,
48    Right,
49}
50
51/// Pretty print an expression to a string.
52///
53/// The output uses Haskell-style surface syntax with minimal parentheses.
54///
55/// # Examples
56///
57/// ```
58/// use std::sync::Arc;
59/// use panproto_expr::{Expr, Literal, BuiltinOp};
60/// use panproto_expr_parser::pretty_print;
61///
62/// let e = Expr::Builtin(BuiltinOp::Add, vec![
63///     Expr::Var(Arc::from("x")),
64///     Expr::Lit(Literal::Int(1)),
65/// ]);
66/// assert_eq!(pretty_print(&e), "x + 1");
67/// ```
68#[must_use]
69pub fn pretty_print(expr: &Expr) -> String {
70    let mut buf = String::new();
71    write_expr(&mut buf, expr, Prec::Top);
72    buf
73}
74
75/// Write an expression at the given precedence context.
76///
77/// If the expression's own precedence is lower than `ctx`, wraps in parens.
78fn write_expr(buf: &mut String, expr: &Expr, ctx: Prec) {
79    match expr {
80        Expr::Var(name) => buf.push_str(name),
81
82        Expr::Lit(lit) => write_literal(buf, lit),
83
84        Expr::Lam(param, body) => {
85            let needs_parens = ctx > Prec::Top;
86            if needs_parens {
87                buf.push('(');
88            }
89            write_lambda_chain(buf, param, body);
90            if needs_parens {
91                buf.push(')');
92            }
93        }
94
95        Expr::App(func, arg) => {
96            write_app(buf, expr, ctx);
97            let _ = (func, arg); // used inside write_app
98        }
99
100        Expr::Record(fields) => {
101            write_record_expr(buf, fields);
102        }
103
104        Expr::List(items) => {
105            buf.push('[');
106            for (i, item) in items.iter().enumerate() {
107                if i > 0 {
108                    buf.push_str(", ");
109                }
110                write_expr(buf, item, Prec::Top);
111            }
112            buf.push(']');
113        }
114
115        Expr::Field(inner, name) => {
116            write_expr(buf, inner, Prec::Atom);
117            buf.push('.');
118            buf.push_str(name);
119        }
120
121        Expr::Index(inner, idx) => {
122            write_expr(buf, inner, Prec::Atom);
123            buf.push('[');
124            write_expr(buf, idx, Prec::Top);
125            buf.push(']');
126        }
127
128        Expr::Match { scrutinee, arms } => {
129            write_match(buf, scrutinee, arms, ctx);
130        }
131
132        Expr::Let { name, value, body } => {
133            write_let(buf, name, value, body, ctx);
134        }
135
136        Expr::Builtin(op, args) => {
137            write_builtin(buf, *op, args, ctx);
138        }
139    }
140}
141
142/// Write a chain of nested lambdas as `\x y z -> body`.
143fn write_lambda_chain(buf: &mut String, first_param: &Arc<str>, first_body: &Expr) {
144    buf.push('\\');
145    buf.push_str(first_param);
146    let mut body = first_body;
147    while let Expr::Lam(param, inner) = body {
148        buf.push(' ');
149        buf.push_str(param);
150        body = inner;
151    }
152    buf.push_str(" -> ");
153    write_expr(buf, body, Prec::Top);
154}
155
156/// Write function application, collecting curried args: `f x y z`.
157fn write_app(buf: &mut String, expr: &Expr, ctx: Prec) {
158    let needs_parens = ctx > Prec::App;
159    if needs_parens {
160        buf.push('(');
161    }
162
163    // Collect the application spine.
164    let mut spine: Vec<&Expr> = Vec::new();
165    let mut head = expr;
166    while let Expr::App(func, arg) = head {
167        spine.push(arg);
168        head = func;
169    }
170    spine.reverse();
171
172    write_expr(buf, head, Prec::App);
173    for arg in &spine {
174        buf.push(' ');
175        write_expr(buf, arg, Prec::Atom);
176    }
177
178    if needs_parens {
179        buf.push(')');
180    }
181}
182
183/// Write a record expression with punning where appropriate.
184fn write_record_expr(buf: &mut String, fields: &[(Arc<str>, Expr)]) {
185    buf.push_str("{ ");
186    for (i, (name, val)) in fields.iter().enumerate() {
187        if i > 0 {
188            buf.push_str(", ");
189        }
190        // Record punning: `{ x }` when field name equals variable name.
191        if let Expr::Var(v) = val {
192            if v == name {
193                buf.push_str(name);
194                continue;
195            }
196        }
197        buf.push_str(name);
198        buf.push_str(" = ");
199        write_expr(buf, val, Prec::Top);
200    }
201    buf.push_str(" }");
202}
203
204/// Write a match expression.
205///
206/// Detects `if/then/else` patterns (two arms with True and Wildcard)
207/// and emits those in the shorter form.
208fn write_match(buf: &mut String, scrutinee: &Expr, arms: &[(Pattern, Expr)], ctx: Prec) {
209    // Detect if/then/else: Match with [Lit(Bool(true)) -> then, Wildcard -> else]
210    if arms.len() == 2 {
211        if let (Pattern::Lit(Literal::Bool(true)), then_branch) = &arms[0] {
212            if let (Pattern::Wildcard, else_branch) = &arms[1] {
213                let needs_parens = ctx > Prec::Top;
214                if needs_parens {
215                    buf.push('(');
216                }
217                buf.push_str("if ");
218                write_expr(buf, scrutinee, Prec::Top);
219                buf.push_str(" then ");
220                write_expr(buf, then_branch, Prec::Top);
221                buf.push_str(" else ");
222                write_expr(buf, else_branch, Prec::Top);
223                if needs_parens {
224                    buf.push(')');
225                }
226                return;
227            }
228        }
229    }
230
231    let needs_parens = ctx > Prec::Top;
232    if needs_parens {
233        buf.push('(');
234    }
235    buf.push_str("case ");
236    write_expr(buf, scrutinee, Prec::Top);
237    buf.push_str(" of\n");
238    for (i, (pat, body)) in arms.iter().enumerate() {
239        if i > 0 {
240            buf.push('\n');
241        }
242        buf.push_str("  ");
243        write_pattern(buf, pat);
244        buf.push_str(" -> ");
245        write_expr(buf, body, Prec::Top);
246    }
247    if needs_parens {
248        buf.push(')');
249    }
250}
251
252/// Write a let binding, collapsing nested lets into a layout block.
253fn write_let(buf: &mut String, name: &Arc<str>, value: &Expr, body: &Expr, ctx: Prec) {
254    let needs_parens = ctx > Prec::Top;
255    if needs_parens {
256        buf.push('(');
257    }
258
259    // Collect chained lets.
260    let mut bindings: Vec<(&Arc<str>, &Expr)> = vec![(name, value)];
261    let mut final_body = body;
262    while let Expr::Let {
263        name: n,
264        value: v,
265        body: b,
266    } = final_body
267    {
268        bindings.push((n, v));
269        final_body = b;
270    }
271
272    if bindings.len() == 1 {
273        buf.push_str("let ");
274        buf.push_str(name);
275        buf.push_str(" = ");
276        write_expr(buf, value, Prec::Top);
277        buf.push_str(" in ");
278    } else {
279        buf.push_str("let\n");
280        for (n, v) in &bindings {
281            buf.push_str("  ");
282            buf.push_str(n);
283            buf.push_str(" = ");
284            write_expr(buf, v, Prec::Top);
285            buf.push('\n');
286        }
287        buf.push_str("in ");
288    }
289    write_expr(buf, final_body, Prec::Top);
290
291    if needs_parens {
292        buf.push(')');
293    }
294}
295
296/// Write a builtin operation, using infix/prefix syntax where possible.
297fn write_builtin(buf: &mut String, op: BuiltinOp, args: &[Expr], ctx: Prec) {
298    // Try infix binary operators.
299    if let Some((sym, prec, assoc)) = infix_info(op) {
300        if args.len() == 2 {
301            let needs_parens = ctx > prec;
302            if needs_parens {
303                buf.push('(');
304            }
305            // For left-associative operators, the left child is fine at the
306            // same precedence but the right child needs to be tighter (to
307            // avoid ambiguity). Vice versa for right-associative.
308            let (left_ctx, right_ctx) = match assoc {
309                Assoc::Left => (prec, next_prec(prec)),
310                Assoc::Right => (next_prec(prec), prec),
311            };
312            write_expr(buf, &args[0], left_ctx);
313            buf.push(' ');
314            buf.push_str(sym);
315            buf.push(' ');
316            write_expr(buf, &args[1], right_ctx);
317            if needs_parens {
318                buf.push(')');
319            }
320            return;
321        }
322    }
323
324    // Edge traversal: `expr -> edge`
325    if op == BuiltinOp::Edge && args.len() == 2 {
326        if let Expr::Lit(Literal::Str(edge_name)) = &args[1] {
327            let needs_parens = ctx > Prec::Atom;
328            if needs_parens {
329                buf.push('(');
330            }
331            write_expr(buf, &args[0], Prec::Atom);
332            buf.push_str(" -> ");
333            buf.push_str(edge_name);
334            if needs_parens {
335                buf.push(')');
336            }
337            return;
338        }
339    }
340
341    // Unary prefix: negation and logical not.
342    if op == BuiltinOp::Neg && args.len() == 1 {
343        let needs_parens = ctx > Prec::Unary;
344        if needs_parens {
345            buf.push('(');
346        }
347        buf.push('-');
348        write_expr(buf, &args[0], Prec::Atom);
349        if needs_parens {
350            buf.push(')');
351        }
352        return;
353    }
354
355    if op == BuiltinOp::Not && args.len() == 1 {
356        let needs_parens = ctx > Prec::Unary;
357        if needs_parens {
358            buf.push('(');
359        }
360        buf.push_str("not ");
361        write_expr(buf, &args[0], Prec::Atom);
362        if needs_parens {
363            buf.push(')');
364        }
365        return;
366    }
367
368    // Fallback: function call syntax `name arg1 arg2 ...`
369    let needs_parens = ctx > Prec::App && !args.is_empty();
370    if needs_parens {
371        buf.push('(');
372    }
373    buf.push_str(builtin_name(op));
374    for arg in args {
375        buf.push(' ');
376        write_expr(buf, arg, Prec::Atom);
377    }
378    if needs_parens {
379        buf.push(')');
380    }
381}
382
383/// Map a builtin op to its infix operator symbol, precedence, and associativity.
384///
385/// Returns `None` for builtins that should use function call syntax.
386const fn infix_info(op: BuiltinOp) -> Option<(&'static str, Prec, Assoc)> {
387    match op {
388        BuiltinOp::Or => Some(("||", Prec::Or, Assoc::Left)),
389        BuiltinOp::And => Some(("&&", Prec::And, Assoc::Left)),
390        BuiltinOp::Eq => Some(("==", Prec::Cmp, Assoc::Right)),
391        BuiltinOp::Neq => Some(("/=", Prec::Cmp, Assoc::Right)),
392        BuiltinOp::Lt => Some(("<", Prec::Cmp, Assoc::Right)),
393        BuiltinOp::Lte => Some(("<=", Prec::Cmp, Assoc::Right)),
394        BuiltinOp::Gt => Some((">", Prec::Cmp, Assoc::Right)),
395        BuiltinOp::Gte => Some((">=", Prec::Cmp, Assoc::Right)),
396        BuiltinOp::Concat => Some(("++", Prec::Concat, Assoc::Right)),
397        BuiltinOp::Add => Some(("+", Prec::AddSub, Assoc::Left)),
398        BuiltinOp::Sub => Some(("-", Prec::AddSub, Assoc::Left)),
399        BuiltinOp::Mul => Some(("*", Prec::MulDiv, Assoc::Left)),
400        BuiltinOp::Div => Some(("/", Prec::MulDiv, Assoc::Left)),
401        BuiltinOp::Mod => Some(("%", Prec::MulDiv, Assoc::Left)),
402        _ => None,
403    }
404}
405
406/// Get the next higher precedence level.
407const fn next_prec(p: Prec) -> Prec {
408    match p {
409        Prec::Top => Prec::Pipe,
410        Prec::Pipe => Prec::Or,
411        Prec::Or => Prec::And,
412        Prec::And => Prec::Cmp,
413        Prec::Cmp => Prec::Concat,
414        Prec::Concat => Prec::AddSub,
415        Prec::AddSub => Prec::MulDiv,
416        Prec::MulDiv => Prec::Unary,
417        Prec::Unary => Prec::App,
418        Prec::App | Prec::Atom => Prec::Atom,
419    }
420}
421
422/// Map a builtin op to its canonical function name for call syntax.
423const fn builtin_name(op: BuiltinOp) -> &'static str {
424    match op {
425        BuiltinOp::Add => "add",
426        BuiltinOp::Sub => "sub",
427        BuiltinOp::Mul => "mul",
428        BuiltinOp::Div => "div",
429        BuiltinOp::Mod => "mod",
430        BuiltinOp::Neg => "neg",
431        BuiltinOp::Abs => "abs",
432        BuiltinOp::Floor => "floor",
433        BuiltinOp::Ceil => "ceil",
434        BuiltinOp::Eq => "eq",
435        BuiltinOp::Neq => "neq",
436        BuiltinOp::Lt => "lt",
437        BuiltinOp::Lte => "lte",
438        BuiltinOp::Gt => "gt",
439        BuiltinOp::Gte => "gte",
440        BuiltinOp::And => "and",
441        BuiltinOp::Or => "or",
442        BuiltinOp::Not => "not",
443        BuiltinOp::Concat => "concat",
444        BuiltinOp::Len => "len",
445        BuiltinOp::Slice => "slice",
446        BuiltinOp::Upper => "upper",
447        BuiltinOp::Lower => "lower",
448        BuiltinOp::Trim => "trim",
449        BuiltinOp::Split => "split",
450        BuiltinOp::Join => "join",
451        BuiltinOp::Replace => "replace",
452        BuiltinOp::Contains => "contains",
453        BuiltinOp::Map => "map",
454        BuiltinOp::Filter => "filter",
455        BuiltinOp::Fold => "fold",
456        BuiltinOp::Append => "append",
457        BuiltinOp::Head => "head",
458        BuiltinOp::Tail => "tail",
459        BuiltinOp::Reverse => "reverse",
460        BuiltinOp::FlatMap => "flat_map",
461        BuiltinOp::Length => "length",
462        BuiltinOp::MergeRecords => "merge",
463        BuiltinOp::Keys => "keys",
464        BuiltinOp::Values => "values",
465        BuiltinOp::HasField => "has_field",
466        BuiltinOp::IntToFloat => "int_to_float",
467        BuiltinOp::FloatToInt => "float_to_int",
468        BuiltinOp::IntToStr => "int_to_str",
469        BuiltinOp::FloatToStr => "float_to_str",
470        BuiltinOp::StrToInt => "str_to_int",
471        BuiltinOp::StrToFloat => "str_to_float",
472        BuiltinOp::TypeOf => "type_of",
473        BuiltinOp::IsNull => "is_null",
474        BuiltinOp::IsList => "is_list",
475        BuiltinOp::Edge => "edge",
476        BuiltinOp::Children => "children",
477        BuiltinOp::HasEdge => "has_edge",
478        BuiltinOp::EdgeCount => "edge_count",
479        BuiltinOp::Anchor => "anchor",
480    }
481}
482
483/// Write a literal value.
484fn write_literal(buf: &mut String, lit: &Literal) {
485    match lit {
486        Literal::Bool(true) => buf.push_str("True"),
487        Literal::Bool(false) => buf.push_str("False"),
488        Literal::Int(n) => {
489            let _ = write!(buf, "{n}");
490        }
491        Literal::Float(f) => {
492            // Ensure there is always a decimal point so the parser
493            // recognizes this as a float, not an int.
494            let s = format!("{f}");
495            if s.contains('.') {
496                buf.push_str(&s);
497            } else {
498                let _ = write!(buf, "{f}.0");
499            }
500        }
501        Literal::Str(s) => {
502            buf.push('"');
503            // Escape backslashes and double quotes.
504            for ch in s.chars() {
505                match ch {
506                    '\\' => buf.push_str("\\\\"),
507                    '"' => buf.push_str("\\\""),
508                    '\n' => buf.push_str("\\n"),
509                    '\r' => buf.push_str("\\r"),
510                    '\t' => buf.push_str("\\t"),
511                    c => buf.push(c),
512                }
513            }
514            buf.push('"');
515        }
516        Literal::Bytes(bytes) => {
517            // No native bytes syntax; emit as a list of ints.
518            buf.push('[');
519            for (i, b) in bytes.iter().enumerate() {
520                if i > 0 {
521                    buf.push_str(", ");
522                }
523                let _ = write!(buf, "{b}");
524            }
525            buf.push(']');
526        }
527        Literal::Null => buf.push_str("Nothing"),
528        Literal::Record(fields) => {
529            buf.push_str("{ ");
530            for (i, (name, val)) in fields.iter().enumerate() {
531                if i > 0 {
532                    buf.push_str(", ");
533                }
534                buf.push_str(name);
535                buf.push_str(" = ");
536                write_literal(buf, val);
537            }
538            buf.push_str(" }");
539        }
540        Literal::List(items) => {
541            buf.push('[');
542            for (i, item) in items.iter().enumerate() {
543                if i > 0 {
544                    buf.push_str(", ");
545                }
546                write_literal(buf, item);
547            }
548            buf.push(']');
549        }
550        Literal::Closure { param, body, .. } => {
551            // Print as a lambda; the captured env is lost but the
552            // expression form is preserved for round-tripping.
553            buf.push('\\');
554            buf.push_str(param);
555            buf.push_str(" -> ");
556            write_expr(buf, body, Prec::Top);
557        }
558    }
559}
560
561/// Write a pattern.
562fn write_pattern(buf: &mut String, pat: &Pattern) {
563    match pat {
564        Pattern::Wildcard => buf.push('_'),
565        Pattern::Var(name) => buf.push_str(name),
566        Pattern::Lit(lit) => write_literal(buf, lit),
567        Pattern::Record(fields) => {
568            buf.push_str("{ ");
569            for (i, (name, p)) in fields.iter().enumerate() {
570                if i > 0 {
571                    buf.push_str(", ");
572                }
573                // Record pattern punning: `{ x }` when field pattern is Var(x).
574                if let Pattern::Var(v) = p {
575                    if v == name {
576                        buf.push_str(name);
577                        continue;
578                    }
579                }
580                buf.push_str(name);
581                buf.push_str(" = ");
582                write_pattern(buf, p);
583            }
584            buf.push_str(" }");
585        }
586        Pattern::List(pats) => {
587            buf.push('[');
588            for (i, p) in pats.iter().enumerate() {
589                if i > 0 {
590                    buf.push_str(", ");
591                }
592                write_pattern(buf, p);
593            }
594            buf.push(']');
595        }
596        Pattern::Constructor(name, args) => {
597            buf.push_str(name);
598            for arg in args {
599                buf.push(' ');
600                // Wrap constructor args in parens if they are themselves
601                // constructors with args (to avoid ambiguity).
602                let needs_parens = matches!(arg, Pattern::Constructor(_, a) if !a.is_empty());
603                if needs_parens {
604                    buf.push('(');
605                }
606                write_pattern(buf, arg);
607                if needs_parens {
608                    buf.push(')');
609                }
610            }
611        }
612    }
613}
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618    use crate::{parse, tokenize};
619
620    /// Parse a string, pretty print it, re-parse, and verify equality.
621    fn round_trip(input: &str) {
622        let tokens1 = tokenize(input).unwrap_or_else(|e| panic!("first lex failed: {e}"));
623        let expr1 = parse(&tokens1).unwrap_or_else(|e| panic!("first parse failed: {e:?}"));
624        let printed = pretty_print(&expr1);
625        let tokens2 = tokenize(&printed).unwrap_or_else(|e| {
626            panic!("re-lex failed for {printed:?}: {e}");
627        });
628        let expr2 = parse(&tokens2).unwrap_or_else(|e| {
629            panic!("re-parse failed for {printed:?}: {e:?}");
630        });
631        assert_eq!(
632            expr1, expr2,
633            "round trip failed.\n  input:   {input:?}\n  printed: {printed:?}"
634        );
635    }
636
637    /// Pretty print an expression built programmatically and check output.
638    fn prints_as(expr: &Expr, expected: &str) {
639        let actual = pretty_print(expr);
640        assert_eq!(actual, expected, "pretty_print mismatch");
641    }
642
643    // ── Literals ──────────────────────────────────────────────────
644
645    #[test]
646    fn lit_int() {
647        prints_as(&Expr::Lit(Literal::Int(42)), "42");
648    }
649
650    #[test]
651    fn lit_negative_int() {
652        prints_as(&Expr::Lit(Literal::Int(-5)), "-5");
653    }
654
655    #[test]
656    fn lit_float() {
657        prints_as(&Expr::Lit(Literal::Float(3.125)), "3.125");
658    }
659
660    #[test]
661    fn lit_string() {
662        prints_as(&Expr::Lit(Literal::Str("hello".into())), r#""hello""#);
663    }
664
665    #[test]
666    fn lit_string_escapes() {
667        prints_as(
668            &Expr::Lit(Literal::Str("say \"hi\"".into())),
669            r#""say \"hi\"""#,
670        );
671    }
672
673    #[test]
674    fn lit_bool() {
675        prints_as(&Expr::Lit(Literal::Bool(true)), "True");
676        prints_as(&Expr::Lit(Literal::Bool(false)), "False");
677    }
678
679    #[test]
680    fn lit_null() {
681        prints_as(&Expr::Lit(Literal::Null), "Nothing");
682    }
683
684    #[test]
685    fn lit_bytes() {
686        prints_as(&Expr::Lit(Literal::Bytes(vec![1, 2, 3])), "[1, 2, 3]");
687    }
688
689    // ── Variables ─────────────────────────────────────────────────
690
691    #[test]
692    fn variable() {
693        prints_as(&Expr::Var(Arc::from("x")), "x");
694    }
695
696    // ── Lambda ────────────────────────────────────────────────────
697
698    #[test]
699    fn lambda_simple() {
700        prints_as(
701            &Expr::Lam(Arc::from("x"), Box::new(Expr::Var(Arc::from("x")))),
702            "\\x -> x",
703        );
704    }
705
706    #[test]
707    fn lambda_multi_param() {
708        prints_as(
709            &Expr::Lam(
710                Arc::from("x"),
711                Box::new(Expr::Lam(
712                    Arc::from("y"),
713                    Box::new(Expr::Builtin(
714                        BuiltinOp::Add,
715                        vec![Expr::Var(Arc::from("x")), Expr::Var(Arc::from("y"))],
716                    )),
717                )),
718            ),
719            "\\x y -> x + y",
720        );
721    }
722
723    #[test]
724    fn lambda_round_trip() {
725        round_trip("\\x -> x + 1");
726        round_trip("\\x y -> x + y");
727    }
728
729    // ── Application ───────────────────────────────────────────────
730
731    #[test]
732    fn app_simple() {
733        prints_as(
734            &Expr::App(
735                Box::new(Expr::Var(Arc::from("f"))),
736                Box::new(Expr::Var(Arc::from("x"))),
737            ),
738            "f x",
739        );
740    }
741
742    #[test]
743    fn app_chain() {
744        prints_as(
745            &Expr::App(
746                Box::new(Expr::App(
747                    Box::new(Expr::Var(Arc::from("f"))),
748                    Box::new(Expr::Var(Arc::from("x"))),
749                )),
750                Box::new(Expr::Var(Arc::from("y"))),
751            ),
752            "f x y",
753        );
754    }
755
756    #[test]
757    fn app_complex_arg() {
758        // f (g x) should parenthesize the argument
759        prints_as(
760            &Expr::App(
761                Box::new(Expr::Var(Arc::from("f"))),
762                Box::new(Expr::App(
763                    Box::new(Expr::Var(Arc::from("g"))),
764                    Box::new(Expr::Var(Arc::from("x"))),
765                )),
766            ),
767            "f (g x)",
768        );
769    }
770
771    // ── Record ────────────────────────────────────────────────────
772
773    #[test]
774    fn record_simple() {
775        prints_as(
776            &Expr::Record(vec![
777                (Arc::from("x"), Expr::Lit(Literal::Int(1))),
778                (Arc::from("y"), Expr::Lit(Literal::Int(2))),
779            ]),
780            "{ x = 1, y = 2 }",
781        );
782    }
783
784    #[test]
785    fn record_punning() {
786        prints_as(
787            &Expr::Record(vec![
788                (Arc::from("x"), Expr::Var(Arc::from("x"))),
789                (Arc::from("y"), Expr::Var(Arc::from("y"))),
790            ]),
791            "{ x, y }",
792        );
793    }
794
795    #[test]
796    fn record_mixed_punning() {
797        prints_as(
798            &Expr::Record(vec![
799                (Arc::from("x"), Expr::Var(Arc::from("x"))),
800                (Arc::from("y"), Expr::Lit(Literal::Int(42))),
801            ]),
802            "{ x, y = 42 }",
803        );
804    }
805
806    #[test]
807    fn record_round_trip() {
808        round_trip("{ name = x, age = 30 }");
809        round_trip("{ x, y }");
810    }
811
812    // ── List ──────────────────────────────────────────────────────
813
814    #[test]
815    fn list_simple() {
816        prints_as(
817            &Expr::List(vec![
818                Expr::Lit(Literal::Int(1)),
819                Expr::Lit(Literal::Int(2)),
820                Expr::Lit(Literal::Int(3)),
821            ]),
822            "[1, 2, 3]",
823        );
824    }
825
826    #[test]
827    fn list_empty() {
828        prints_as(&Expr::List(vec![]), "[]");
829    }
830
831    #[test]
832    fn list_round_trip() {
833        round_trip("[1, 2, 3]");
834        round_trip("[]");
835    }
836
837    // ── Field access ──────────────────────────────────────────────
838
839    #[test]
840    fn field_access() {
841        prints_as(
842            &Expr::Field(Box::new(Expr::Var(Arc::from("x"))), Arc::from("name")),
843            "x.name",
844        );
845    }
846
847    #[test]
848    fn field_chain() {
849        prints_as(
850            &Expr::Field(
851                Box::new(Expr::Field(
852                    Box::new(Expr::Var(Arc::from("x"))),
853                    Arc::from("a"),
854                )),
855                Arc::from("b"),
856            ),
857            "x.a.b",
858        );
859    }
860
861    #[test]
862    fn field_round_trip() {
863        round_trip("x.name");
864        round_trip("x.a.b");
865    }
866
867    // ── Edge traversal ────────────────────────────────────────────
868
869    #[test]
870    fn edge_traversal() {
871        prints_as(
872            &Expr::Builtin(
873                BuiltinOp::Edge,
874                vec![
875                    Expr::Var(Arc::from("doc")),
876                    Expr::Lit(Literal::Str("layers".into())),
877                ],
878            ),
879            "doc -> layers",
880        );
881    }
882
883    #[test]
884    fn edge_chain() {
885        prints_as(
886            &Expr::Builtin(
887                BuiltinOp::Edge,
888                vec![
889                    Expr::Builtin(
890                        BuiltinOp::Edge,
891                        vec![
892                            Expr::Var(Arc::from("doc")),
893                            Expr::Lit(Literal::Str("layers".into())),
894                        ],
895                    ),
896                    Expr::Lit(Literal::Str("annotations".into())),
897                ],
898            ),
899            "doc -> layers -> annotations",
900        );
901    }
902
903    #[test]
904    fn edge_round_trip() {
905        round_trip("doc -> layers");
906        round_trip("doc -> layers -> annotations");
907    }
908
909    // ── Infix operators ───────────────────────────────────────────
910
911    #[test]
912    fn infix_add() {
913        prints_as(
914            &Expr::Builtin(
915                BuiltinOp::Add,
916                vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))],
917            ),
918            "x + 1",
919        );
920    }
921
922    #[test]
923    fn infix_precedence_no_parens() {
924        // 1 + 2 * 3 should not need parens because * binds tighter.
925        prints_as(
926            &Expr::Builtin(
927                BuiltinOp::Add,
928                vec![
929                    Expr::Lit(Literal::Int(1)),
930                    Expr::Builtin(
931                        BuiltinOp::Mul,
932                        vec![Expr::Lit(Literal::Int(2)), Expr::Lit(Literal::Int(3))],
933                    ),
934                ],
935            ),
936            "1 + 2 * 3",
937        );
938    }
939
940    #[test]
941    fn infix_precedence_needs_parens() {
942        // (1 + 2) * 3 needs parens because + is lower than *.
943        prints_as(
944            &Expr::Builtin(
945                BuiltinOp::Mul,
946                vec![
947                    Expr::Builtin(
948                        BuiltinOp::Add,
949                        vec![Expr::Lit(Literal::Int(1)), Expr::Lit(Literal::Int(2))],
950                    ),
951                    Expr::Lit(Literal::Int(3)),
952                ],
953            ),
954            "(1 + 2) * 3",
955        );
956    }
957
958    #[test]
959    fn infix_left_assoc_no_parens() {
960        // 1 + 2 + 3 is left-associative, so (1+2)+3 needs no parens.
961        prints_as(
962            &Expr::Builtin(
963                BuiltinOp::Add,
964                vec![
965                    Expr::Builtin(
966                        BuiltinOp::Add,
967                        vec![Expr::Lit(Literal::Int(1)), Expr::Lit(Literal::Int(2))],
968                    ),
969                    Expr::Lit(Literal::Int(3)),
970                ],
971            ),
972            "1 + 2 + 3",
973        );
974    }
975
976    #[test]
977    fn infix_right_assoc_needs_parens() {
978        // For left-assoc +, 1 + (2 + 3) needs parens on the right.
979        prints_as(
980            &Expr::Builtin(
981                BuiltinOp::Add,
982                vec![
983                    Expr::Lit(Literal::Int(1)),
984                    Expr::Builtin(
985                        BuiltinOp::Add,
986                        vec![Expr::Lit(Literal::Int(2)), Expr::Lit(Literal::Int(3))],
987                    ),
988                ],
989            ),
990            "1 + (2 + 3)",
991        );
992    }
993
994    #[test]
995    fn infix_concat_right_assoc() {
996        // ++ is right-associative, so a ++ (b ++ c) needs no parens.
997        prints_as(
998            &Expr::Builtin(
999                BuiltinOp::Concat,
1000                vec![
1001                    Expr::Var(Arc::from("a")),
1002                    Expr::Builtin(
1003                        BuiltinOp::Concat,
1004                        vec![Expr::Var(Arc::from("b")), Expr::Var(Arc::from("c"))],
1005                    ),
1006                ],
1007            ),
1008            "a ++ b ++ c",
1009        );
1010    }
1011
1012    #[test]
1013    fn infix_comparison() {
1014        prints_as(
1015            &Expr::Builtin(
1016                BuiltinOp::Eq,
1017                vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))],
1018            ),
1019            "x == 1",
1020        );
1021        prints_as(
1022            &Expr::Builtin(
1023                BuiltinOp::Neq,
1024                vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))],
1025            ),
1026            "x /= 1",
1027        );
1028        prints_as(
1029            &Expr::Builtin(
1030                BuiltinOp::Lt,
1031                vec![Expr::Var(Arc::from("x")), Expr::Var(Arc::from("y"))],
1032            ),
1033            "x < y",
1034        );
1035    }
1036
1037    #[test]
1038    fn infix_logical() {
1039        prints_as(
1040            &Expr::Builtin(
1041                BuiltinOp::And,
1042                vec![Expr::Var(Arc::from("a")), Expr::Var(Arc::from("b"))],
1043            ),
1044            "a && b",
1045        );
1046        prints_as(
1047            &Expr::Builtin(
1048                BuiltinOp::Or,
1049                vec![Expr::Var(Arc::from("a")), Expr::Var(Arc::from("b"))],
1050            ),
1051            "a || b",
1052        );
1053    }
1054
1055    #[test]
1056    fn infix_round_trips() {
1057        round_trip("1 + 2");
1058        round_trip("1 + 2 * 3");
1059        round_trip("(1 + 2) * 3");
1060        round_trip("a && b || c");
1061        round_trip("x == 1");
1062        round_trip("x /= 1");
1063    }
1064
1065    // ── Prefix operators ──────────────────────────────────────────
1066
1067    #[test]
1068    fn prefix_neg() {
1069        prints_as(
1070            &Expr::Builtin(BuiltinOp::Neg, vec![Expr::Var(Arc::from("x"))]),
1071            "-x",
1072        );
1073    }
1074
1075    #[test]
1076    fn prefix_not() {
1077        prints_as(
1078            &Expr::Builtin(BuiltinOp::Not, vec![Expr::Lit(Literal::Bool(true))]),
1079            "not True",
1080        );
1081    }
1082
1083    #[test]
1084    fn prefix_round_trip() {
1085        round_trip("-x");
1086        round_trip("not True");
1087    }
1088
1089    // ── Builtin function call syntax ──────────────────────────────
1090
1091    #[test]
1092    fn builtin_function_call() {
1093        prints_as(
1094            &Expr::Builtin(
1095                BuiltinOp::Map,
1096                vec![Expr::Var(Arc::from("f")), Expr::Var(Arc::from("xs"))],
1097            ),
1098            "map f xs",
1099        );
1100    }
1101
1102    #[test]
1103    fn builtin_unary() {
1104        prints_as(
1105            &Expr::Builtin(BuiltinOp::Head, vec![Expr::Var(Arc::from("xs"))]),
1106            "head xs",
1107        );
1108    }
1109
1110    #[test]
1111    fn builtin_round_trip() {
1112        round_trip("map f xs");
1113        round_trip("head xs");
1114        round_trip("filter f xs");
1115    }
1116
1117    // ── Let ───────────────────────────────────────────────────────
1118
1119    #[test]
1120    fn let_simple() {
1121        prints_as(
1122            &Expr::Let {
1123                name: Arc::from("x"),
1124                value: Box::new(Expr::Lit(Literal::Int(1))),
1125                body: Box::new(Expr::Builtin(
1126                    BuiltinOp::Add,
1127                    vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))],
1128                )),
1129            },
1130            "let x = 1 in x + 1",
1131        );
1132    }
1133
1134    #[test]
1135    fn let_round_trip() {
1136        round_trip("let x = 1 in x + 1");
1137    }
1138
1139    // ── If/then/else ──────────────────────────────────────────────
1140
1141    #[test]
1142    fn if_then_else() {
1143        let expr = Expr::Match {
1144            scrutinee: Box::new(Expr::Lit(Literal::Bool(true))),
1145            arms: vec![
1146                (
1147                    Pattern::Lit(Literal::Bool(true)),
1148                    Expr::Lit(Literal::Int(1)),
1149                ),
1150                (Pattern::Wildcard, Expr::Lit(Literal::Int(0))),
1151            ],
1152        };
1153        prints_as(&expr, "if True then 1 else 0");
1154    }
1155
1156    #[test]
1157    fn if_round_trip() {
1158        round_trip("if True then 1 else 0");
1159    }
1160
1161    // ── Case/of ───────────────────────────────────────────────────
1162
1163    #[test]
1164    fn case_of() {
1165        let expr = Expr::Match {
1166            scrutinee: Box::new(Expr::Var(Arc::from("x"))),
1167            arms: vec![
1168                (
1169                    Pattern::Lit(Literal::Bool(true)),
1170                    Expr::Lit(Literal::Int(1)),
1171                ),
1172                (
1173                    Pattern::Lit(Literal::Bool(false)),
1174                    Expr::Lit(Literal::Int(0)),
1175                ),
1176            ],
1177        };
1178        prints_as(&expr, "case x of\n  True -> 1\n  False -> 0");
1179    }
1180
1181    #[test]
1182    fn case_round_trip() {
1183        round_trip("case x of\n  True -> 1\n  False -> 0");
1184    }
1185
1186    // ── Nested expressions ────────────────────────────────────────
1187
1188    #[test]
1189    fn nested_let_in_lambda() {
1190        round_trip("\\x -> let y = x + 1 in y * 2");
1191    }
1192
1193    #[test]
1194    fn nested_if_in_let() {
1195        round_trip("let x = if True then 1 else 0 in x + 1");
1196    }
1197
1198    #[test]
1199    fn lambda_as_arg() {
1200        // f (\x -> x) should parenthesize the lambda argument
1201        prints_as(
1202            &Expr::App(
1203                Box::new(Expr::Var(Arc::from("f"))),
1204                Box::new(Expr::Lam(
1205                    Arc::from("x"),
1206                    Box::new(Expr::Var(Arc::from("x"))),
1207                )),
1208            ),
1209            "f (\\x -> x)",
1210        );
1211    }
1212
1213    #[test]
1214    fn complex_expression_round_trip() {
1215        round_trip("\\f xs -> map (\\x -> f x + 1) xs");
1216    }
1217
1218    // ── Pattern printing ──────────────────────────────────────────
1219
1220    #[test]
1221    fn pattern_wildcard() {
1222        let mut buf = String::new();
1223        write_pattern(&mut buf, &Pattern::Wildcard);
1224        assert_eq!(buf, "_");
1225    }
1226
1227    #[test]
1228    fn pattern_var() {
1229        let mut buf = String::new();
1230        write_pattern(&mut buf, &Pattern::Var(Arc::from("x")));
1231        assert_eq!(buf, "x");
1232    }
1233
1234    #[test]
1235    fn pattern_lit() {
1236        let mut buf = String::new();
1237        write_pattern(&mut buf, &Pattern::Lit(Literal::Int(42)));
1238        assert_eq!(buf, "42");
1239    }
1240
1241    #[test]
1242    fn pattern_list() {
1243        let mut buf = String::new();
1244        write_pattern(
1245            &mut buf,
1246            &Pattern::List(vec![
1247                Pattern::Var(Arc::from("x")),
1248                Pattern::Var(Arc::from("y")),
1249            ]),
1250        );
1251        assert_eq!(buf, "[x, y]");
1252    }
1253
1254    #[test]
1255    fn pattern_record_punning() {
1256        let mut buf = String::new();
1257        write_pattern(
1258            &mut buf,
1259            &Pattern::Record(vec![
1260                (Arc::from("x"), Pattern::Var(Arc::from("x"))),
1261                (Arc::from("y"), Pattern::Var(Arc::from("y"))),
1262            ]),
1263        );
1264        assert_eq!(buf, "{ x, y }");
1265    }
1266
1267    #[test]
1268    fn pattern_constructor() {
1269        let mut buf = String::new();
1270        write_pattern(
1271            &mut buf,
1272            &Pattern::Constructor(Arc::from("Just"), vec![Pattern::Var(Arc::from("x"))]),
1273        );
1274        assert_eq!(buf, "Just x");
1275    }
1276
1277    // ── Index ─────────────────────────────────────────────────────
1278
1279    #[test]
1280    fn index_access() {
1281        prints_as(
1282            &Expr::Index(
1283                Box::new(Expr::Var(Arc::from("xs"))),
1284                Box::new(Expr::Lit(Literal::Int(0))),
1285            ),
1286            "xs[0]",
1287        );
1288    }
1289
1290    // ── Literal record and list ───────────────────────────────────
1291
1292    #[test]
1293    fn literal_record() {
1294        prints_as(
1295            &Expr::Lit(Literal::Record(vec![
1296                (Arc::from("x"), Literal::Int(1)),
1297                (Arc::from("y"), Literal::Int(2)),
1298            ])),
1299            "{ x = 1, y = 2 }",
1300        );
1301    }
1302
1303    #[test]
1304    fn literal_list() {
1305        prints_as(
1306            &Expr::Lit(Literal::List(vec![Literal::Int(1), Literal::Int(2)])),
1307            "[1, 2]",
1308        );
1309    }
1310
1311    // ── Mixed precedence round trips ──────────────────────────────
1312
1313    #[test]
1314    fn precedence_logical_and_comparison() {
1315        round_trip("x == 1 && y == 2");
1316    }
1317
1318    #[test]
1319    fn precedence_arithmetic_in_comparison() {
1320        round_trip("x + 1 == y * 2");
1321    }
1322
1323    #[test]
1324    fn concat_round_trip() {
1325        round_trip(r#""hello" ++ " world""#);
1326    }
1327}