1use crate::literal::Literal;
2use crate::expr::{ExprVisitor, Expr};
3use crate::stmt::{StmtVisitor, Stmt};
4
5macro_rules! parenthesize {
7 ( $self:ident, $name:expr, $( $x:expr ),+ ) => {
8 {
9 let mut string = String::new();
10 string += "(";
11 string += $name;
12 $(
13 string += " ";
14 string += &$x.accept($self);
15 )*
16 string += ")";
17
18 string
19 }
20 };
21}
22
23pub struct ASTPrinter;
24
25impl ASTPrinter {
26 pub fn print(&mut self, expr: Expr) -> String {
28 expr.accept(self)
29 }
30}
31
32impl ExprVisitor<String> for ASTPrinter {
33 fn visit_literal_expr(&mut self, expr: &Expr) -> String {
34 let Expr::Literal(literal) = expr else { unreachable!() };
35 literal.to_string()
36 }
37
38 fn visit_logical_expr(&mut self, expr: &Expr) -> String {
39 let Expr::Logical(data) = expr else { unreachable!() };
40 parenthesize!(self, &data.operator.lexeme, &data.left, &data.right)
41 }
42
43 fn visit_unary_expr(&mut self, expr: &Expr) -> String {
44 let Expr::Unary(data) = expr else { unreachable!() };
45 parenthesize!(self, &data.operator.lexeme, &data.expr)
46 }
47
48 fn visit_binary_expr(&mut self, expr: &Expr) -> String {
49 let Expr::Binary(data) = expr else { unreachable!() };
50 parenthesize!(self, &data.operator.lexeme, &data.left, &data.right)
51 }
52
53 fn visit_grouping_expr(&mut self, expr: &Expr) -> String {
54 let Expr::Grouping(data) = expr else { unreachable!() };
55 parenthesize!(self, "group", data.expr)
56 }
57
58 fn visit_variable_expr(&mut self, expr: &Expr) -> String {
59 let Expr::Variable(data) = expr else { unreachable!() };
60 data.name.lexeme.clone()
61 }
62
63 fn visit_assign_expr(&mut self, expr: &Expr) -> String {
64 let Expr::Assign(data) = expr else { unreachable!() };
65 parenthesize!(self, format!("= {}", &data.name.lexeme).as_str(), data.value)
66 }
67
68 fn visit_call_expr(&mut self, expr: &Expr) -> String {
69 let Expr::Call(data) = expr else { unreachable!() };
70 let mut string = String::new();
71 string += "(call ";
72 string += &data.callee.accept(self);
73 string += " (";
74 for arg in &data.arguments {
75 string += &arg.accept(self);
76 string += " ";
77 }
78 string = string.trim_end().to_string();
79 string += "))";
80 string
81 }
82
83 fn visit_get_expr(&mut self, expr: &Expr) -> String {
84 let Expr::Get(data) = expr else { unreachable!() };
85 let mut string = String::new();
86 string += "(get ";
87 string += &data.object.accept(self);
88 string += " ";
89 string += &data.name.lexeme;
90 string += ")";
91 string
92 }
93
94 fn visit_set_expr(&mut self, expr: &Expr) -> String {
95 let Expr::Set(data) = expr else { unreachable!() };
96 let mut string = String::new();
97 string += "(set ";
98 string += &data.object.accept(self);
99 string += " ";
100 string += &data.name.lexeme;
101 string += " ";
102 string += &data.value.accept(self);
103 string += ")";
104 string
105 }
106
107 fn visit_this_expr(&mut self, _expr: &Expr) -> String {
108 "this".to_string()
109 }
110
111 fn visit_super_expr(&mut self, expr: &Expr) -> String {
112 let Expr::Super(data) = expr else { unreachable!() };
113 let mut string = String::new();
114 string += "(super ";
115 string += &data.method.lexeme;
116 string += ")";
117 string
118 }
119}
120
121impl StmtVisitor<String> for ASTPrinter {
122 fn visit_expression_stmt(&mut self, stmt: &Stmt) -> String {
123 let Stmt::Expression(data) = stmt else { unreachable!() };
124 data.expr.accept(self)
125 }
126
127 fn visit_function_stmt(&mut self, stmt: &Stmt) -> String {
128 let Stmt::Function(data) = stmt else { unreachable!() };
129 let mut string = String::new();
130 string += "(fun ";
131 string += &data.name.lexeme;
132 string += " (";
133 for param in &data.params {
134 string += ¶m.lexeme;
135 string += " ";
136 }
137 string = string.trim_end().to_string();
138 string += ") { ";
139 for body in &data.body {
140 string += &body.accept(self);
141 }
142 string += " })";
143 string
144 }
145
146 fn visit_if_stmt(&mut self, stmt: &Stmt) -> String {
147 let Stmt::If(data) = stmt else { unreachable!() };
148 let mut string = String::new();
149 string += "(if ";
150 string += &data.condition.accept(self);
151 string += " ";
152 string += &data.then_branch.accept(self);
153 if let Some(else_branch) = &data.else_branch {
154 string += " else ";
155 string += &else_branch.accept(self);
156 }
157 string += ")";
158 string
159 }
160
161 fn visit_print_stmt(&mut self, stmt: &Stmt) -> String {
162 let Stmt::Print(data) = stmt else { unreachable!() };
163 parenthesize!(self, "print", data.expr)
164 }
165
166 fn visit_return_stmt(&mut self, stmt: &Stmt) -> String {
167 let Stmt::Return(data) = stmt else { unreachable!() };
168 parenthesize!(self, "return", data.value.clone().unwrap_or(Expr::Literal(Literal::Null)))
169 }
170
171 fn visit_var_stmt(&mut self, stmt: &Stmt) -> String {
172 let Stmt::Var(data) = stmt else { unreachable!() };
173 let mut string = String::new();
174 string += "(var ";
175 string += &data.name.lexeme;
176 if let Some(initializer) = &data.initializer {
177 string += " = ";
178 string += &initializer.accept(self);
179 }
180 string += ")";
181 string
182 }
183
184 fn visit_while_stmt(&mut self, stmt: &Stmt) -> String {
185 let Stmt::While(data) = stmt else { unreachable!() };
186 let mut string = String::new();
187 string += "(while ";
188 string += &data.condition.accept(self);
189 string += " ";
190 string += &data.body.accept(self);
191 string += ")";
192 string
193 }
194
195 fn visit_block_stmt(&mut self, stmt: &Stmt) -> String {
196 let Stmt::Block(data) = stmt else { unreachable!() };
197 let mut string = String::new();
198 string += "{";
199 for stmt in &data.statements {
200 string += " ";
201 string += &stmt.accept(self);
202 }
203 string += " }";
204 string
205 }
206
207 fn visit_class_stmt(&mut self, stmt: &Stmt) -> String {
208 let Stmt::Class(data) = stmt else { unreachable!() };
209 let mut string = String::new();
210 string += "(class ";
211 string += &data.name.lexeme;
212 if let Some(superclass) = &data.superclass {
213 string += " < ";
214 string += &superclass.accept(self);
215 }
216 for method in &data.methods {
217 string += " ";
218 string += &method.accept(self);
219 }
220 string += ")";
221 string
222 }
223
224 fn visit_break_stmt(&mut self, stmt: &Stmt) -> String {
225 let Stmt::Break(_) = stmt else { unreachable!() };
226 "break".to_string()
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::scanner::Scanner;
234 use crate::parser::Parser;
235
236 #[test]
237 fn test_ast_printer() {
238 let source = "var a = 1; var b = 2; print a + b;";
239 let mut scanner = Scanner::new(source);
240 let tokens = scanner.scan_tokens();
241 let mut parser = Parser::new(tokens);
242 let statements = parser.parse();
243 let mut printer = ASTPrinter {};
244 let ast = statements.iter()
245 .map(|stmt| { stmt.accept(&mut printer) })
246 .collect::<Vec<String>>()
247 .join(" ");
248 assert_eq!(ast, "(var a = 1) (var b = 2) (print (+ a b))");
249 }
250
251 #[test]
252 fn test_ast_printer_with_grouping() {
253 let source = "print (1 + 2) * 3;";
254 let mut scanner = Scanner::new(source);
255 let tokens = scanner.scan_tokens();
256 let mut parser = Parser::new(tokens);
257 let statements = parser.parse();
258 let mut printer = ASTPrinter {};
259 let ast = statements.iter()
260 .map(|stmt| { stmt.accept(&mut printer) })
261 .collect::<Vec<String>>()
262 .join(" ");
263 assert_eq!(ast, "(print (* (group (+ 1 2)) 3))");
264 }
265
266 #[test]
267 fn test_ast_printer_with_if() {
268 let source = "if (a > 0) { print a; } else { print -a; }";
269 let mut scanner = Scanner::new(source);
270 let tokens = scanner.scan_tokens();
271 let mut parser = Parser::new(tokens);
272 let statements = parser.parse();
273 let mut printer = ASTPrinter {};
274 let ast = statements.iter()
275 .map(|stmt| { stmt.accept(&mut printer) })
276 .collect::<Vec<String>>()
277 .join(" ");
278 assert_eq!(ast, "(if (> a 0) { (print a) } else { (print (- a)) })");
279 }
280
281 #[test]
282 fn test_ast_printer_with_function() {
283 let source = "fun add(a, b) { return a + b; }";
284 let mut scanner = Scanner::new(source);
285 let tokens = scanner.scan_tokens();
286 let mut parser = Parser::new(tokens);
287 let statements = parser.parse();
288 let mut printer = ASTPrinter {};
289 let ast = statements.iter()
290 .map(|stmt| { stmt.accept(&mut printer) })
291 .collect::<Vec<String>>()
292 .join(" ");
293 assert_eq!(ast, "(fun add (a b) { (return (+ a b)) })");
294 }
295
296 #[test]
297 fn test_ast_printer_with_class() {
298 let source = "class Foo { bar() { print \"bar\"; } }";
299 let mut scanner = Scanner::new(source);
300 let tokens = scanner.scan_tokens();
301 let mut parser = Parser::new(tokens);
302 let statements = parser.parse();
303 let mut printer = ASTPrinter {};
304 let ast = statements.iter()
305 .map(|stmt| { stmt.accept(&mut printer) })
306 .collect::<Vec<String>>()
307 .join(" ");
308 assert_eq!(ast, "(class Foo (fun bar () { (print bar) }))");
309 }
310
311 #[test]
312 fn test_ast_printer_with_break() {
313 let source = "while (true) { break; }";
314 let mut scanner = Scanner::new(source);
315 let tokens = scanner.scan_tokens();
316 let mut parser = Parser::new(tokens);
317 let statements = parser.parse();
318 let mut printer = ASTPrinter {};
319 let ast = statements.iter()
320 .map(|stmt| { stmt.accept(&mut printer) })
321 .collect::<Vec<String>>()
322 .join(" ");
323 assert_eq!(ast, "(while true { break })");
324 }
325
326 #[test]
327 fn test_ast_printer_with_assignment() {
328 let source = "var a = 1; a = 2;";
329 let mut scanner = Scanner::new(source);
330 let tokens = scanner.scan_tokens();
331 let mut parser = Parser::new(tokens);
332 let statements = parser.parse();
333 let mut printer = ASTPrinter {};
334 let ast = statements.iter()
335 .map(|stmt| { stmt.accept(&mut printer) })
336 .collect::<Vec<String>>()
337 .join(" ");
338 assert_eq!(ast, "(var a = 1) (= a 2)");
339 }
340
341 #[test]
342 fn test_ast_printer_with_logical() {
343 let source = "true and false or true;";
344 let mut scanner = Scanner::new(source);
345 let tokens = scanner.scan_tokens();
346 let mut parser = Parser::new(tokens);
347 let statements = parser.parse();
348 let mut printer = ASTPrinter {};
349 let ast = statements.iter()
350 .map(|stmt| { stmt.accept(&mut printer) })
351 .collect::<Vec<String>>()
352 .join(" ");
353 assert_eq!(ast, "(or (and true false) true)");
354 }
355
356 #[test]
357 fn test_ast_printer_with_call() {
358 let source = "foo(1, 2);";
359 let mut scanner = Scanner::new(source);
360 let tokens = scanner.scan_tokens();
361 let mut parser = Parser::new(tokens);
362 let statements = parser.parse();
363 let mut printer = ASTPrinter {};
364 let ast = statements.iter()
365 .map(|stmt| { stmt.accept(&mut printer) })
366 .collect::<Vec<String>>()
367 .join("");
368 assert_eq!(ast, "(call foo (1 2))");
369 }
370
371 #[test]
372 fn test_ast_printer_with_get() {
373 let source = "foo.bar;";
374 let mut scanner = Scanner::new(source);
375 let tokens = scanner.scan_tokens();
376 let mut parser = Parser::new(tokens);
377 let statements = parser.parse();
378 let mut printer = ASTPrinter {};
379 let ast = statements.iter()
380 .map(|stmt| { stmt.accept(&mut printer) })
381 .collect::<Vec<String>>()
382 .join("");
383 assert_eq!(ast, "(get foo bar)");
384 }
385
386 #[test]
387 fn test_ast_printer_with_set() {
388 let source = "foo.bar = 1;";
389 let mut scanner = Scanner::new(source);
390 let tokens = scanner.scan_tokens();
391 let mut parser = Parser::new(tokens);
392 let statements = parser.parse();
393 let mut printer = ASTPrinter {};
394 let ast = statements.iter()
395 .map(|stmt| { stmt.accept(&mut printer) })
396 .collect::<Vec<String>>()
397 .join("");
398 assert_eq!(ast, "(set foo bar 1)");
399 }
400
401 #[test]
402 fn test_ast_printer_with_this() {
403 let source = "this.foo;";
404 let mut scanner = Scanner::new(source);
405 let tokens = scanner.scan_tokens();
406 let mut parser = Parser::new(tokens);
407 let statements = parser.parse();
408 let mut printer = ASTPrinter {};
409 let ast = statements.iter()
410 .map(|stmt| { stmt.accept(&mut printer) })
411 .collect::<String>();
412 assert_eq!(ast, "(get this foo)");
413 }
414
415 #[test]
416 fn test_ast_printer_with_super() {
417 let source = "class a < b { init() { super.init(); } }";
418 let mut scanner = Scanner::new(source);
419 let tokens = scanner.scan_tokens();
420 let mut parser = Parser::new(tokens);
421 let statements = parser.parse();
422 let mut printer = ASTPrinter {};
423 let ast = statements.iter()
424 .map(|stmt| { stmt.accept(&mut printer) })
425 .collect::<String>();
426 assert_eq!(ast, "(class a < b (fun init () { (call (super init) ()) }))");
427 }
428}