Skip to main content

mest_core/
visitor.rs

1use crate::ast::{BinOp, Expr, ExprKind, Literal, Pat, UnaryOp};
2use lasso::Rodeo;
3
4pub trait Visitor<'bump, Ctx>: Sized {
5    fn visit_expr(&mut self, _expr: &'bump Expr<'bump>, _ctx: &mut Ctx) {}
6
7    fn visit_pat(&mut self, _pat: &'bump Pat<'bump>, _ctx: &mut Ctx) {}
8
9    fn walk_expr(&mut self, expr: &'bump Expr<'bump>, ctx: &mut Ctx) {
10        self.visit_expr(expr, ctx);
11        match &*expr.kind {
12            ExprKind::Literal(_) | ExprKind::Var(_) => {}
13            ExprKind::If {
14                cond,
15                then_expr,
16                else_expr,
17            } => {
18                self.walk_expr(cond, ctx);
19                self.walk_expr(then_expr, ctx);
20                self.walk_expr(else_expr, ctx);
21            }
22            ExprKind::BinOp { op: _, lhs, rhs } => {
23                self.walk_expr(lhs, ctx);
24                self.walk_expr(rhs, ctx);
25            }
26            ExprKind::UnaryOp { op: _, rhs } => {
27                self.walk_expr(rhs, ctx);
28            }
29            ExprKind::Let {
30                name: _,
31                value,
32                body,
33                rec: _,
34            } => {
35                self.walk_expr(value, ctx);
36                self.walk_expr(body, ctx);
37            }
38            ExprKind::Match { scrutinee, arms } => {
39                self.walk_expr(scrutinee, ctx);
40                for (pat, body) in arms.iter() {
41                    self.walk_pat(pat, ctx);
42                    self.walk_expr(body, ctx);
43                }
44            }
45            ExprKind::Abs { param: _, body } => {
46                self.walk_expr(body, ctx);
47            }
48            ExprKind::App { func, arg } => {
49                self.walk_expr(func, ctx);
50                self.walk_expr(arg, ctx);
51            }
52        }
53    }
54
55    fn walk_pat(&mut self, pat: &'bump Pat<'bump>, ctx: &mut Ctx) {
56        self.visit_pat(pat, ctx);
57        match pat {
58            Pat::Wildcard | Pat::Var(_) | Pat::Lit(_) => {}
59            Pat::Or(a, b) => {
60                self.walk_pat(a, ctx);
61                self.walk_pat(b, ctx);
62            }
63        }
64    }
65}
66
67// TODO: VisitMut requires `kind: &'bump mut ExprKind` on `Expr` to walk children.
68// Add once the AST supports mutable node access.
69
70pub struct PrintCtx<'a> {
71    pub rodeo: &'a Rodeo,
72    pub indent: usize,
73    pub output: String,
74}
75
76impl<'a> PrintCtx<'a> {
77    pub fn new(rodeo: &'a Rodeo) -> Self {
78        Self {
79            rodeo,
80            indent: 0,
81            output: String::new(),
82        }
83    }
84}
85
86pub struct AstPrinter;
87
88pub struct DisplayExpr<'a, 'bump> {
89    expr: &'bump Expr<'bump>,
90    rodeo: &'a Rodeo,
91}
92
93impl std::fmt::Display for DisplayExpr<'_, '_> {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        let mut ctx = PrintCtx::new(self.rodeo);
96        AstPrinter.visit_expr(self.expr, &mut ctx);
97        f.write_str(&ctx.output)
98    }
99}
100
101impl AstPrinter {
102    pub fn print_expr<'bump, 'a>(expr: &'bump Expr<'bump>, rodeo: &'a Rodeo) -> DisplayExpr<'a, 'bump> {
103        DisplayExpr { expr, rodeo }
104    }
105}
106
107impl<'bump> Expr<'bump> {
108    pub fn display<'a>(&'bump self, rodeo: &'a Rodeo) -> DisplayExpr<'a, 'bump> {
109        DisplayExpr { expr: self, rodeo }
110    }
111}
112
113fn indent(level: usize) -> String {
114    "  ".repeat(level)
115}
116
117fn binop_prec(op: &BinOp) -> u8 {
118    match op {
119        BinOp::Or => 1,
120        BinOp::And => 2,
121        BinOp::Eq | BinOp::NotEq | BinOp::Lt | BinOp::Gt | BinOp::Le | BinOp::Ge => 3,
122        BinOp::Add | BinOp::Sub => 4,
123        BinOp::Mul | BinOp::Div => 5,
124        BinOp::Pow => 6,
125    }
126}
127
128fn needs_parens_simple(expr: &Expr) -> bool {
129    matches!(
130        &*expr.kind,
131        ExprKind::If { .. }
132            | ExprKind::Let { .. }
133            | ExprKind::Match { .. }
134            | ExprKind::BinOp { .. }
135            | ExprKind::UnaryOp { .. }
136    )
137}
138
139fn parexpr<'bump>(
140    visitor: &mut AstPrinter,
141    expr: &'bump Expr<'bump>,
142    ctx: &mut PrintCtx,
143    wrap: bool,
144) {
145    if wrap && needs_parens_simple(expr) {
146        ctx.output.push('(');
147        visitor.visit_expr(expr, ctx);
148        ctx.output.push(')');
149    } else {
150        visitor.visit_expr(expr, ctx);
151    }
152}
153
154fn parexpr_binop<'bump>(
155    visitor: &mut AstPrinter,
156    child: &'bump Expr<'bump>,
157    parent_op: &BinOp,
158    ctx: &mut PrintCtx,
159) {
160    let needs_parens = match &*child.kind {
161        ExprKind::BinOp { op: child_op, .. } => binop_prec(child_op) < binop_prec(parent_op),
162        _ => needs_parens_simple(child),
163    };
164    if needs_parens {
165        ctx.output.push('(');
166        visitor.visit_expr(child, ctx);
167        ctx.output.push(')');
168    } else {
169        visitor.visit_expr(child, ctx);
170    }
171}
172
173impl<'bump> Visitor<'bump, PrintCtx<'_>> for AstPrinter {
174    fn visit_expr(&mut self, expr: &'bump Expr<'bump>, ctx: &mut PrintCtx<'_>) {
175        match &*expr.kind {
176            ExprKind::Literal(lit) => match lit {
177                Literal::Int(n) => ctx.output.push_str(&n.to_string()),
178                Literal::Float(f) => ctx.output.push_str(&f.to_string()),
179                Literal::Bool(b) => ctx.output.push_str(if *b { "true" } else { "false" }),
180            },
181            ExprKind::Var(ident) => {
182                ctx.output.push_str(ctx.rodeo.resolve(&ident.0));
183            }
184            ExprKind::If {
185                cond,
186                then_expr,
187                else_expr,
188            } => {
189                ctx.output.push_str("if ");
190                self.visit_expr(cond, ctx);
191                ctx.output.push('\n');
192                ctx.output.push_str(&indent(ctx.indent));
193                ctx.output.push_str("then\n");
194                ctx.output.push_str(&indent(ctx.indent + 1));
195                self.visit_expr(then_expr, ctx);
196                ctx.output.push('\n');
197                ctx.output.push_str(&indent(ctx.indent));
198                ctx.output.push_str("else\n");
199                ctx.output.push_str(&indent(ctx.indent + 1));
200                self.visit_expr(else_expr, ctx);
201            }
202            ExprKind::BinOp { op, lhs, rhs } => {
203                let op_str = match op {
204                    BinOp::Eq => " == ",
205                    BinOp::NotEq => " != ",
206                    BinOp::Lt => " < ",
207                    BinOp::Gt => " > ",
208                    BinOp::Le => " <= ",
209                    BinOp::Ge => " >= ",
210                    BinOp::And => " && ",
211                    BinOp::Or => " || ",
212                    BinOp::Add => " + ",
213                    BinOp::Sub => " - ",
214                    BinOp::Mul => " * ",
215                    BinOp::Div => " / ",
216                    BinOp::Pow => " ^ ",
217                };
218                parexpr_binop(self, lhs, op, ctx);
219                ctx.output.push_str(op_str);
220                parexpr_binop(self, rhs, op, ctx);
221            }
222            ExprKind::UnaryOp { op, rhs } => {
223                let op_str = match op {
224                    UnaryOp::Neg => "-",
225                    UnaryOp::Not => "!",
226                };
227                ctx.output.push_str(op_str);
228                parexpr(self, rhs, ctx, true);
229            }
230            ExprKind::Let {
231                name,
232                value,
233                body,
234                rec,
235            } => {
236                if *rec {
237                    ctx.output.push_str("let rec ");
238                } else {
239                    ctx.output.push_str("let ");
240                }
241                ctx.output.push_str(ctx.rodeo.resolve(&name.0));
242                ctx.output.push_str(" = ");
243                ctx.indent += 1;
244                self.visit_expr(value, ctx);
245                ctx.output.push('\n');
246                ctx.indent -= 1;
247                ctx.output.push_str(&indent(ctx.indent));
248                ctx.output.push_str("in\n");
249                ctx.output.push_str(&indent(ctx.indent + 1));
250                self.visit_expr(body, ctx);
251            }
252            ExprKind::Match { scrutinee, arms } => {
253                ctx.output.push_str("match ");
254                self.visit_expr(scrutinee, ctx);
255                ctx.output.push('\n');
256                for (i, (pat, body)) in arms.iter().enumerate() {
257                    ctx.output.push_str(&indent(ctx.indent));
258                    ctx.output.push_str("| ");
259                    self.visit_pat(pat, ctx);
260                    ctx.output.push_str(" => ");
261                    self.visit_expr(body, ctx);
262                    if i < arms.len() - 1 {
263                        ctx.output.push('\n');
264                    }
265                }
266            }
267            ExprKind::Abs { param, body } => {
268                ctx.output.push('|');
269                ctx.output.push_str(ctx.rodeo.resolve(&param.0));
270                ctx.output.push_str("| ");
271                self.visit_expr(body, ctx);
272            }
273            ExprKind::App { func, arg } => {
274                parexpr(self, func, ctx, true);
275                ctx.output.push(' ');
276                parexpr(self, arg, ctx, true);
277            }
278        }
279    }
280
281    fn visit_pat(&mut self, pat: &'bump Pat<'bump>, ctx: &mut PrintCtx<'_>) {
282        match pat {
283            Pat::Wildcard => ctx.output.push('_'),
284            Pat::Var(ident) => ctx.output.push_str(ctx.rodeo.resolve(&ident.0)),
285            Pat::Lit(lit) => match lit {
286                Literal::Int(n) => ctx.output.push_str(&n.to_string()),
287                Literal::Float(f) => ctx.output.push_str(&f.to_string()),
288                Literal::Bool(b) => ctx.output.push_str(if *b { "true" } else { "false" }),
289            },
290            Pat::Or(a, b) => {
291                self.visit_pat(a, ctx);
292                ctx.output.push_str(" | ");
293                self.visit_pat(b, ctx);
294            }
295        }
296    }
297}