Skip to main content

panproto_expr_parser/
pretty.rs

1//! Pretty printer for panproto expressions.
2//!
3//! Converts `panproto_expr::Expr` back into Haskell-style surface syntax.
4//! The output is designed to round-trip through the parser:
5//! `parse(tokenize(pretty_print(&e))) == e` for well-formed expressions.
6//!
7//! Parentheses are minimized using precedence awareness, and operators
8//! are printed in infix notation where the parser supports it.
9
10use std::fmt::Write;
11use std::sync::Arc;
12
13use panproto_expr::{BuiltinOp, Expr, Literal, Pattern};
14
15/// Precedence levels (higher binds tighter).
16///
17/// These mirror the Pratt parser precedences in `parser.rs`.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
19enum Prec {
20    /// Top level: no parens needed.
21    Top = 0,
22    /// Pipe operator (`&`).
23    Pipe = 1,
24    /// Logical or (`||`).
25    Or = 3,
26    /// Logical and (`&&`).
27    And = 4,
28    /// Comparison (`==`, `/=`, `<`, `<=`, `>`, `>=`).
29    Cmp = 5,
30    /// Concatenation (`++`).
31    Concat = 6,
32    /// Addition and subtraction (`+`, `-`).
33    AddSub = 7,
34    /// Multiplication, division, modulo (`*`, `/`, `%`, `mod`, `div`).
35    MulDiv = 8,
36    /// Unary prefix (`-`, `not`).
37    Unary = 9,
38    /// Function application.
39    App = 10,
40    /// Postfix (`.field`, `->edge`), atoms.
41    Atom = 11,
42}
43
44/// Associativity of a binary operator.
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46enum Assoc {
47    Left,
48    Right,
49}
50
51/// Pretty print an expression to a string.
52///
53/// The output uses Haskell-style surface syntax with minimal parentheses.
54///
55/// # Examples
56///
57/// ```
58/// use std::sync::Arc;
59/// use panproto_expr::{Expr, Literal, BuiltinOp};
60/// use panproto_expr_parser::pretty_print;
61///
62/// let e = Expr::Builtin(BuiltinOp::Add, vec![
63///     Expr::Var(Arc::from("x")),
64///     Expr::Lit(Literal::Int(1)),
65/// ]);
66/// assert_eq!(pretty_print(&e), "x + 1");
67/// ```
68#[must_use]
69pub fn pretty_print(expr: &Expr) -> String {
70    let mut buf = String::new();
71    write_expr(&mut buf, expr, Prec::Top);
72    buf
73}
74
75/// Write an expression at the given precedence context.
76///
77/// If the expression's own precedence is lower than `ctx`, wraps in parens.
78fn write_expr(buf: &mut String, expr: &Expr, ctx: Prec) {
79    match expr {
80        Expr::Var(name) => buf.push_str(name),
81
82        Expr::Lit(lit) => write_literal(buf, lit),
83
84        Expr::Lam(param, body) => {
85            let needs_parens = ctx > Prec::Top;
86            if needs_parens {
87                buf.push('(');
88            }
89            write_lambda_chain(buf, param, body);
90            if needs_parens {
91                buf.push(')');
92            }
93        }
94
95        Expr::App(func, arg) => {
96            write_app(buf, expr, ctx);
97            let _ = (func, arg); // used inside write_app
98        }
99
100        Expr::Record(fields) => {
101            write_record_expr(buf, fields);
102        }
103
104        Expr::List(items) => {
105            buf.push('[');
106            for (i, item) in items.iter().enumerate() {
107                if i > 0 {
108                    buf.push_str(", ");
109                }
110                write_expr(buf, item, Prec::Top);
111            }
112            buf.push(']');
113        }
114
115        Expr::Field(inner, name) => {
116            write_expr(buf, inner, Prec::Atom);
117            buf.push('.');
118            buf.push_str(name);
119        }
120
121        Expr::Index(inner, idx) => {
122            write_expr(buf, inner, Prec::Atom);
123            buf.push('[');
124            write_expr(buf, idx, Prec::Top);
125            buf.push(']');
126        }
127
128        Expr::Match { scrutinee, arms } => {
129            write_match(buf, scrutinee, arms, ctx);
130        }
131
132        Expr::Let { name, value, body } => {
133            write_let(buf, name, value, body, ctx);
134        }
135
136        Expr::Builtin(op, args) => {
137            write_builtin(buf, *op, args, ctx);
138        }
139    }
140}
141
142/// Write a chain of nested lambdas as `\x y z -> body`.
143fn write_lambda_chain(buf: &mut String, first_param: &Arc<str>, first_body: &Expr) {
144    buf.push('\\');
145    buf.push_str(first_param);
146    let mut body = first_body;
147    while let Expr::Lam(param, inner) = body {
148        buf.push(' ');
149        buf.push_str(param);
150        body = inner;
151    }
152    buf.push_str(" -> ");
153    write_expr(buf, body, Prec::Top);
154}
155
156/// Write function application, collecting curried args: `f x y z`.
157fn write_app(buf: &mut String, expr: &Expr, ctx: Prec) {
158    let needs_parens = ctx > Prec::App;
159    if needs_parens {
160        buf.push('(');
161    }
162
163    // Collect the application spine.
164    let mut spine: Vec<&Expr> = Vec::new();
165    let mut head = expr;
166    while let Expr::App(func, arg) = head {
167        spine.push(arg);
168        head = func;
169    }
170    spine.reverse();
171
172    write_expr(buf, head, Prec::App);
173    for arg in &spine {
174        buf.push(' ');
175        write_expr(buf, arg, Prec::Atom);
176    }
177
178    if needs_parens {
179        buf.push(')');
180    }
181}
182
183/// Write a record expression with punning where appropriate.
184fn write_record_expr(buf: &mut String, fields: &[(Arc<str>, Expr)]) {
185    buf.push_str("{ ");
186    for (i, (name, val)) in fields.iter().enumerate() {
187        if i > 0 {
188            buf.push_str(", ");
189        }
190        // Record punning: `{ x }` when field name equals variable name.
191        if let Expr::Var(v) = val {
192            if v == name {
193                buf.push_str(name);
194                continue;
195            }
196        }
197        buf.push_str(name);
198        buf.push_str(" = ");
199        write_expr(buf, val, Prec::Top);
200    }
201    buf.push_str(" }");
202}
203
204/// Write a match expression.
205///
206/// Detects `if/then/else` patterns (two arms with True and Wildcard)
207/// and emits those in the shorter form.
208fn write_match(buf: &mut String, scrutinee: &Expr, arms: &[(Pattern, Expr)], ctx: Prec) {
209    // Detect if/then/else: Match with [Lit(Bool(true)) -> then, Wildcard -> else]
210    if arms.len() == 2 {
211        if let (Pattern::Lit(Literal::Bool(true)), then_branch) = &arms[0] {
212            if let (Pattern::Wildcard, else_branch) = &arms[1] {
213                let needs_parens = ctx > Prec::Top;
214                if needs_parens {
215                    buf.push('(');
216                }
217                buf.push_str("if ");
218                write_expr(buf, scrutinee, Prec::Top);
219                buf.push_str(" then ");
220                write_expr(buf, then_branch, Prec::Top);
221                buf.push_str(" else ");
222                write_expr(buf, else_branch, Prec::Top);
223                if needs_parens {
224                    buf.push(')');
225                }
226                return;
227            }
228        }
229    }
230
231    let needs_parens = ctx > Prec::Top;
232    if needs_parens {
233        buf.push('(');
234    }
235    buf.push_str("case ");
236    write_expr(buf, scrutinee, Prec::Top);
237    buf.push_str(" of\n");
238    for (i, (pat, body)) in arms.iter().enumerate() {
239        if i > 0 {
240            buf.push('\n');
241        }
242        buf.push_str("  ");
243        write_pattern(buf, pat);
244        buf.push_str(" -> ");
245        write_expr(buf, body, Prec::Top);
246    }
247    if needs_parens {
248        buf.push(')');
249    }
250}
251
252/// Write a let binding, collapsing nested lets into a layout block.
253fn write_let(buf: &mut String, name: &Arc<str>, value: &Expr, body: &Expr, ctx: Prec) {
254    let needs_parens = ctx > Prec::Top;
255    if needs_parens {
256        buf.push('(');
257    }
258
259    // Collect chained lets.
260    let mut bindings: Vec<(&Arc<str>, &Expr)> = vec![(name, value)];
261    let mut final_body = body;
262    while let Expr::Let {
263        name: n,
264        value: v,
265        body: b,
266    } = final_body
267    {
268        bindings.push((n, v));
269        final_body = b;
270    }
271
272    if bindings.len() == 1 {
273        buf.push_str("let ");
274        buf.push_str(name);
275        buf.push_str(" = ");
276        write_expr(buf, value, Prec::Top);
277        buf.push_str(" in ");
278    } else {
279        buf.push_str("let\n");
280        for (n, v) in &bindings {
281            buf.push_str("  ");
282            buf.push_str(n);
283            buf.push_str(" = ");
284            write_expr(buf, v, Prec::Top);
285            buf.push('\n');
286        }
287        buf.push_str("in ");
288    }
289    write_expr(buf, final_body, Prec::Top);
290
291    if needs_parens {
292        buf.push(')');
293    }
294}
295
296/// Write a builtin operation, using infix/prefix syntax where possible.
297fn write_builtin(buf: &mut String, op: BuiltinOp, args: &[Expr], ctx: Prec) {
298    // Try infix binary operators.
299    if let Some((sym, prec, assoc)) = infix_info(op) {
300        if args.len() == 2 {
301            let needs_parens = ctx > prec;
302            if needs_parens {
303                buf.push('(');
304            }
305            // For left-associative operators, the left child is fine at the
306            // same precedence but the right child needs to be tighter (to
307            // avoid ambiguity). Vice versa for right-associative.
308            let (left_ctx, right_ctx) = match assoc {
309                Assoc::Left => (prec, next_prec(prec)),
310                Assoc::Right => (next_prec(prec), prec),
311            };
312            write_expr(buf, &args[0], left_ctx);
313            buf.push(' ');
314            buf.push_str(sym);
315            buf.push(' ');
316            write_expr(buf, &args[1], right_ctx);
317            if needs_parens {
318                buf.push(')');
319            }
320            return;
321        }
322    }
323
324    // Edge traversal: `expr -> edge`
325    if op == BuiltinOp::Edge && args.len() == 2 {
326        if let Expr::Lit(Literal::Str(edge_name)) = &args[1] {
327            let needs_parens = ctx > Prec::Atom;
328            if needs_parens {
329                buf.push('(');
330            }
331            write_expr(buf, &args[0], Prec::Atom);
332            buf.push_str(" -> ");
333            buf.push_str(edge_name);
334            if needs_parens {
335                buf.push(')');
336            }
337            return;
338        }
339    }
340
341    // Unary prefix: negation and logical not.
342    if op == BuiltinOp::Neg && args.len() == 1 {
343        let needs_parens = ctx > Prec::Unary;
344        if needs_parens {
345            buf.push('(');
346        }
347        buf.push('-');
348        write_expr(buf, &args[0], Prec::Atom);
349        if needs_parens {
350            buf.push(')');
351        }
352        return;
353    }
354
355    if op == BuiltinOp::Not && args.len() == 1 {
356        let needs_parens = ctx > Prec::Unary;
357        if needs_parens {
358            buf.push('(');
359        }
360        buf.push_str("not ");
361        write_expr(buf, &args[0], Prec::Atom);
362        if needs_parens {
363            buf.push(')');
364        }
365        return;
366    }
367
368    // Fallback: function call syntax `name arg1 arg2 ...`
369    let needs_parens = ctx > Prec::App && !args.is_empty();
370    if needs_parens {
371        buf.push('(');
372    }
373    buf.push_str(builtin_name(op));
374    for arg in args {
375        buf.push(' ');
376        write_expr(buf, arg, Prec::Atom);
377    }
378    if needs_parens {
379        buf.push(')');
380    }
381}
382
383/// Map a builtin op to its infix operator symbol, precedence, and associativity.
384///
385/// Returns `None` for builtins that should use function call syntax.
386const fn infix_info(op: BuiltinOp) -> Option<(&'static str, Prec, Assoc)> {
387    match op {
388        BuiltinOp::Or => Some(("||", Prec::Or, Assoc::Left)),
389        BuiltinOp::And => Some(("&&", Prec::And, Assoc::Left)),
390        BuiltinOp::Eq => Some(("==", Prec::Cmp, Assoc::Right)),
391        BuiltinOp::Neq => Some(("/=", Prec::Cmp, Assoc::Right)),
392        BuiltinOp::Lt => Some(("<", Prec::Cmp, Assoc::Right)),
393        BuiltinOp::Lte => Some(("<=", Prec::Cmp, Assoc::Right)),
394        BuiltinOp::Gt => Some((">", Prec::Cmp, Assoc::Right)),
395        BuiltinOp::Gte => Some((">=", Prec::Cmp, Assoc::Right)),
396        BuiltinOp::Concat => Some(("++", Prec::Concat, Assoc::Right)),
397        BuiltinOp::Add => Some(("+", Prec::AddSub, Assoc::Left)),
398        BuiltinOp::Sub => Some(("-", Prec::AddSub, Assoc::Left)),
399        BuiltinOp::Mul => Some(("*", Prec::MulDiv, Assoc::Left)),
400        BuiltinOp::Div => Some(("/", Prec::MulDiv, Assoc::Left)),
401        BuiltinOp::Mod => Some(("%", Prec::MulDiv, Assoc::Left)),
402        _ => None,
403    }
404}
405
406/// Get the next higher precedence level.
407const fn next_prec(p: Prec) -> Prec {
408    match p {
409        Prec::Top => Prec::Pipe,
410        Prec::Pipe => Prec::Or,
411        Prec::Or => Prec::And,
412        Prec::And => Prec::Cmp,
413        Prec::Cmp => Prec::Concat,
414        Prec::Concat => Prec::AddSub,
415        Prec::AddSub => Prec::MulDiv,
416        Prec::MulDiv => Prec::Unary,
417        Prec::Unary => Prec::App,
418        Prec::App | Prec::Atom => Prec::Atom,
419    }
420}
421
422/// Map a builtin op to its canonical function name for call syntax.
423const fn builtin_name(op: BuiltinOp) -> &'static str {
424    match op {
425        BuiltinOp::Add => "add",
426        BuiltinOp::Sub => "sub",
427        BuiltinOp::Mul => "mul",
428        BuiltinOp::Div => "div",
429        BuiltinOp::Mod => "mod",
430        BuiltinOp::Neg => "neg",
431        BuiltinOp::Abs => "abs",
432        BuiltinOp::Floor => "floor",
433        BuiltinOp::Ceil => "ceil",
434        BuiltinOp::Round => "round",
435        BuiltinOp::Eq => "eq",
436        BuiltinOp::Neq => "neq",
437        BuiltinOp::Lt => "lt",
438        BuiltinOp::Lte => "lte",
439        BuiltinOp::Gt => "gt",
440        BuiltinOp::Gte => "gte",
441        BuiltinOp::And => "and",
442        BuiltinOp::Or => "or",
443        BuiltinOp::Not => "not",
444        BuiltinOp::Concat => "concat",
445        BuiltinOp::Len => "len",
446        BuiltinOp::Slice => "slice",
447        BuiltinOp::Upper => "upper",
448        BuiltinOp::Lower => "lower",
449        BuiltinOp::Trim => "trim",
450        BuiltinOp::Split => "split",
451        BuiltinOp::Join => "join",
452        BuiltinOp::Replace => "replace",
453        BuiltinOp::Contains => "contains",
454        BuiltinOp::Map => "map",
455        BuiltinOp::Filter => "filter",
456        BuiltinOp::Fold => "fold",
457        BuiltinOp::Append => "append",
458        BuiltinOp::Head => "head",
459        BuiltinOp::Tail => "tail",
460        BuiltinOp::Reverse => "reverse",
461        BuiltinOp::FlatMap => "flat_map",
462        BuiltinOp::Length => "length",
463        BuiltinOp::MergeRecords => "merge",
464        BuiltinOp::Keys => "keys",
465        BuiltinOp::Values => "values",
466        BuiltinOp::HasField => "has_field",
467        BuiltinOp::DefaultVal => "default",
468        BuiltinOp::Clamp => "clamp",
469        BuiltinOp::TruncateStr => "truncate_str",
470        BuiltinOp::IntToFloat => "int_to_float",
471        BuiltinOp::FloatToInt => "float_to_int",
472        BuiltinOp::IntToStr => "int_to_str",
473        BuiltinOp::FloatToStr => "float_to_str",
474        BuiltinOp::StrToInt => "str_to_int",
475        BuiltinOp::StrToFloat => "str_to_float",
476        BuiltinOp::TypeOf => "type_of",
477        BuiltinOp::IsNull => "is_null",
478        BuiltinOp::IsList => "is_list",
479        BuiltinOp::Edge => "edge",
480        BuiltinOp::Children => "children",
481        BuiltinOp::HasEdge => "has_edge",
482        BuiltinOp::EdgeCount => "edge_count",
483        BuiltinOp::Anchor => "anchor",
484    }
485}
486
487/// Write a literal value.
488fn write_literal(buf: &mut String, lit: &Literal) {
489    match lit {
490        Literal::Bool(true) => buf.push_str("True"),
491        Literal::Bool(false) => buf.push_str("False"),
492        Literal::Int(n) => {
493            let _ = write!(buf, "{n}");
494        }
495        Literal::Float(f) => {
496            // Ensure there is always a decimal point so the parser
497            // recognizes this as a float, not an int.
498            let s = format!("{f}");
499            if s.contains('.') {
500                buf.push_str(&s);
501            } else {
502                let _ = write!(buf, "{f}.0");
503            }
504        }
505        Literal::Str(s) => {
506            buf.push('"');
507            // Escape backslashes and double quotes.
508            for ch in s.chars() {
509                match ch {
510                    '\\' => buf.push_str("\\\\"),
511                    '"' => buf.push_str("\\\""),
512                    '\n' => buf.push_str("\\n"),
513                    '\r' => buf.push_str("\\r"),
514                    '\t' => buf.push_str("\\t"),
515                    c => buf.push(c),
516                }
517            }
518            buf.push('"');
519        }
520        Literal::Bytes(bytes) => {
521            // No native bytes syntax; emit as a list of ints.
522            buf.push('[');
523            for (i, b) in bytes.iter().enumerate() {
524                if i > 0 {
525                    buf.push_str(", ");
526                }
527                let _ = write!(buf, "{b}");
528            }
529            buf.push(']');
530        }
531        Literal::Null => buf.push_str("Nothing"),
532        Literal::Record(fields) => {
533            buf.push_str("{ ");
534            for (i, (name, val)) in fields.iter().enumerate() {
535                if i > 0 {
536                    buf.push_str(", ");
537                }
538                buf.push_str(name);
539                buf.push_str(" = ");
540                write_literal(buf, val);
541            }
542            buf.push_str(" }");
543        }
544        Literal::List(items) => {
545            buf.push('[');
546            for (i, item) in items.iter().enumerate() {
547                if i > 0 {
548                    buf.push_str(", ");
549                }
550                write_literal(buf, item);
551            }
552            buf.push(']');
553        }
554        Literal::Closure { param, body, .. } => {
555            // Print as a lambda; the captured env is lost but the
556            // expression form is preserved for round-tripping.
557            buf.push('\\');
558            buf.push_str(param);
559            buf.push_str(" -> ");
560            write_expr(buf, body, Prec::Top);
561        }
562    }
563}
564
565/// Write a pattern.
566fn write_pattern(buf: &mut String, pat: &Pattern) {
567    match pat {
568        Pattern::Wildcard => buf.push('_'),
569        Pattern::Var(name) => buf.push_str(name),
570        Pattern::Lit(lit) => write_literal(buf, lit),
571        Pattern::Record(fields) => {
572            buf.push_str("{ ");
573            for (i, (name, p)) in fields.iter().enumerate() {
574                if i > 0 {
575                    buf.push_str(", ");
576                }
577                // Record pattern punning: `{ x }` when field pattern is Var(x).
578                if let Pattern::Var(v) = p {
579                    if v == name {
580                        buf.push_str(name);
581                        continue;
582                    }
583                }
584                buf.push_str(name);
585                buf.push_str(" = ");
586                write_pattern(buf, p);
587            }
588            buf.push_str(" }");
589        }
590        Pattern::List(pats) => {
591            buf.push('[');
592            for (i, p) in pats.iter().enumerate() {
593                if i > 0 {
594                    buf.push_str(", ");
595                }
596                write_pattern(buf, p);
597            }
598            buf.push(']');
599        }
600        Pattern::Constructor(name, args) => {
601            buf.push_str(name);
602            for arg in args {
603                buf.push(' ');
604                // Wrap constructor args in parens if they are themselves
605                // constructors with args (to avoid ambiguity).
606                let needs_parens = matches!(arg, Pattern::Constructor(_, a) if !a.is_empty());
607                if needs_parens {
608                    buf.push('(');
609                }
610                write_pattern(buf, arg);
611                if needs_parens {
612                    buf.push(')');
613                }
614            }
615        }
616    }
617}
618
619#[cfg(test)]
620mod tests {
621    use super::*;
622    use crate::{parse, tokenize};
623
624    /// Parse a string, pretty print it, re-parse, and verify equality.
625    fn round_trip(input: &str) {
626        let tokens1 = tokenize(input).unwrap_or_else(|e| panic!("first lex failed: {e}"));
627        let expr1 = parse(&tokens1).unwrap_or_else(|e| panic!("first parse failed: {e:?}"));
628        let printed = pretty_print(&expr1);
629        let tokens2 = tokenize(&printed).unwrap_or_else(|e| {
630            panic!("re-lex failed for {printed:?}: {e}");
631        });
632        let expr2 = parse(&tokens2).unwrap_or_else(|e| {
633            panic!("re-parse failed for {printed:?}: {e:?}");
634        });
635        assert_eq!(
636            expr1, expr2,
637            "round trip failed.\n  input:   {input:?}\n  printed: {printed:?}"
638        );
639    }
640
641    /// Pretty print an expression built programmatically and check output.
642    fn prints_as(expr: &Expr, expected: &str) {
643        let actual = pretty_print(expr);
644        assert_eq!(actual, expected, "pretty_print mismatch");
645    }
646
647    // ── Literals ──────────────────────────────────────────────────
648
649    #[test]
650    fn lit_int() {
651        prints_as(&Expr::Lit(Literal::Int(42)), "42");
652    }
653
654    #[test]
655    fn lit_negative_int() {
656        prints_as(&Expr::Lit(Literal::Int(-5)), "-5");
657    }
658
659    #[test]
660    fn lit_float() {
661        prints_as(&Expr::Lit(Literal::Float(3.125)), "3.125");
662    }
663
664    #[test]
665    fn lit_string() {
666        prints_as(&Expr::Lit(Literal::Str("hello".into())), r#""hello""#);
667    }
668
669    #[test]
670    fn lit_string_escapes() {
671        prints_as(
672            &Expr::Lit(Literal::Str("say \"hi\"".into())),
673            r#""say \"hi\"""#,
674        );
675    }
676
677    #[test]
678    fn lit_bool() {
679        prints_as(&Expr::Lit(Literal::Bool(true)), "True");
680        prints_as(&Expr::Lit(Literal::Bool(false)), "False");
681    }
682
683    #[test]
684    fn lit_null() {
685        prints_as(&Expr::Lit(Literal::Null), "Nothing");
686    }
687
688    #[test]
689    fn lit_bytes() {
690        prints_as(&Expr::Lit(Literal::Bytes(vec![1, 2, 3])), "[1, 2, 3]");
691    }
692
693    // ── Variables ─────────────────────────────────────────────────
694
695    #[test]
696    fn variable() {
697        prints_as(&Expr::Var(Arc::from("x")), "x");
698    }
699
700    // ── Lambda ────────────────────────────────────────────────────
701
702    #[test]
703    fn lambda_simple() {
704        prints_as(
705            &Expr::Lam(Arc::from("x"), Box::new(Expr::Var(Arc::from("x")))),
706            "\\x -> x",
707        );
708    }
709
710    #[test]
711    fn lambda_multi_param() {
712        prints_as(
713            &Expr::Lam(
714                Arc::from("x"),
715                Box::new(Expr::Lam(
716                    Arc::from("y"),
717                    Box::new(Expr::Builtin(
718                        BuiltinOp::Add,
719                        vec![Expr::Var(Arc::from("x")), Expr::Var(Arc::from("y"))],
720                    )),
721                )),
722            ),
723            "\\x y -> x + y",
724        );
725    }
726
727    #[test]
728    fn lambda_round_trip() {
729        round_trip("\\x -> x + 1");
730        round_trip("\\x y -> x + y");
731    }
732
733    // ── Application ───────────────────────────────────────────────
734
735    #[test]
736    fn app_simple() {
737        prints_as(
738            &Expr::App(
739                Box::new(Expr::Var(Arc::from("f"))),
740                Box::new(Expr::Var(Arc::from("x"))),
741            ),
742            "f x",
743        );
744    }
745
746    #[test]
747    fn app_chain() {
748        prints_as(
749            &Expr::App(
750                Box::new(Expr::App(
751                    Box::new(Expr::Var(Arc::from("f"))),
752                    Box::new(Expr::Var(Arc::from("x"))),
753                )),
754                Box::new(Expr::Var(Arc::from("y"))),
755            ),
756            "f x y",
757        );
758    }
759
760    #[test]
761    fn app_complex_arg() {
762        // f (g x) should parenthesize the argument
763        prints_as(
764            &Expr::App(
765                Box::new(Expr::Var(Arc::from("f"))),
766                Box::new(Expr::App(
767                    Box::new(Expr::Var(Arc::from("g"))),
768                    Box::new(Expr::Var(Arc::from("x"))),
769                )),
770            ),
771            "f (g x)",
772        );
773    }
774
775    // ── Record ────────────────────────────────────────────────────
776
777    #[test]
778    fn record_simple() {
779        prints_as(
780            &Expr::Record(vec![
781                (Arc::from("x"), Expr::Lit(Literal::Int(1))),
782                (Arc::from("y"), Expr::Lit(Literal::Int(2))),
783            ]),
784            "{ x = 1, y = 2 }",
785        );
786    }
787
788    #[test]
789    fn record_punning() {
790        prints_as(
791            &Expr::Record(vec![
792                (Arc::from("x"), Expr::Var(Arc::from("x"))),
793                (Arc::from("y"), Expr::Var(Arc::from("y"))),
794            ]),
795            "{ x, y }",
796        );
797    }
798
799    #[test]
800    fn record_mixed_punning() {
801        prints_as(
802            &Expr::Record(vec![
803                (Arc::from("x"), Expr::Var(Arc::from("x"))),
804                (Arc::from("y"), Expr::Lit(Literal::Int(42))),
805            ]),
806            "{ x, y = 42 }",
807        );
808    }
809
810    #[test]
811    fn record_round_trip() {
812        round_trip("{ name = x, age = 30 }");
813        round_trip("{ x, y }");
814    }
815
816    // ── List ──────────────────────────────────────────────────────
817
818    #[test]
819    fn list_simple() {
820        prints_as(
821            &Expr::List(vec![
822                Expr::Lit(Literal::Int(1)),
823                Expr::Lit(Literal::Int(2)),
824                Expr::Lit(Literal::Int(3)),
825            ]),
826            "[1, 2, 3]",
827        );
828    }
829
830    #[test]
831    fn list_empty() {
832        prints_as(&Expr::List(vec![]), "[]");
833    }
834
835    #[test]
836    fn list_round_trip() {
837        round_trip("[1, 2, 3]");
838        round_trip("[]");
839    }
840
841    // ── Field access ──────────────────────────────────────────────
842
843    #[test]
844    fn field_access() {
845        prints_as(
846            &Expr::Field(Box::new(Expr::Var(Arc::from("x"))), Arc::from("name")),
847            "x.name",
848        );
849    }
850
851    #[test]
852    fn field_chain() {
853        prints_as(
854            &Expr::Field(
855                Box::new(Expr::Field(
856                    Box::new(Expr::Var(Arc::from("x"))),
857                    Arc::from("a"),
858                )),
859                Arc::from("b"),
860            ),
861            "x.a.b",
862        );
863    }
864
865    #[test]
866    fn field_round_trip() {
867        round_trip("x.name");
868        round_trip("x.a.b");
869    }
870
871    // ── Edge traversal ────────────────────────────────────────────
872
873    #[test]
874    fn edge_traversal() {
875        prints_as(
876            &Expr::Builtin(
877                BuiltinOp::Edge,
878                vec![
879                    Expr::Var(Arc::from("doc")),
880                    Expr::Lit(Literal::Str("layers".into())),
881                ],
882            ),
883            "doc -> layers",
884        );
885    }
886
887    #[test]
888    fn edge_chain() {
889        prints_as(
890            &Expr::Builtin(
891                BuiltinOp::Edge,
892                vec![
893                    Expr::Builtin(
894                        BuiltinOp::Edge,
895                        vec![
896                            Expr::Var(Arc::from("doc")),
897                            Expr::Lit(Literal::Str("layers".into())),
898                        ],
899                    ),
900                    Expr::Lit(Literal::Str("annotations".into())),
901                ],
902            ),
903            "doc -> layers -> annotations",
904        );
905    }
906
907    #[test]
908    fn edge_round_trip() {
909        round_trip("doc -> layers");
910        round_trip("doc -> layers -> annotations");
911    }
912
913    // ── Infix operators ───────────────────────────────────────────
914
915    #[test]
916    fn infix_add() {
917        prints_as(
918            &Expr::Builtin(
919                BuiltinOp::Add,
920                vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))],
921            ),
922            "x + 1",
923        );
924    }
925
926    #[test]
927    fn infix_precedence_no_parens() {
928        // 1 + 2 * 3 should not need parens because * binds tighter.
929        prints_as(
930            &Expr::Builtin(
931                BuiltinOp::Add,
932                vec![
933                    Expr::Lit(Literal::Int(1)),
934                    Expr::Builtin(
935                        BuiltinOp::Mul,
936                        vec![Expr::Lit(Literal::Int(2)), Expr::Lit(Literal::Int(3))],
937                    ),
938                ],
939            ),
940            "1 + 2 * 3",
941        );
942    }
943
944    #[test]
945    fn infix_precedence_needs_parens() {
946        // (1 + 2) * 3 needs parens because + is lower than *.
947        prints_as(
948            &Expr::Builtin(
949                BuiltinOp::Mul,
950                vec![
951                    Expr::Builtin(
952                        BuiltinOp::Add,
953                        vec![Expr::Lit(Literal::Int(1)), Expr::Lit(Literal::Int(2))],
954                    ),
955                    Expr::Lit(Literal::Int(3)),
956                ],
957            ),
958            "(1 + 2) * 3",
959        );
960    }
961
962    #[test]
963    fn infix_left_assoc_no_parens() {
964        // 1 + 2 + 3 is left-associative, so (1+2)+3 needs no parens.
965        prints_as(
966            &Expr::Builtin(
967                BuiltinOp::Add,
968                vec![
969                    Expr::Builtin(
970                        BuiltinOp::Add,
971                        vec![Expr::Lit(Literal::Int(1)), Expr::Lit(Literal::Int(2))],
972                    ),
973                    Expr::Lit(Literal::Int(3)),
974                ],
975            ),
976            "1 + 2 + 3",
977        );
978    }
979
980    #[test]
981    fn infix_right_assoc_needs_parens() {
982        // For left-assoc +, 1 + (2 + 3) needs parens on the right.
983        prints_as(
984            &Expr::Builtin(
985                BuiltinOp::Add,
986                vec![
987                    Expr::Lit(Literal::Int(1)),
988                    Expr::Builtin(
989                        BuiltinOp::Add,
990                        vec![Expr::Lit(Literal::Int(2)), Expr::Lit(Literal::Int(3))],
991                    ),
992                ],
993            ),
994            "1 + (2 + 3)",
995        );
996    }
997
998    #[test]
999    fn infix_concat_right_assoc() {
1000        // ++ is right-associative, so a ++ (b ++ c) needs no parens.
1001        prints_as(
1002            &Expr::Builtin(
1003                BuiltinOp::Concat,
1004                vec![
1005                    Expr::Var(Arc::from("a")),
1006                    Expr::Builtin(
1007                        BuiltinOp::Concat,
1008                        vec![Expr::Var(Arc::from("b")), Expr::Var(Arc::from("c"))],
1009                    ),
1010                ],
1011            ),
1012            "a ++ b ++ c",
1013        );
1014    }
1015
1016    #[test]
1017    fn infix_comparison() {
1018        prints_as(
1019            &Expr::Builtin(
1020                BuiltinOp::Eq,
1021                vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))],
1022            ),
1023            "x == 1",
1024        );
1025        prints_as(
1026            &Expr::Builtin(
1027                BuiltinOp::Neq,
1028                vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))],
1029            ),
1030            "x /= 1",
1031        );
1032        prints_as(
1033            &Expr::Builtin(
1034                BuiltinOp::Lt,
1035                vec![Expr::Var(Arc::from("x")), Expr::Var(Arc::from("y"))],
1036            ),
1037            "x < y",
1038        );
1039    }
1040
1041    #[test]
1042    fn infix_logical() {
1043        prints_as(
1044            &Expr::Builtin(
1045                BuiltinOp::And,
1046                vec![Expr::Var(Arc::from("a")), Expr::Var(Arc::from("b"))],
1047            ),
1048            "a && b",
1049        );
1050        prints_as(
1051            &Expr::Builtin(
1052                BuiltinOp::Or,
1053                vec![Expr::Var(Arc::from("a")), Expr::Var(Arc::from("b"))],
1054            ),
1055            "a || b",
1056        );
1057    }
1058
1059    #[test]
1060    fn infix_round_trips() {
1061        round_trip("1 + 2");
1062        round_trip("1 + 2 * 3");
1063        round_trip("(1 + 2) * 3");
1064        round_trip("a && b || c");
1065        round_trip("x == 1");
1066        round_trip("x /= 1");
1067    }
1068
1069    // ── Prefix operators ──────────────────────────────────────────
1070
1071    #[test]
1072    fn prefix_neg() {
1073        prints_as(
1074            &Expr::Builtin(BuiltinOp::Neg, vec![Expr::Var(Arc::from("x"))]),
1075            "-x",
1076        );
1077    }
1078
1079    #[test]
1080    fn prefix_not() {
1081        prints_as(
1082            &Expr::Builtin(BuiltinOp::Not, vec![Expr::Lit(Literal::Bool(true))]),
1083            "not True",
1084        );
1085    }
1086
1087    #[test]
1088    fn prefix_round_trip() {
1089        round_trip("-x");
1090        round_trip("not True");
1091    }
1092
1093    // ── Builtin function call syntax ──────────────────────────────
1094
1095    #[test]
1096    fn builtin_function_call() {
1097        prints_as(
1098            &Expr::Builtin(
1099                BuiltinOp::Map,
1100                vec![Expr::Var(Arc::from("f")), Expr::Var(Arc::from("xs"))],
1101            ),
1102            "map f xs",
1103        );
1104    }
1105
1106    #[test]
1107    fn builtin_unary() {
1108        prints_as(
1109            &Expr::Builtin(BuiltinOp::Head, vec![Expr::Var(Arc::from("xs"))]),
1110            "head xs",
1111        );
1112    }
1113
1114    #[test]
1115    fn builtin_round_trip() {
1116        round_trip("map f xs");
1117        round_trip("head xs");
1118        round_trip("filter f xs");
1119    }
1120
1121    // ── Let ───────────────────────────────────────────────────────
1122
1123    #[test]
1124    fn let_simple() {
1125        prints_as(
1126            &Expr::Let {
1127                name: Arc::from("x"),
1128                value: Box::new(Expr::Lit(Literal::Int(1))),
1129                body: Box::new(Expr::Builtin(
1130                    BuiltinOp::Add,
1131                    vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))],
1132                )),
1133            },
1134            "let x = 1 in x + 1",
1135        );
1136    }
1137
1138    #[test]
1139    fn let_round_trip() {
1140        round_trip("let x = 1 in x + 1");
1141    }
1142
1143    // ── If/then/else ──────────────────────────────────────────────
1144
1145    #[test]
1146    fn if_then_else() {
1147        let expr = Expr::Match {
1148            scrutinee: Box::new(Expr::Lit(Literal::Bool(true))),
1149            arms: vec![
1150                (
1151                    Pattern::Lit(Literal::Bool(true)),
1152                    Expr::Lit(Literal::Int(1)),
1153                ),
1154                (Pattern::Wildcard, Expr::Lit(Literal::Int(0))),
1155            ],
1156        };
1157        prints_as(&expr, "if True then 1 else 0");
1158    }
1159
1160    #[test]
1161    fn if_round_trip() {
1162        round_trip("if True then 1 else 0");
1163    }
1164
1165    // ── Case/of ───────────────────────────────────────────────────
1166
1167    #[test]
1168    fn case_of() {
1169        let expr = Expr::Match {
1170            scrutinee: Box::new(Expr::Var(Arc::from("x"))),
1171            arms: vec![
1172                (
1173                    Pattern::Lit(Literal::Bool(true)),
1174                    Expr::Lit(Literal::Int(1)),
1175                ),
1176                (
1177                    Pattern::Lit(Literal::Bool(false)),
1178                    Expr::Lit(Literal::Int(0)),
1179                ),
1180            ],
1181        };
1182        prints_as(&expr, "case x of\n  True -> 1\n  False -> 0");
1183    }
1184
1185    #[test]
1186    fn case_round_trip() {
1187        round_trip("case x of\n  True -> 1\n  False -> 0");
1188    }
1189
1190    // ── Nested expressions ────────────────────────────────────────
1191
1192    #[test]
1193    fn nested_let_in_lambda() {
1194        round_trip("\\x -> let y = x + 1 in y * 2");
1195    }
1196
1197    #[test]
1198    fn nested_if_in_let() {
1199        round_trip("let x = if True then 1 else 0 in x + 1");
1200    }
1201
1202    #[test]
1203    fn lambda_as_arg() {
1204        // f (\x -> x) should parenthesize the lambda argument
1205        prints_as(
1206            &Expr::App(
1207                Box::new(Expr::Var(Arc::from("f"))),
1208                Box::new(Expr::Lam(
1209                    Arc::from("x"),
1210                    Box::new(Expr::Var(Arc::from("x"))),
1211                )),
1212            ),
1213            "f (\\x -> x)",
1214        );
1215    }
1216
1217    #[test]
1218    fn complex_expression_round_trip() {
1219        round_trip("\\f xs -> map (\\x -> f x + 1) xs");
1220    }
1221
1222    // ── Pattern printing ──────────────────────────────────────────
1223
1224    #[test]
1225    fn pattern_wildcard() {
1226        let mut buf = String::new();
1227        write_pattern(&mut buf, &Pattern::Wildcard);
1228        assert_eq!(buf, "_");
1229    }
1230
1231    #[test]
1232    fn pattern_var() {
1233        let mut buf = String::new();
1234        write_pattern(&mut buf, &Pattern::Var(Arc::from("x")));
1235        assert_eq!(buf, "x");
1236    }
1237
1238    #[test]
1239    fn pattern_lit() {
1240        let mut buf = String::new();
1241        write_pattern(&mut buf, &Pattern::Lit(Literal::Int(42)));
1242        assert_eq!(buf, "42");
1243    }
1244
1245    #[test]
1246    fn pattern_list() {
1247        let mut buf = String::new();
1248        write_pattern(
1249            &mut buf,
1250            &Pattern::List(vec![
1251                Pattern::Var(Arc::from("x")),
1252                Pattern::Var(Arc::from("y")),
1253            ]),
1254        );
1255        assert_eq!(buf, "[x, y]");
1256    }
1257
1258    #[test]
1259    fn pattern_record_punning() {
1260        let mut buf = String::new();
1261        write_pattern(
1262            &mut buf,
1263            &Pattern::Record(vec![
1264                (Arc::from("x"), Pattern::Var(Arc::from("x"))),
1265                (Arc::from("y"), Pattern::Var(Arc::from("y"))),
1266            ]),
1267        );
1268        assert_eq!(buf, "{ x, y }");
1269    }
1270
1271    #[test]
1272    fn pattern_constructor() {
1273        let mut buf = String::new();
1274        write_pattern(
1275            &mut buf,
1276            &Pattern::Constructor(Arc::from("Just"), vec![Pattern::Var(Arc::from("x"))]),
1277        );
1278        assert_eq!(buf, "Just x");
1279    }
1280
1281    // ── Index ─────────────────────────────────────────────────────
1282
1283    #[test]
1284    fn index_access() {
1285        prints_as(
1286            &Expr::Index(
1287                Box::new(Expr::Var(Arc::from("xs"))),
1288                Box::new(Expr::Lit(Literal::Int(0))),
1289            ),
1290            "xs[0]",
1291        );
1292    }
1293
1294    // ── Literal record and list ───────────────────────────────────
1295
1296    #[test]
1297    fn literal_record() {
1298        prints_as(
1299            &Expr::Lit(Literal::Record(vec![
1300                (Arc::from("x"), Literal::Int(1)),
1301                (Arc::from("y"), Literal::Int(2)),
1302            ])),
1303            "{ x = 1, y = 2 }",
1304        );
1305    }
1306
1307    #[test]
1308    fn literal_list() {
1309        prints_as(
1310            &Expr::Lit(Literal::List(vec![Literal::Int(1), Literal::Int(2)])),
1311            "[1, 2]",
1312        );
1313    }
1314
1315    // ── Mixed precedence round trips ──────────────────────────────
1316
1317    #[test]
1318    fn precedence_logical_and_comparison() {
1319        round_trip("x == 1 && y == 2");
1320    }
1321
1322    #[test]
1323    fn precedence_arithmetic_in_comparison() {
1324        round_trip("x + 1 == y * 2");
1325    }
1326
1327    #[test]
1328    fn concat_round_trip() {
1329        round_trip(r#""hello" ++ " world""#);
1330    }
1331}