1use crate::ast::{BinOp, Expr, ExprKind, Literal, Pat, UnaryOp};
2use lasso::Rodeo;
3
4pub trait Visitor<'bump, Ctx>: Sized {
5 fn visit_expr(&mut self, _expr: &'bump Expr<'bump>, _ctx: &mut Ctx) {}
6
7 fn visit_pat(&mut self, _pat: &'bump Pat<'bump>, _ctx: &mut Ctx) {}
8
9 fn walk_expr(&mut self, expr: &'bump Expr<'bump>, ctx: &mut Ctx) {
10 self.visit_expr(expr, ctx);
11 match &*expr.kind {
12 ExprKind::Literal(_) | ExprKind::Var(_) => {}
13 ExprKind::If {
14 cond,
15 then_expr,
16 else_expr,
17 } => {
18 self.walk_expr(cond, ctx);
19 self.walk_expr(then_expr, ctx);
20 self.walk_expr(else_expr, ctx);
21 }
22 ExprKind::BinOp { op: _, lhs, rhs } => {
23 self.walk_expr(lhs, ctx);
24 self.walk_expr(rhs, ctx);
25 }
26 ExprKind::UnaryOp { op: _, rhs } => {
27 self.walk_expr(rhs, ctx);
28 }
29 ExprKind::Let {
30 name: _,
31 value,
32 body,
33 rec: _,
34 } => {
35 self.walk_expr(value, ctx);
36 self.walk_expr(body, ctx);
37 }
38 ExprKind::Match { scrutinee, arms } => {
39 self.walk_expr(scrutinee, ctx);
40 for (pat, body) in arms.iter() {
41 self.walk_pat(pat, ctx);
42 self.walk_expr(body, ctx);
43 }
44 }
45 ExprKind::Abs { param: _, body } => {
46 self.walk_expr(body, ctx);
47 }
48 ExprKind::App { func, arg } => {
49 self.walk_expr(func, ctx);
50 self.walk_expr(arg, ctx);
51 }
52 }
53 }
54
55 fn walk_pat(&mut self, pat: &'bump Pat<'bump>, ctx: &mut Ctx) {
56 self.visit_pat(pat, ctx);
57 match pat {
58 Pat::Wildcard | Pat::Var(_) | Pat::Lit(_) => {}
59 Pat::Or(a, b) => {
60 self.walk_pat(a, ctx);
61 self.walk_pat(b, ctx);
62 }
63 }
64 }
65}
66
67pub struct PrintCtx<'a> {
71 pub rodeo: &'a Rodeo,
72 pub indent: usize,
73 pub output: String,
74}
75
76impl<'a> PrintCtx<'a> {
77 pub fn new(rodeo: &'a Rodeo) -> Self {
78 Self {
79 rodeo,
80 indent: 0,
81 output: String::new(),
82 }
83 }
84}
85
86pub struct AstPrinter;
87
88pub struct DisplayExpr<'a, 'bump> {
89 expr: &'bump Expr<'bump>,
90 rodeo: &'a Rodeo,
91}
92
93impl std::fmt::Display for DisplayExpr<'_, '_> {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 let mut ctx = PrintCtx::new(self.rodeo);
96 AstPrinter.visit_expr(self.expr, &mut ctx);
97 f.write_str(&ctx.output)
98 }
99}
100
101impl AstPrinter {
102 pub fn print_expr<'bump, 'a>(expr: &'bump Expr<'bump>, rodeo: &'a Rodeo) -> DisplayExpr<'a, 'bump> {
103 DisplayExpr { expr, rodeo }
104 }
105}
106
107impl<'bump> Expr<'bump> {
108 pub fn display<'a>(&'bump self, rodeo: &'a Rodeo) -> DisplayExpr<'a, 'bump> {
109 DisplayExpr { expr: self, rodeo }
110 }
111}
112
113fn indent(level: usize) -> String {
114 " ".repeat(level)
115}
116
117fn binop_prec(op: &BinOp) -> u8 {
118 match op {
119 BinOp::Or => 1,
120 BinOp::And => 2,
121 BinOp::Eq | BinOp::NotEq | BinOp::Lt | BinOp::Gt | BinOp::Le | BinOp::Ge => 3,
122 BinOp::Add | BinOp::Sub => 4,
123 BinOp::Mul | BinOp::Div => 5,
124 BinOp::Pow => 6,
125 }
126}
127
128fn needs_parens_simple(expr: &Expr) -> bool {
129 matches!(
130 &*expr.kind,
131 ExprKind::If { .. }
132 | ExprKind::Let { .. }
133 | ExprKind::Match { .. }
134 | ExprKind::BinOp { .. }
135 | ExprKind::UnaryOp { .. }
136 )
137}
138
139fn parexpr<'bump>(
140 visitor: &mut AstPrinter,
141 expr: &'bump Expr<'bump>,
142 ctx: &mut PrintCtx,
143 wrap: bool,
144) {
145 if wrap && needs_parens_simple(expr) {
146 ctx.output.push('(');
147 visitor.visit_expr(expr, ctx);
148 ctx.output.push(')');
149 } else {
150 visitor.visit_expr(expr, ctx);
151 }
152}
153
154fn parexpr_binop<'bump>(
155 visitor: &mut AstPrinter,
156 child: &'bump Expr<'bump>,
157 parent_op: &BinOp,
158 ctx: &mut PrintCtx,
159) {
160 let needs_parens = match &*child.kind {
161 ExprKind::BinOp { op: child_op, .. } => binop_prec(child_op) < binop_prec(parent_op),
162 _ => needs_parens_simple(child),
163 };
164 if needs_parens {
165 ctx.output.push('(');
166 visitor.visit_expr(child, ctx);
167 ctx.output.push(')');
168 } else {
169 visitor.visit_expr(child, ctx);
170 }
171}
172
173impl<'bump> Visitor<'bump, PrintCtx<'_>> for AstPrinter {
174 fn visit_expr(&mut self, expr: &'bump Expr<'bump>, ctx: &mut PrintCtx<'_>) {
175 match &*expr.kind {
176 ExprKind::Literal(lit) => match lit {
177 Literal::Int(n) => ctx.output.push_str(&n.to_string()),
178 Literal::Float(f) => ctx.output.push_str(&f.to_string()),
179 Literal::Bool(b) => ctx.output.push_str(if *b { "true" } else { "false" }),
180 },
181 ExprKind::Var(ident) => {
182 ctx.output.push_str(ctx.rodeo.resolve(&ident.0));
183 }
184 ExprKind::If {
185 cond,
186 then_expr,
187 else_expr,
188 } => {
189 ctx.output.push_str("if ");
190 self.visit_expr(cond, ctx);
191 ctx.output.push('\n');
192 ctx.output.push_str(&indent(ctx.indent));
193 ctx.output.push_str("then\n");
194 ctx.output.push_str(&indent(ctx.indent + 1));
195 self.visit_expr(then_expr, ctx);
196 ctx.output.push('\n');
197 ctx.output.push_str(&indent(ctx.indent));
198 ctx.output.push_str("else\n");
199 ctx.output.push_str(&indent(ctx.indent + 1));
200 self.visit_expr(else_expr, ctx);
201 }
202 ExprKind::BinOp { op, lhs, rhs } => {
203 let op_str = match op {
204 BinOp::Eq => " == ",
205 BinOp::NotEq => " != ",
206 BinOp::Lt => " < ",
207 BinOp::Gt => " > ",
208 BinOp::Le => " <= ",
209 BinOp::Ge => " >= ",
210 BinOp::And => " && ",
211 BinOp::Or => " || ",
212 BinOp::Add => " + ",
213 BinOp::Sub => " - ",
214 BinOp::Mul => " * ",
215 BinOp::Div => " / ",
216 BinOp::Pow => " ^ ",
217 };
218 parexpr_binop(self, lhs, op, ctx);
219 ctx.output.push_str(op_str);
220 parexpr_binop(self, rhs, op, ctx);
221 }
222 ExprKind::UnaryOp { op, rhs } => {
223 let op_str = match op {
224 UnaryOp::Neg => "-",
225 UnaryOp::Not => "!",
226 };
227 ctx.output.push_str(op_str);
228 parexpr(self, rhs, ctx, true);
229 }
230 ExprKind::Let {
231 name,
232 value,
233 body,
234 rec,
235 } => {
236 if *rec {
237 ctx.output.push_str("let rec ");
238 } else {
239 ctx.output.push_str("let ");
240 }
241 ctx.output.push_str(ctx.rodeo.resolve(&name.0));
242 ctx.output.push_str(" = ");
243 ctx.indent += 1;
244 self.visit_expr(value, ctx);
245 ctx.output.push('\n');
246 ctx.indent -= 1;
247 ctx.output.push_str(&indent(ctx.indent));
248 ctx.output.push_str("in\n");
249 ctx.output.push_str(&indent(ctx.indent + 1));
250 self.visit_expr(body, ctx);
251 }
252 ExprKind::Match { scrutinee, arms } => {
253 ctx.output.push_str("match ");
254 self.visit_expr(scrutinee, ctx);
255 ctx.output.push('\n');
256 for (i, (pat, body)) in arms.iter().enumerate() {
257 ctx.output.push_str(&indent(ctx.indent));
258 ctx.output.push_str("| ");
259 self.visit_pat(pat, ctx);
260 ctx.output.push_str(" => ");
261 self.visit_expr(body, ctx);
262 if i < arms.len() - 1 {
263 ctx.output.push('\n');
264 }
265 }
266 }
267 ExprKind::Abs { param, body } => {
268 ctx.output.push('|');
269 ctx.output.push_str(ctx.rodeo.resolve(¶m.0));
270 ctx.output.push_str("| ");
271 self.visit_expr(body, ctx);
272 }
273 ExprKind::App { func, arg } => {
274 parexpr(self, func, ctx, true);
275 ctx.output.push(' ');
276 parexpr(self, arg, ctx, true);
277 }
278 }
279 }
280
281 fn visit_pat(&mut self, pat: &'bump Pat<'bump>, ctx: &mut PrintCtx<'_>) {
282 match pat {
283 Pat::Wildcard => ctx.output.push('_'),
284 Pat::Var(ident) => ctx.output.push_str(ctx.rodeo.resolve(&ident.0)),
285 Pat::Lit(lit) => match lit {
286 Literal::Int(n) => ctx.output.push_str(&n.to_string()),
287 Literal::Float(f) => ctx.output.push_str(&f.to_string()),
288 Literal::Bool(b) => ctx.output.push_str(if *b { "true" } else { "false" }),
289 },
290 Pat::Or(a, b) => {
291 self.visit_pat(a, ctx);
292 ctx.output.push_str(" | ");
293 self.visit_pat(b, ctx);
294 }
295 }
296 }
297}