rocks_lang/
ast.rs

1use crate::literal::Literal;
2use crate::expr::{ExprVisitor, Expr};
3use crate::stmt::{StmtVisitor, Stmt};
4
5/// Returns a string representation of the expression in paranthesize.
6macro_rules! parenthesize {
7    ( $self:ident, $name:expr, $( $x:expr ),+ ) => {
8        {
9            let mut string = String::new();
10            string += "(";
11            string += $name;
12            $(
13                string += " ";
14                string += &$x.accept($self);
15            )*
16            string += ")";
17
18            string
19        }
20    };
21}
22
23pub struct ASTPrinter;
24
25impl ASTPrinter {
26    /// Prints the expression using visitor pattern.
27    pub fn print(&mut self, expr: Expr) -> String {
28        expr.accept(self)
29    }
30}
31
32impl ExprVisitor<String> for ASTPrinter {
33    fn visit_literal_expr(&mut self, expr: &Expr) -> String {
34        let Expr::Literal(literal) = expr else { unreachable!() };
35        literal.to_string()
36    }
37
38    fn visit_logical_expr(&mut self, expr: &Expr) -> String {
39        let Expr::Logical(data) = expr else { unreachable!() };
40        parenthesize!(self, &data.operator.lexeme, &data.left, &data.right)
41    }
42
43    fn visit_unary_expr(&mut self, expr: &Expr) -> String {
44        let Expr::Unary(data) = expr else { unreachable!() };
45        parenthesize!(self, &data.operator.lexeme, &data.expr)
46    }
47
48    fn visit_binary_expr(&mut self, expr: &Expr) -> String {
49        let Expr::Binary(data) = expr else { unreachable!() };
50        parenthesize!(self, &data.operator.lexeme, &data.left, &data.right)
51    }
52
53    fn visit_grouping_expr(&mut self, expr: &Expr) -> String {
54        let Expr::Grouping(data) = expr else { unreachable!() };
55        parenthesize!(self, "group", data.expr)
56    }
57
58    fn visit_variable_expr(&mut self, expr: &Expr) -> String {
59        let Expr::Variable(data) = expr else { unreachable!() };
60        data.name.lexeme.clone()
61    }
62
63    fn visit_assign_expr(&mut self, expr: &Expr) -> String {
64        let Expr::Assign(data) = expr else { unreachable!() };
65        parenthesize!(self, format!("= {}", &data.name.lexeme).as_str(), data.value)
66    }
67
68    fn visit_call_expr(&mut self, expr: &Expr) -> String {
69        let Expr::Call(data) = expr else { unreachable!() };
70        let mut string = String::new();
71        string += "(call ";
72        string += &data.callee.accept(self);
73        string += " (";
74        for arg in &data.arguments {
75            string += &arg.accept(self);
76            string += " ";
77        }
78        string = string.trim_end().to_string();
79        string += "))";
80        string
81    }
82
83    fn visit_get_expr(&mut self, expr: &Expr) -> String {
84        let Expr::Get(data) = expr else { unreachable!() };
85        let mut string = String::new();
86        string += "(get ";
87        string += &data.object.accept(self);
88        string += " ";
89        string += &data.name.lexeme;
90        string += ")";
91        string
92    }
93
94    fn visit_set_expr(&mut self, expr: &Expr) -> String {
95        let Expr::Set(data) = expr else { unreachable!() };
96        let mut string = String::new();
97        string += "(set ";
98        string += &data.object.accept(self);
99        string += " ";
100        string += &data.name.lexeme;
101        string += " ";
102        string += &data.value.accept(self);
103        string += ")";
104        string
105    }
106
107    fn visit_this_expr(&mut self, _expr: &Expr) -> String {
108        "this".to_string()
109    }
110
111    fn visit_super_expr(&mut self, expr: &Expr) -> String {
112        let Expr::Super(data) = expr else { unreachable!() };
113        let mut string = String::new();
114        string += "(super ";
115        string += &data.method.lexeme;
116        string += ")";
117        string
118    }
119}
120
121impl StmtVisitor<String> for ASTPrinter {
122    fn visit_expression_stmt(&mut self, stmt: &Stmt) -> String {
123        let Stmt::Expression(data) = stmt else { unreachable!() };
124        data.expr.accept(self)
125    }
126
127    fn visit_function_stmt(&mut self, stmt: &Stmt) -> String {
128        let Stmt::Function(data) = stmt else { unreachable!() };
129        let mut string = String::new();
130        string += "(fun ";
131        string += &data.name.lexeme;
132        string += " (";
133        for param in &data.params {
134            string += &param.lexeme;
135            string += " ";
136        }
137        string = string.trim_end().to_string();
138        string += ") { ";
139        for body in &data.body {
140            string += &body.accept(self);
141        }
142        string += " })";
143        string
144    }
145
146    fn visit_if_stmt(&mut self, stmt: &Stmt) -> String {
147        let Stmt::If(data) = stmt else { unreachable!() };
148        let mut string = String::new();
149        string += "(if ";
150        string += &data.condition.accept(self);
151        string += " ";
152        string += &data.then_branch.accept(self);
153        if let Some(else_branch) = &data.else_branch {
154            string += " else ";
155            string += &else_branch.accept(self);
156        }
157        string += ")";
158        string
159    }
160
161    fn visit_print_stmt(&mut self, stmt: &Stmt) -> String {
162        let Stmt::Print(data) = stmt else { unreachable!() };
163        parenthesize!(self, "print", data.expr)
164    }
165
166    fn visit_return_stmt(&mut self, stmt: &Stmt) -> String {
167        let Stmt::Return(data) = stmt else { unreachable!() };
168        parenthesize!(self, "return", data.value.clone().unwrap_or(Expr::Literal(Literal::Null)))
169    }
170
171    fn visit_var_stmt(&mut self, stmt: &Stmt) -> String {
172        let Stmt::Var(data) = stmt else { unreachable!() };
173        let mut string = String::new();
174        string += "(var ";
175        string += &data.name.lexeme;
176        if let Some(initializer) = &data.initializer {
177            string += " = ";
178            string += &initializer.accept(self);
179        }
180        string += ")";
181        string
182    }
183
184    fn visit_while_stmt(&mut self, stmt: &Stmt) -> String {
185        let Stmt::While(data) = stmt else { unreachable!() };
186        let mut string = String::new();
187        string += "(while ";
188        string += &data.condition.accept(self);
189        string += " ";
190        string += &data.body.accept(self);
191        string += ")";
192        string
193    }
194
195    fn visit_block_stmt(&mut self, stmt: &Stmt) -> String {
196        let Stmt::Block(data) = stmt else { unreachable!() };
197        let mut string = String::new();
198        string += "{";
199        for stmt in &data.statements {
200            string += " ";
201            string += &stmt.accept(self);
202        }
203        string += " }";
204        string
205    }
206
207    fn visit_class_stmt(&mut self, stmt: &Stmt) -> String {
208        let Stmt::Class(data) = stmt else { unreachable!() };
209        let mut string = String::new();
210        string += "(class ";
211        string += &data.name.lexeme;
212        if let Some(superclass) = &data.superclass {
213            string += " < ";
214            string += &superclass.accept(self);
215        }
216        for method in &data.methods {
217            string += " ";
218            string += &method.accept(self);
219        }
220        string += ")";
221        string
222    }
223
224    fn visit_break_stmt(&mut self, stmt: &Stmt) -> String {
225        let Stmt::Break(_) = stmt else { unreachable!() };
226        "break".to_string()
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::scanner::Scanner;
234    use crate::parser::Parser;
235
236    #[test]
237    fn test_ast_printer() {
238        let source = "var a = 1; var b = 2; print a + b;";
239        let mut scanner = Scanner::new(source);
240        let tokens = scanner.scan_tokens();
241        let mut parser = Parser::new(tokens);
242        let statements = parser.parse();
243        let mut printer = ASTPrinter {};
244        let ast = statements.iter()
245            .map(|stmt| { stmt.accept(&mut printer) })
246            .collect::<Vec<String>>()
247            .join(" ");
248        assert_eq!(ast, "(var a = 1) (var b = 2) (print (+ a b))");
249    }
250
251    #[test]
252    fn test_ast_printer_with_grouping() {
253        let source = "print (1 + 2) * 3;";
254        let mut scanner = Scanner::new(source);
255        let tokens = scanner.scan_tokens();
256        let mut parser = Parser::new(tokens);
257        let statements = parser.parse();
258        let mut printer = ASTPrinter {};
259        let ast = statements.iter()
260            .map(|stmt| { stmt.accept(&mut printer) })
261            .collect::<Vec<String>>()
262            .join(" ");
263        assert_eq!(ast, "(print (* (group (+ 1 2)) 3))");
264    }
265
266    #[test]
267    fn test_ast_printer_with_if() {
268        let source = "if (a > 0) { print a; } else { print -a; }";
269        let mut scanner = Scanner::new(source);
270        let tokens = scanner.scan_tokens();
271        let mut parser = Parser::new(tokens);
272        let statements = parser.parse();
273        let mut printer = ASTPrinter {};
274        let ast = statements.iter()
275            .map(|stmt| { stmt.accept(&mut printer) })
276            .collect::<Vec<String>>()
277            .join(" ");
278        assert_eq!(ast, "(if (> a 0) { (print a) } else { (print (- a)) })");
279    }
280
281    #[test]
282    fn test_ast_printer_with_function() {
283        let source = "fun add(a, b) { return a + b; }";
284        let mut scanner = Scanner::new(source);
285        let tokens = scanner.scan_tokens();
286        let mut parser = Parser::new(tokens);
287        let statements = parser.parse();
288        let mut printer = ASTPrinter {};
289        let ast = statements.iter()
290            .map(|stmt| { stmt.accept(&mut printer) })
291            .collect::<Vec<String>>()
292            .join(" ");
293        assert_eq!(ast, "(fun add (a b) { (return (+ a b)) })");
294    }
295
296    #[test]
297    fn test_ast_printer_with_class() {
298        let source = "class Foo { bar() { print \"bar\"; } }";
299        let mut scanner = Scanner::new(source);
300        let tokens = scanner.scan_tokens();
301        let mut parser = Parser::new(tokens);
302        let statements = parser.parse();
303        let mut printer = ASTPrinter {};
304        let ast = statements.iter()
305            .map(|stmt| { stmt.accept(&mut printer) })
306            .collect::<Vec<String>>()
307            .join(" ");
308        assert_eq!(ast, "(class Foo (fun bar () { (print bar) }))");
309    }
310
311    #[test]
312    fn test_ast_printer_with_break() {
313        let source = "while (true) { break; }";
314        let mut scanner = Scanner::new(source);
315        let tokens = scanner.scan_tokens();
316        let mut parser = Parser::new(tokens);
317        let statements = parser.parse();
318        let mut printer = ASTPrinter {};
319        let ast = statements.iter()
320            .map(|stmt| { stmt.accept(&mut printer) })
321            .collect::<Vec<String>>()
322            .join(" ");
323        assert_eq!(ast, "(while true { break })");
324    }
325
326    #[test]
327    fn test_ast_printer_with_assignment() {
328        let source = "var a = 1; a = 2;";
329        let mut scanner = Scanner::new(source);
330        let tokens = scanner.scan_tokens();
331        let mut parser = Parser::new(tokens);
332        let statements = parser.parse();
333        let mut printer = ASTPrinter {};
334        let ast = statements.iter()
335            .map(|stmt| { stmt.accept(&mut printer) })
336            .collect::<Vec<String>>()
337            .join(" ");
338        assert_eq!(ast, "(var a = 1) (= a 2)");
339    }
340
341    #[test]
342    fn test_ast_printer_with_logical() {
343        let source = "true and false or true;";
344        let mut scanner = Scanner::new(source);
345        let tokens = scanner.scan_tokens();
346        let mut parser = Parser::new(tokens);
347        let statements = parser.parse();
348        let mut printer = ASTPrinter {};
349        let ast = statements.iter()
350            .map(|stmt| { stmt.accept(&mut printer) })
351            .collect::<Vec<String>>()
352            .join(" ");
353        assert_eq!(ast, "(or (and true false) true)");
354    }
355
356    #[test]
357    fn test_ast_printer_with_call() {
358        let source = "foo(1, 2);";
359        let mut scanner = Scanner::new(source);
360        let tokens = scanner.scan_tokens();
361        let mut parser = Parser::new(tokens);
362        let statements = parser.parse();
363        let mut printer = ASTPrinter {};
364        let ast = statements.iter()
365            .map(|stmt| { stmt.accept(&mut printer) })
366            .collect::<Vec<String>>()
367            .join("");
368        assert_eq!(ast, "(call foo (1 2))");
369    }
370
371    #[test]
372    fn test_ast_printer_with_get() {
373        let source = "foo.bar;";
374        let mut scanner = Scanner::new(source);
375        let tokens = scanner.scan_tokens();
376        let mut parser = Parser::new(tokens);
377        let statements = parser.parse();
378        let mut printer = ASTPrinter {};
379        let ast = statements.iter()
380            .map(|stmt| { stmt.accept(&mut printer) })
381            .collect::<Vec<String>>()
382            .join("");
383        assert_eq!(ast, "(get foo bar)");
384    }
385
386    #[test]
387    fn test_ast_printer_with_set() {
388        let source = "foo.bar = 1;";
389        let mut scanner = Scanner::new(source);
390        let tokens = scanner.scan_tokens();
391        let mut parser = Parser::new(tokens);
392        let statements = parser.parse();
393        let mut printer = ASTPrinter {};
394        let ast = statements.iter()
395            .map(|stmt| { stmt.accept(&mut printer) })
396            .collect::<Vec<String>>()
397            .join("");
398        assert_eq!(ast, "(set foo bar 1)");
399    }
400
401    #[test]
402    fn test_ast_printer_with_this() {
403        let source = "this.foo;";
404        let mut scanner = Scanner::new(source);
405        let tokens = scanner.scan_tokens();
406        let mut parser = Parser::new(tokens);
407        let statements = parser.parse();
408        let mut printer = ASTPrinter {};
409        let ast = statements.iter()
410            .map(|stmt| { stmt.accept(&mut printer) })
411            .collect::<String>();
412        assert_eq!(ast, "(get this foo)");
413    }
414
415    #[test]
416    fn test_ast_printer_with_super() {
417        let source = "class a < b { init() { super.init(); } }";
418        let mut scanner = Scanner::new(source);
419        let tokens = scanner.scan_tokens();
420        let mut parser = Parser::new(tokens);
421        let statements = parser.parse();
422        let mut printer = ASTPrinter {};
423        let ast = statements.iter()
424            .map(|stmt| { stmt.accept(&mut printer) })
425            .collect::<String>();
426        assert_eq!(ast, "(class a < b (fun init () { (call (super init) ()) }))");
427    }
428}