Skip to main content

lex_syntax/
printer.rs

1//! Pretty-printer for the syntax tree. Designed so
2//! `parse(text) → print → parse` round-trips on the canonical AST.
3
4use crate::syntax::*;
5use std::fmt::Write;
6
7pub fn print_program(program: &Program) -> String {
8    let mut p = Printer::new();
9    p.program(program);
10    p.out
11}
12
13struct Printer {
14    out: String,
15    indent: usize,
16}
17
18impl Printer {
19    fn new() -> Self { Self { out: String::new(), indent: 0 } }
20
21    fn write_indent(&mut self) {
22        for _ in 0..self.indent {
23            self.out.push_str("  ");
24        }
25    }
26
27    fn nl(&mut self) {
28        self.out.push('\n');
29    }
30
31    fn program(&mut self, p: &Program) {
32        for (i, item) in p.items.iter().enumerate() {
33            if i > 0 { self.nl(); }
34            self.item(item);
35        }
36        if !p.items.is_empty() {
37            self.nl();
38        }
39    }
40
41    fn item(&mut self, item: &Item) {
42        match item {
43            Item::Import(i) => {
44                writeln!(self.out, "import \"{}\" as {}", i.reference, i.alias).unwrap();
45            }
46            Item::TypeDecl(td) => self.type_decl(td),
47            Item::FnDecl(fd) => self.fn_decl(fd),
48        }
49    }
50
51    fn type_decl(&mut self, td: &TypeDecl) {
52        write!(self.out, "type {}", td.name).unwrap();
53        if !td.params.is_empty() {
54            write!(self.out, "[{}]", td.params.join(", ")).unwrap();
55        }
56        write!(self.out, " = ").unwrap();
57        self.type_expr(&td.definition);
58        self.nl();
59    }
60
61    fn fn_decl(&mut self, fd: &FnDecl) {
62        write!(self.out, "fn {}", fd.name).unwrap();
63        if !fd.type_params.is_empty() {
64            write!(self.out, "[{}]", fd.type_params.join(", ")).unwrap();
65        }
66        write!(self.out, "(").unwrap();
67        for (i, p) in fd.params.iter().enumerate() {
68            if i > 0 { write!(self.out, ", ").unwrap(); }
69            write!(self.out, "{} :: ", p.name).unwrap();
70            self.type_expr(&p.ty);
71        }
72        write!(self.out, ") -> ").unwrap();
73        self.effects(&fd.effects);
74        self.type_expr(&fd.return_type);
75        write!(self.out, " ").unwrap();
76        self.block(&fd.body);
77        self.nl();
78    }
79
80    fn effects(&mut self, effects: &[Effect]) {
81        if effects.is_empty() { return; }
82        write!(self.out, "[").unwrap();
83        for (i, e) in effects.iter().enumerate() {
84            if i > 0 { write!(self.out, ", ").unwrap(); }
85            write!(self.out, "{}", e.name).unwrap();
86            if let Some(arg) = &e.arg {
87                match arg {
88                    EffectArg::Str(s) => write!(self.out, "(\"{}\")", s).unwrap(),
89                    EffectArg::Int(n) => write!(self.out, "({})", n).unwrap(),
90                    EffectArg::Ident(s) => write!(self.out, "({})", s).unwrap(),
91                }
92            }
93        }
94        write!(self.out, "] ").unwrap();
95    }
96
97    fn type_expr(&mut self, t: &TypeExpr) {
98        match t {
99            TypeExpr::Named { name, args } => {
100                write!(self.out, "{}", name).unwrap();
101                if !args.is_empty() {
102                    write!(self.out, "[").unwrap();
103                    for (i, a) in args.iter().enumerate() {
104                        if i > 0 { write!(self.out, ", ").unwrap(); }
105                        self.type_expr(a);
106                    }
107                    write!(self.out, "]").unwrap();
108                }
109            }
110            TypeExpr::Record(fs) => {
111                write!(self.out, "{{ ").unwrap();
112                for (i, f) in fs.iter().enumerate() {
113                    if i > 0 { write!(self.out, ", ").unwrap(); }
114                    write!(self.out, "{} :: ", f.name).unwrap();
115                    self.type_expr(&f.ty);
116                }
117                write!(self.out, " }}").unwrap();
118            }
119            TypeExpr::Tuple(items) => {
120                write!(self.out, "(").unwrap();
121                for (i, it) in items.iter().enumerate() {
122                    if i > 0 { write!(self.out, ", ").unwrap(); }
123                    self.type_expr(it);
124                }
125                write!(self.out, ")").unwrap();
126            }
127            TypeExpr::Function { params, effects, ret } => {
128                write!(self.out, "(").unwrap();
129                for (i, p) in params.iter().enumerate() {
130                    if i > 0 { write!(self.out, ", ").unwrap(); }
131                    self.type_expr(p);
132                }
133                write!(self.out, ") -> ").unwrap();
134                self.effects(effects);
135                self.type_expr(ret);
136            }
137            TypeExpr::Union(variants) => {
138                for (i, v) in variants.iter().enumerate() {
139                    if i > 0 { write!(self.out, " | ").unwrap(); }
140                    write!(self.out, "{}", v.name).unwrap();
141                    if let Some(payload) = &v.payload {
142                        write!(self.out, "(").unwrap();
143                        self.type_expr(payload);
144                        write!(self.out, ")").unwrap();
145                    }
146                }
147            }
148            TypeExpr::Refined { base, binding, predicate } => {
149                self.type_expr(base);
150                write!(self.out, "{{{} | ", binding).unwrap();
151                self.expr(predicate);
152                write!(self.out, "}}").unwrap();
153            }
154        }
155    }
156
157    fn block(&mut self, b: &Block) {
158        write!(self.out, "{{").unwrap();
159        self.indent += 1;
160        for stmt in &b.statements {
161            self.nl();
162            self.write_indent();
163            self.statement(stmt);
164        }
165        self.nl();
166        self.write_indent();
167        self.expr(&b.result);
168        self.indent -= 1;
169        self.nl();
170        self.write_indent();
171        write!(self.out, "}}").unwrap();
172    }
173
174    fn statement(&mut self, s: &Statement) {
175        match s {
176            Statement::Let { name, ty, value } => {
177                write!(self.out, "let {}", name).unwrap();
178                if let Some(ty) = ty {
179                    write!(self.out, " :: ").unwrap();
180                    self.type_expr(ty);
181                }
182                write!(self.out, " := ").unwrap();
183                self.expr(value);
184            }
185            Statement::Expr(e) => self.expr(e),
186        }
187    }
188
189    fn expr(&mut self, e: &Expr) {
190        self.expr_prec(e, 0);
191    }
192
193    fn expr_prec(&mut self, e: &Expr, parent_prec: u8) {
194        match e {
195            Expr::Lit(l) => self.literal(l),
196            Expr::Var(n) => { write!(self.out, "{}", n).unwrap(); }
197            Expr::Block(b) => self.block(b),
198            Expr::Call { callee, args } => {
199                self.expr_prec(callee, 100);
200                write!(self.out, "(").unwrap();
201                for (i, a) in args.iter().enumerate() {
202                    if i > 0 { write!(self.out, ", ").unwrap(); }
203                    self.expr(a);
204                }
205                write!(self.out, ")").unwrap();
206            }
207            Expr::Pipe { left, right } => {
208                if parent_prec > 0 { write!(self.out, "(").unwrap(); }
209                self.expr_prec(left, 1);
210                write!(self.out, " |> ").unwrap();
211                self.expr_prec(right, 1);
212                if parent_prec > 0 { write!(self.out, ")").unwrap(); }
213            }
214            Expr::Try(inner) => {
215                self.expr_prec(inner, 100);
216                write!(self.out, "?").unwrap();
217            }
218            Expr::Field { value, field } => {
219                self.expr_prec(value, 100);
220                write!(self.out, ".{}", field).unwrap();
221            }
222            Expr::BinOp { op, lhs, rhs } => {
223                let prec = op.precedence() + 10;
224                if parent_prec > prec { write!(self.out, "(").unwrap(); }
225                self.expr_prec(lhs, prec);
226                write!(self.out, " {} ", op.as_str()).unwrap();
227                self.expr_prec(rhs, prec + 1);
228                if parent_prec > prec { write!(self.out, ")").unwrap(); }
229            }
230            Expr::UnaryOp { op, expr } => {
231                let s = match op { UnaryOp::Neg => "-", UnaryOp::Not => "not " };
232                write!(self.out, "{}", s).unwrap();
233                self.expr_prec(expr, 100);
234            }
235            Expr::If { cond, then_block, else_block } => {
236                write!(self.out, "if ").unwrap();
237                self.expr(cond);
238                write!(self.out, " ").unwrap();
239                self.block(then_block);
240                write!(self.out, " else ").unwrap();
241                self.block(else_block);
242            }
243            Expr::Match { scrutinee, arms } => {
244                write!(self.out, "match ").unwrap();
245                self.expr(scrutinee);
246                write!(self.out, " {{").unwrap();
247                self.indent += 1;
248                for arm in arms {
249                    self.nl();
250                    self.write_indent();
251                    self.pattern(&arm.pattern);
252                    write!(self.out, " => ").unwrap();
253                    self.expr(&arm.body);
254                    write!(self.out, ",").unwrap();
255                }
256                self.indent -= 1;
257                self.nl();
258                self.write_indent();
259                write!(self.out, "}}").unwrap();
260            }
261            Expr::RecordLit(fields) => {
262                write!(self.out, "{{ ").unwrap();
263                for (i, f) in fields.iter().enumerate() {
264                    if i > 0 { write!(self.out, ", ").unwrap(); }
265                    write!(self.out, "{}: ", f.name).unwrap();
266                    self.expr(&f.value);
267                }
268                write!(self.out, " }}").unwrap();
269            }
270            Expr::TupleLit(items) => {
271                write!(self.out, "(").unwrap();
272                for (i, it) in items.iter().enumerate() {
273                    if i > 0 { write!(self.out, ", ").unwrap(); }
274                    self.expr(it);
275                }
276                write!(self.out, ")").unwrap();
277            }
278            Expr::ListLit(items) => {
279                write!(self.out, "[").unwrap();
280                for (i, it) in items.iter().enumerate() {
281                    if i > 0 { write!(self.out, ", ").unwrap(); }
282                    self.expr(it);
283                }
284                write!(self.out, "]").unwrap();
285            }
286            Expr::Constructor { name, args } => {
287                write!(self.out, "{}", name).unwrap();
288                if !args.is_empty() {
289                    write!(self.out, "(").unwrap();
290                    for (i, a) in args.iter().enumerate() {
291                        if i > 0 { write!(self.out, ", ").unwrap(); }
292                        self.expr(a);
293                    }
294                    write!(self.out, ")").unwrap();
295                }
296            }
297            Expr::Ascription { value, ty } => {
298                write!(self.out, "(").unwrap();
299                self.expr(value);
300                write!(self.out, " :: ").unwrap();
301                self.type_expr(ty);
302                write!(self.out, ")").unwrap();
303            }
304            Expr::Lambda(l) => {
305                write!(self.out, "fn (").unwrap();
306                for (i, p) in l.params.iter().enumerate() {
307                    if i > 0 { write!(self.out, ", ").unwrap(); }
308                    write!(self.out, "{} :: ", p.name).unwrap();
309                    self.type_expr(&p.ty);
310                }
311                write!(self.out, ") -> ").unwrap();
312                self.effects(&l.effects);
313                self.type_expr(&l.return_type);
314                write!(self.out, " ").unwrap();
315                self.block(&l.body);
316            }
317        }
318    }
319
320    fn literal(&mut self, l: &Literal) {
321        match l {
322            Literal::Int(n) => write!(self.out, "{}", n).unwrap(),
323            Literal::Float(n) => write!(self.out, "{}", format_float(*n)).unwrap(),
324            Literal::Str(s) => write!(self.out, "\"{}\"", escape(s)).unwrap(),
325            Literal::Bytes(b) => {
326                write!(self.out, "b\"").unwrap();
327                for &c in b {
328                    if c.is_ascii() && (c as char).is_ascii_graphic() && c != b'"' && c != b'\\' {
329                        self.out.push(c as char);
330                    } else {
331                        write!(self.out, "\\x{:02x}", c).unwrap();
332                    }
333                }
334                write!(self.out, "\"").unwrap();
335            }
336            Literal::Bool(b) => write!(self.out, "{}", b).unwrap(),
337            Literal::Unit => write!(self.out, "()").unwrap(),
338        }
339    }
340
341    fn pattern(&mut self, p: &Pattern) {
342        match p {
343            Pattern::Lit(l) => self.literal(l),
344            Pattern::Var(n) => { write!(self.out, "{}", n).unwrap(); }
345            Pattern::Wild => { write!(self.out, "_").unwrap(); }
346            Pattern::Constructor { name, args } => {
347                write!(self.out, "{}", name).unwrap();
348                if !args.is_empty() {
349                    write!(self.out, "(").unwrap();
350                    for (i, a) in args.iter().enumerate() {
351                        if i > 0 { write!(self.out, ", ").unwrap(); }
352                        self.pattern(a);
353                    }
354                    write!(self.out, ")").unwrap();
355                }
356            }
357            Pattern::Record { fields, rest: _ } => {
358                write!(self.out, "{{ ").unwrap();
359                for (i, f) in fields.iter().enumerate() {
360                    if i > 0 { write!(self.out, ", ").unwrap(); }
361                    write!(self.out, "{}", f.name).unwrap();
362                    if let Some(p) = &f.pattern {
363                        write!(self.out, ": ").unwrap();
364                        self.pattern(p);
365                    }
366                }
367                write!(self.out, " }}").unwrap();
368            }
369            Pattern::Tuple(items) => {
370                write!(self.out, "(").unwrap();
371                for (i, it) in items.iter().enumerate() {
372                    if i > 0 { write!(self.out, ", ").unwrap(); }
373                    self.pattern(it);
374                }
375                write!(self.out, ")").unwrap();
376            }
377        }
378    }
379}
380
381fn escape(s: &str) -> String {
382    let mut out = String::with_capacity(s.len());
383    for c in s.chars() {
384        match c {
385            '\\' => out.push_str("\\\\"),
386            '"' => out.push_str("\\\""),
387            '\n' => out.push_str("\\n"),
388            '\t' => out.push_str("\\t"),
389            '\r' => out.push_str("\\r"),
390            c => out.push(c),
391        }
392    }
393    out
394}
395
396fn format_float(n: f64) -> String {
397    if n.is_finite() && n == n.trunc() {
398        format!("{:.1}", n)
399    } else {
400        format!("{}", n)
401    }
402}