rocks_lang/
resolver.rs

1use std::mem;
2use std::collections::HashMap;
3
4use crate::error::{Error, ResolveError};
5use crate::expr::{Expr, ExprVisitor};
6use crate::stmt::{Stmt, StmtVisitor};
7use crate::interpreter::Interpreter;
8use crate::token::Token;
9
10enum FunctionType {
11    None,
12    Function,
13    Initializer,
14    Method,
15}
16
17enum ClassType {
18    None,
19    Class,
20    Subclass,
21}
22
23enum LoopType {
24    None,
25    While,
26}
27
28pub struct Resolver<'a, 'w> {
29    interpreter: &'a mut Interpreter<'w>,
30    scopes: Vec<HashMap<String, bool>>,
31    current_function: FunctionType,
32    current_class: ClassType,
33    current_loop: LoopType,
34}
35
36impl<'a, 'w> Resolver<'a, 'w> {
37    pub fn new(interpreter: &'a mut Interpreter<'w>) -> Self {
38        Resolver {
39            interpreter,
40            scopes: vec![],
41            current_function: FunctionType::None,
42            current_class: ClassType::None,
43            current_loop: LoopType::None,
44        }
45    }
46
47    fn resolve_expr(&mut self, expr: &Expr) {
48        expr.accept(self);
49    }
50
51    fn resolve_stmt(&mut self, stmt: &Stmt) {
52        stmt.accept(self);
53    }
54
55    pub fn resolve(&mut self, statements: &Vec<Stmt>) {
56        for statement in statements {
57            self.resolve_stmt(statement)
58        }
59    }
60
61    fn resolve_function(&mut self, function: &Stmt, r#type: FunctionType) {
62        let Stmt::Function(function) = function else { unreachable!() };
63
64        let enclosing_function = mem::replace(&mut self.current_function, r#type);
65
66        self.begin_scope();
67        for param in &function.params {
68            self.declare(param);
69            self.define(param);
70        }
71        self.resolve(&function.body);
72        self.end_scope();
73
74        self.current_function = enclosing_function;
75    }
76
77    fn begin_scope(&mut self) {
78        self.scopes.push(HashMap::new());
79    }
80
81    fn end_scope(&mut self) {
82        self.scopes.pop();
83    }
84
85    fn declare(&mut self, name: &Token) {
86        if self.scopes.is_empty() {
87            return;
88        }
89
90        let scope = self.scopes.last_mut().expect("stack to be not empty");
91        if scope.contains_key(&name.lexeme) {
92            ResolveError {
93                token: name.clone(),
94                message: format!("A variable is already defined with name '{}' in this scope", name.lexeme),
95            }.throw();
96        }
97        scope.insert(name.lexeme.to_owned(), false);
98    }
99
100    fn define(&mut self, name: &Token) {
101        if self.scopes.is_empty() {
102            return;
103        }
104
105        self.scopes
106            .last_mut()
107            .expect("stack to be not empty")
108            .insert(name.lexeme.to_owned(), true);
109    }
110
111    fn resolve_local(&mut self, name: &Token) {
112        for (i, scope) in self.scopes.iter().rev().enumerate() {
113            if scope.contains_key(&name.lexeme) {
114                self.interpreter.resolve(name, i);
115                return;
116            }
117        }
118    }
119}
120
121impl<'a, 'w> ExprVisitor<()> for Resolver<'a, 'w> {
122    fn visit_variable_expr(&mut self, expr: &Expr) {
123        let Expr::Variable(variable) = expr else { unreachable!() };
124
125        if let Some(scope) = self.scopes.last() {
126            if let Some(entry) = scope.get(&variable.name.lexeme) {
127                if !entry {
128                    ResolveError {
129                        token: variable.name.to_owned(),
130                        message: "Cannot read local variable in its own initializer".to_string(),
131                    }.throw();
132                }
133            }
134        }
135
136        self.resolve_local(&variable.name);
137    }
138
139    fn visit_assign_expr(&mut self, expr: &Expr) {
140        let Expr::Assign(assign) = expr else { unreachable!() };
141
142        self.resolve_expr(&assign.value);
143        self.resolve_local(&assign.name);
144    }
145
146    fn visit_literal_expr(&mut self, expr: &Expr) {
147        let Expr::Literal(_) = expr else { unreachable!() };
148
149        return;
150    }
151
152    fn visit_logical_expr(&mut self, expr: &Expr) {
153        let Expr::Logical(logical) = expr else { unreachable!() };
154
155        self.resolve_expr(&logical.left);
156        self.resolve_expr(&logical.right);
157    }
158
159    fn visit_unary_expr(&mut self, expr: &Expr) {
160        let Expr::Unary(unary) = expr else { unreachable!() };
161
162        self.resolve_expr(&unary.expr);
163    }
164
165    fn visit_binary_expr(&mut self, expr: &Expr) {
166        let Expr::Binary(binary) = expr else { unreachable!() };
167
168        self.resolve_expr(&binary.left);
169        self.resolve_expr(&binary.right);
170    }
171
172    fn visit_grouping_expr(&mut self, expr: &Expr) {
173        let Expr::Grouping(grouping) = expr else { unreachable!() };
174
175        self.resolve_expr(&grouping.expr);
176    }
177
178    fn visit_call_expr(&mut self, expr: &Expr) {
179        let Expr::Call(call) = expr else { unreachable!() };
180
181        self.resolve_expr(&call.callee);
182
183        for argument in &call.arguments {
184            self.resolve_expr(argument);
185        }
186    }
187
188    fn visit_get_expr(&mut self, expr: &Expr) {
189        let Expr::Get(get) = expr else { unreachable!() };
190
191        self.resolve_expr(&get.object);
192    }
193
194    fn visit_set_expr(&mut self, expr: &Expr) {
195        let Expr::Set(set) = expr else { unreachable!() };
196
197        self.resolve_expr(&set.value);
198        self.resolve_expr(&set.object);
199    }
200
201    fn visit_this_expr(&mut self, expr: &Expr) {
202        let Expr::This(this) = expr else { unreachable!() };
203
204        if let ClassType::None = self.current_class {
205            ResolveError {
206                token: this.keyword.clone(),
207                message: "Cannot use 'this' outside of a class".to_string(),
208            }.throw();
209
210            return;
211        }
212
213        self.resolve_local(&this.keyword);
214    }
215
216    fn visit_super_expr(&mut self, expr: &Expr) {
217        let Expr::Super(super_expr) = expr else { unreachable!() };
218
219        match self.current_class {
220            ClassType::Subclass => (),
221            ClassType::None => ResolveError {
222                token: super_expr.keyword.clone(),
223                message: "Cannot use 'super' outside of a class".to_string()
224            }.throw(),
225            _ => ResolveError {
226                token: super_expr.keyword.clone(),
227                message: "Cannot use 'super' in a class with no superclass".to_string(),
228            }.throw(),
229        }
230
231        self.resolve_local(&super_expr.keyword);
232    }
233}
234
235impl<'a, 'w> StmtVisitor<()> for Resolver<'a, 'w> {
236    fn visit_block_stmt(&mut self, stmt: &Stmt) {
237        let Stmt::Block(block) = stmt else { unreachable!() };
238
239        self.begin_scope();
240        self.resolve(&block.statements);
241        self.end_scope();
242    }
243
244    fn visit_var_stmt(&mut self, stmt: &Stmt) {
245        let Stmt::Var(var) = stmt else { unreachable!() };
246
247        self.declare(&var.name);
248        if let Some(initializer) = &var.initializer {
249            self.resolve_expr(initializer);
250        }
251        self.define(&var.name);
252    }
253
254    fn visit_function_stmt(&mut self, stmt: &Stmt) {
255        let Stmt::Function(function) = stmt else { unreachable!() };
256
257        self.declare(&function.name);
258        self.define(&function.name);
259
260        self.resolve_function(stmt, FunctionType::Function);
261    }
262
263    fn visit_expression_stmt(&mut self, stmt: &Stmt) {
264        let Stmt::Expression(expr) = stmt else { unreachable!() };
265
266        self.resolve_expr(&expr.expr);
267    }
268
269    fn visit_if_stmt(&mut self, stmt: &Stmt) {
270        let Stmt::If(if_stmt) = stmt else { unreachable!() };
271
272        self.resolve_expr(&if_stmt.condition);
273        self.resolve_stmt(&if_stmt.then_branch);
274        if let Some(else_branch) = &if_stmt.else_branch {
275            self.resolve_stmt(else_branch);
276        }
277    }
278
279    fn visit_print_stmt(&mut self, stmt: &Stmt) {
280        let Stmt::Print(print) = stmt else { unreachable!() };
281
282        self.resolve_expr(&print.expr);
283    }
284
285    fn visit_return_stmt(&mut self, stmt: &Stmt) {
286        let Stmt::Return(return_stmt) = stmt else { unreachable!() };
287
288        if let FunctionType::None = self.current_function {
289            ResolveError {
290                token: return_stmt.keyword.clone(),
291                message: "Cannot return from top-level code".to_string(),
292            }.throw();
293        }
294
295        if let Some(value) = &return_stmt.value {
296            if let FunctionType::Initializer = self.current_function {
297                ResolveError {
298                    token: return_stmt.keyword.clone(),
299                    message: "Cannot return a value from an initializer".to_string(),
300                }.throw();
301                return;
302            }
303
304            self.resolve_expr(value);
305        }
306    }
307
308    fn visit_break_stmt(&mut self, stmt: &Stmt) {
309        let Stmt::Break(break_stmt) = stmt else { unreachable!() };
310
311        if let LoopType::None = self.current_loop {
312            ResolveError {
313                token: break_stmt.keyword.clone(),
314                message: "Cannot break outside of a loop".to_string(),
315            }.throw();
316        }
317    }
318
319    fn visit_while_stmt(&mut self, stmt: &Stmt) {
320        let Stmt::While(while_stmt) = stmt else { unreachable!() };
321
322        let enclosing_loop = mem::replace(&mut self.current_loop, LoopType::While);
323
324        self.resolve_expr(&while_stmt.condition);
325        self.resolve_stmt(&while_stmt.body);
326
327        self.current_loop = enclosing_loop;
328    }
329
330    fn visit_class_stmt(&mut self, stmt: &Stmt) {
331        let Stmt::Class(class_stmt) = stmt else { unreachable!() };
332
333        let enclosing_class = mem::replace(&mut self.current_class, ClassType::Class);
334
335        self.declare(&class_stmt.name);
336        self.define(&class_stmt.name);
337
338        if let Some(ref superclass) = class_stmt.superclass {
339            if let Expr::Variable(variable) = superclass {
340                if class_stmt.name.lexeme == variable.name.lexeme {
341                    ResolveError {
342                        token: variable.name.clone(),
343                        message: "A class cannot inherit from itself".to_string(),
344                    }.throw();
345                }
346            } else {
347                unreachable!();
348            }
349
350            self.current_class = ClassType::Subclass;
351
352            self.resolve_expr(superclass);
353
354            self.begin_scope();
355            self.scopes
356                .last_mut()
357                .expect("stack to be not empty")
358                .insert("super".to_string(), true);
359        }
360
361        self.begin_scope();
362        self.scopes
363            .last_mut()
364            .expect("stack to be not empty")
365            .insert("this".to_string(), true);
366
367        for method in &class_stmt.methods {
368            if let Stmt::Function(function) = method {
369                let decleration = if function.name.lexeme.eq("init") {
370                    FunctionType::Initializer
371                } else {
372                    FunctionType::Method
373                };
374                self.resolve_function(method, decleration);
375            } else {
376                unreachable!();
377            }
378        }
379
380        self.end_scope();
381
382        if class_stmt.superclass.is_some() {
383            self.end_scope();
384        }
385
386        self.current_class = enclosing_class;
387    }
388}