1use std::mem;
2use std::collections::HashMap;
3
4use crate::error::{Error, ResolveError};
5use crate::expr::{Expr, ExprVisitor};
6use crate::stmt::{Stmt, StmtVisitor};
7use crate::interpreter::Interpreter;
8use crate::token::Token;
9
10enum FunctionType {
11 None,
12 Function,
13 Initializer,
14 Method,
15}
16
17enum ClassType {
18 None,
19 Class,
20 Subclass,
21}
22
23enum LoopType {
24 None,
25 While,
26}
27
28pub struct Resolver<'a, 'w> {
29 interpreter: &'a mut Interpreter<'w>,
30 scopes: Vec<HashMap<String, bool>>,
31 current_function: FunctionType,
32 current_class: ClassType,
33 current_loop: LoopType,
34}
35
36impl<'a, 'w> Resolver<'a, 'w> {
37 pub fn new(interpreter: &'a mut Interpreter<'w>) -> Self {
38 Resolver {
39 interpreter,
40 scopes: vec![],
41 current_function: FunctionType::None,
42 current_class: ClassType::None,
43 current_loop: LoopType::None,
44 }
45 }
46
47 fn resolve_expr(&mut self, expr: &Expr) {
48 expr.accept(self);
49 }
50
51 fn resolve_stmt(&mut self, stmt: &Stmt) {
52 stmt.accept(self);
53 }
54
55 pub fn resolve(&mut self, statements: &Vec<Stmt>) {
56 for statement in statements {
57 self.resolve_stmt(statement)
58 }
59 }
60
61 fn resolve_function(&mut self, function: &Stmt, r#type: FunctionType) {
62 let Stmt::Function(function) = function else { unreachable!() };
63
64 let enclosing_function = mem::replace(&mut self.current_function, r#type);
65
66 self.begin_scope();
67 for param in &function.params {
68 self.declare(param);
69 self.define(param);
70 }
71 self.resolve(&function.body);
72 self.end_scope();
73
74 self.current_function = enclosing_function;
75 }
76
77 fn begin_scope(&mut self) {
78 self.scopes.push(HashMap::new());
79 }
80
81 fn end_scope(&mut self) {
82 self.scopes.pop();
83 }
84
85 fn declare(&mut self, name: &Token) {
86 if self.scopes.is_empty() {
87 return;
88 }
89
90 let scope = self.scopes.last_mut().expect("stack to be not empty");
91 if scope.contains_key(&name.lexeme) {
92 ResolveError {
93 token: name.clone(),
94 message: format!("A variable is already defined with name '{}' in this scope", name.lexeme),
95 }.throw();
96 }
97 scope.insert(name.lexeme.to_owned(), false);
98 }
99
100 fn define(&mut self, name: &Token) {
101 if self.scopes.is_empty() {
102 return;
103 }
104
105 self.scopes
106 .last_mut()
107 .expect("stack to be not empty")
108 .insert(name.lexeme.to_owned(), true);
109 }
110
111 fn resolve_local(&mut self, name: &Token) {
112 for (i, scope) in self.scopes.iter().rev().enumerate() {
113 if scope.contains_key(&name.lexeme) {
114 self.interpreter.resolve(name, i);
115 return;
116 }
117 }
118 }
119}
120
121impl<'a, 'w> ExprVisitor<()> for Resolver<'a, 'w> {
122 fn visit_variable_expr(&mut self, expr: &Expr) {
123 let Expr::Variable(variable) = expr else { unreachable!() };
124
125 if let Some(scope) = self.scopes.last() {
126 if let Some(entry) = scope.get(&variable.name.lexeme) {
127 if !entry {
128 ResolveError {
129 token: variable.name.to_owned(),
130 message: "Cannot read local variable in its own initializer".to_string(),
131 }.throw();
132 }
133 }
134 }
135
136 self.resolve_local(&variable.name);
137 }
138
139 fn visit_assign_expr(&mut self, expr: &Expr) {
140 let Expr::Assign(assign) = expr else { unreachable!() };
141
142 self.resolve_expr(&assign.value);
143 self.resolve_local(&assign.name);
144 }
145
146 fn visit_literal_expr(&mut self, expr: &Expr) {
147 let Expr::Literal(_) = expr else { unreachable!() };
148
149 return;
150 }
151
152 fn visit_logical_expr(&mut self, expr: &Expr) {
153 let Expr::Logical(logical) = expr else { unreachable!() };
154
155 self.resolve_expr(&logical.left);
156 self.resolve_expr(&logical.right);
157 }
158
159 fn visit_unary_expr(&mut self, expr: &Expr) {
160 let Expr::Unary(unary) = expr else { unreachable!() };
161
162 self.resolve_expr(&unary.expr);
163 }
164
165 fn visit_binary_expr(&mut self, expr: &Expr) {
166 let Expr::Binary(binary) = expr else { unreachable!() };
167
168 self.resolve_expr(&binary.left);
169 self.resolve_expr(&binary.right);
170 }
171
172 fn visit_grouping_expr(&mut self, expr: &Expr) {
173 let Expr::Grouping(grouping) = expr else { unreachable!() };
174
175 self.resolve_expr(&grouping.expr);
176 }
177
178 fn visit_call_expr(&mut self, expr: &Expr) {
179 let Expr::Call(call) = expr else { unreachable!() };
180
181 self.resolve_expr(&call.callee);
182
183 for argument in &call.arguments {
184 self.resolve_expr(argument);
185 }
186 }
187
188 fn visit_get_expr(&mut self, expr: &Expr) {
189 let Expr::Get(get) = expr else { unreachable!() };
190
191 self.resolve_expr(&get.object);
192 }
193
194 fn visit_set_expr(&mut self, expr: &Expr) {
195 let Expr::Set(set) = expr else { unreachable!() };
196
197 self.resolve_expr(&set.value);
198 self.resolve_expr(&set.object);
199 }
200
201 fn visit_this_expr(&mut self, expr: &Expr) {
202 let Expr::This(this) = expr else { unreachable!() };
203
204 if let ClassType::None = self.current_class {
205 ResolveError {
206 token: this.keyword.clone(),
207 message: "Cannot use 'this' outside of a class".to_string(),
208 }.throw();
209
210 return;
211 }
212
213 self.resolve_local(&this.keyword);
214 }
215
216 fn visit_super_expr(&mut self, expr: &Expr) {
217 let Expr::Super(super_expr) = expr else { unreachable!() };
218
219 match self.current_class {
220 ClassType::Subclass => (),
221 ClassType::None => ResolveError {
222 token: super_expr.keyword.clone(),
223 message: "Cannot use 'super' outside of a class".to_string()
224 }.throw(),
225 _ => ResolveError {
226 token: super_expr.keyword.clone(),
227 message: "Cannot use 'super' in a class with no superclass".to_string(),
228 }.throw(),
229 }
230
231 self.resolve_local(&super_expr.keyword);
232 }
233}
234
235impl<'a, 'w> StmtVisitor<()> for Resolver<'a, 'w> {
236 fn visit_block_stmt(&mut self, stmt: &Stmt) {
237 let Stmt::Block(block) = stmt else { unreachable!() };
238
239 self.begin_scope();
240 self.resolve(&block.statements);
241 self.end_scope();
242 }
243
244 fn visit_var_stmt(&mut self, stmt: &Stmt) {
245 let Stmt::Var(var) = stmt else { unreachable!() };
246
247 self.declare(&var.name);
248 if let Some(initializer) = &var.initializer {
249 self.resolve_expr(initializer);
250 }
251 self.define(&var.name);
252 }
253
254 fn visit_function_stmt(&mut self, stmt: &Stmt) {
255 let Stmt::Function(function) = stmt else { unreachable!() };
256
257 self.declare(&function.name);
258 self.define(&function.name);
259
260 self.resolve_function(stmt, FunctionType::Function);
261 }
262
263 fn visit_expression_stmt(&mut self, stmt: &Stmt) {
264 let Stmt::Expression(expr) = stmt else { unreachable!() };
265
266 self.resolve_expr(&expr.expr);
267 }
268
269 fn visit_if_stmt(&mut self, stmt: &Stmt) {
270 let Stmt::If(if_stmt) = stmt else { unreachable!() };
271
272 self.resolve_expr(&if_stmt.condition);
273 self.resolve_stmt(&if_stmt.then_branch);
274 if let Some(else_branch) = &if_stmt.else_branch {
275 self.resolve_stmt(else_branch);
276 }
277 }
278
279 fn visit_print_stmt(&mut self, stmt: &Stmt) {
280 let Stmt::Print(print) = stmt else { unreachable!() };
281
282 self.resolve_expr(&print.expr);
283 }
284
285 fn visit_return_stmt(&mut self, stmt: &Stmt) {
286 let Stmt::Return(return_stmt) = stmt else { unreachable!() };
287
288 if let FunctionType::None = self.current_function {
289 ResolveError {
290 token: return_stmt.keyword.clone(),
291 message: "Cannot return from top-level code".to_string(),
292 }.throw();
293 }
294
295 if let Some(value) = &return_stmt.value {
296 if let FunctionType::Initializer = self.current_function {
297 ResolveError {
298 token: return_stmt.keyword.clone(),
299 message: "Cannot return a value from an initializer".to_string(),
300 }.throw();
301 return;
302 }
303
304 self.resolve_expr(value);
305 }
306 }
307
308 fn visit_break_stmt(&mut self, stmt: &Stmt) {
309 let Stmt::Break(break_stmt) = stmt else { unreachable!() };
310
311 if let LoopType::None = self.current_loop {
312 ResolveError {
313 token: break_stmt.keyword.clone(),
314 message: "Cannot break outside of a loop".to_string(),
315 }.throw();
316 }
317 }
318
319 fn visit_while_stmt(&mut self, stmt: &Stmt) {
320 let Stmt::While(while_stmt) = stmt else { unreachable!() };
321
322 let enclosing_loop = mem::replace(&mut self.current_loop, LoopType::While);
323
324 self.resolve_expr(&while_stmt.condition);
325 self.resolve_stmt(&while_stmt.body);
326
327 self.current_loop = enclosing_loop;
328 }
329
330 fn visit_class_stmt(&mut self, stmt: &Stmt) {
331 let Stmt::Class(class_stmt) = stmt else { unreachable!() };
332
333 let enclosing_class = mem::replace(&mut self.current_class, ClassType::Class);
334
335 self.declare(&class_stmt.name);
336 self.define(&class_stmt.name);
337
338 if let Some(ref superclass) = class_stmt.superclass {
339 if let Expr::Variable(variable) = superclass {
340 if class_stmt.name.lexeme == variable.name.lexeme {
341 ResolveError {
342 token: variable.name.clone(),
343 message: "A class cannot inherit from itself".to_string(),
344 }.throw();
345 }
346 } else {
347 unreachable!();
348 }
349
350 self.current_class = ClassType::Subclass;
351
352 self.resolve_expr(superclass);
353
354 self.begin_scope();
355 self.scopes
356 .last_mut()
357 .expect("stack to be not empty")
358 .insert("super".to_string(), true);
359 }
360
361 self.begin_scope();
362 self.scopes
363 .last_mut()
364 .expect("stack to be not empty")
365 .insert("this".to_string(), true);
366
367 for method in &class_stmt.methods {
368 if let Stmt::Function(function) = method {
369 let decleration = if function.name.lexeme.eq("init") {
370 FunctionType::Initializer
371 } else {
372 FunctionType::Method
373 };
374 self.resolve_function(method, decleration);
375 } else {
376 unreachable!();
377 }
378 }
379
380 self.end_scope();
381
382 if class_stmt.superclass.is_some() {
383 self.end_scope();
384 }
385
386 self.current_class = enclosing_class;
387 }
388}