1use super::error_codes;
2use super::Ast;
3use super::Expr;
4use super::Function;
5use super::Stmt;
6
7use std::cell::Cell;
8use std::collections::HashMap;
9
10use flexi_parse::error::Error;
11use flexi_parse::new_error;
12use flexi_parse::token::Ident;
13use flexi_parse::Result;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16enum FunctionType {
17 None,
18 Function,
19 Method,
20 Initialiser,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24enum ClassType {
25 None,
26 Class,
27 SubClass,
28}
29
30struct State {
31 scopes: Vec<HashMap<String, bool>>,
32 error: Option<Error>,
33 current_function: FunctionType,
34 current_class: ClassType,
35}
36
37impl State {
38 fn declare(&mut self, name: &Ident) {
39 if let Some(scope) = self.scopes.last_mut() {
40 if scope.contains_key(name.string()) {
41 self.error(new_error(
42 "Already a variable with this name in scope".to_string(),
43 name,
44 error_codes::SHADOW,
45 ));
46 } else {
47 scope.insert(name.string().to_owned(), false);
48 }
49 }
50 }
51
52 fn define(&mut self, name: &Ident) {
53 if let Some(scope) = self.scopes.last_mut() {
54 scope.insert(name.string().to_owned(), true);
55 }
56 }
57
58 fn resolve_local(&self, name: &Ident, distance: &Cell<Option<usize>>) {
59 for (i, scope) in self.scopes.iter().enumerate().rev() {
60 if scope.contains_key(name.string()) {
61 distance.set(Some(self.scopes.len() - 1 - i));
62 return;
63 }
64 }
65 }
66
67 fn error(&mut self, error: Error) {
68 if let Some(existing_error) = &mut self.error {
69 existing_error.add(error);
70 } else {
71 self.error = Some(error);
72 }
73 }
74
75 fn begin_scope(&mut self) {
76 self.scopes.push(HashMap::new());
77 }
78
79 fn end_scope(&mut self) {
80 self.scopes.pop();
81 }
82}
83
84impl Expr {
85 fn resolve(&self, state: &mut State) {
86 match self {
87 Expr::Assign {
88 name,
89 value,
90 distance,
91 } => {
92 value.resolve(state);
93 state.resolve_local(name, distance);
94 }
95 Expr::Binary(binary) => {
96 binary.left().resolve(state);
97 binary.right().resolve(state);
98 }
99 Expr::Call {
100 callee,
101 paren: _,
102 arguments,
103 } => {
104 callee.resolve(state);
105 for argument in arguments {
106 argument.resolve(state);
107 }
108 }
109 Expr::Get { object, name: _ } => object.resolve(state),
110 Expr::Group(expr) => expr.resolve(state),
111 Expr::Literal(_) => {}
112 Expr::Logical(logical) => {
113 logical.left().resolve(state);
114 logical.right().resolve(state);
115 }
116 Expr::Set {
117 object,
118 name: _,
119 value,
120 } => {
121 object.resolve(state);
122 value.resolve(state);
123 }
124 Expr::Super {
125 keyword,
126 distance,
127 dot: _,
128 method: _,
129 } => {
130 if state.current_class == ClassType::None {
131 state.error(new_error(
132 "Can't use 'super' outside of a class".to_string(),
133 keyword,
134 error_codes::INVALID_SUPER,
135 ));
136 } else if state.current_class == ClassType::Class {
137 state.error(new_error(
138 "Can't use 'super' in a class with no superclass".to_string(),
139 keyword,
140 error_codes::INVALID_SUPER,
141 ));
142 }
143 state.resolve_local(keyword.ident(), distance);
144 }
145 Expr::This { keyword, distance } => {
146 if state.current_class == ClassType::None {
147 state.error(new_error(
148 "Can't use 'this' outside of a class".to_string(),
149 keyword,
150 error_codes::THIS_OUTSIDE_CLASS,
151 ));
152 }
153
154 state.resolve_local(keyword.ident(), distance);
155 }
156 Expr::Unary(unary) => unary.right().resolve(state),
157 Expr::Variable { name, distance } => {
158 if let Some(scope) = state.scopes.last() {
159 if scope.get(name.string()) == Some(&false) {
160 state.error(new_error(
161 "Can't read a variable in its own intialiser".to_string(),
162 name,
163 error_codes::INVALID_INITIALISER,
164 ));
165 }
166 }
167 state.resolve_local(name, distance);
168 }
169 }
170 }
171}
172
173impl Function {
174 fn resolve(&self, state: &mut State, kind: FunctionType) {
175 let enclosing = state.current_function;
176 state.current_function = kind;
177 state.begin_scope();
178 for param in &self.params {
179 state.declare(param);
180 state.define(param);
181 }
182 for stmt in &self.body {
183 stmt.resolve(state);
184 }
185 state.end_scope();
186 state.current_function = enclosing;
187 }
188}
189
190impl Stmt {
191 fn resolve(&self, state: &mut State) {
192 match self {
193 Stmt::Block(stmts) => {
194 state.begin_scope();
195 for stmt in stmts {
196 stmt.resolve(state);
197 }
198 state.end_scope();
199 }
200 Stmt::Class {
201 name,
202 superclass,
203 superclass_distance,
204 methods,
205 } => {
206 let enclosing = state.current_class;
207 state.current_class = ClassType::Class;
208
209 state.declare(name);
210 state.define(name);
211
212 if let Some(superclass) = superclass {
213 if name.string() == superclass.string() {
214 state.error(new_error(
215 "A class can't inherit from itself".to_string(),
216 superclass,
217 error_codes::CYCLICAL_INHERITANCE,
218 ));
219 }
220
221 state.current_class = ClassType::SubClass;
222
223 state.resolve_local(superclass, superclass_distance);
224
225 state.begin_scope();
226 state
227 .scopes
228 .last_mut()
229 .unwrap()
230 .insert("super".to_string(), true);
231 }
232
233 state.begin_scope();
234 state
235 .scopes
236 .last_mut()
237 .unwrap()
238 .insert("this".to_string(), true);
239
240 for method in methods {
241 let kind = if method.name.string() == "init" {
242 FunctionType::Initialiser
243 } else {
244 FunctionType::Method
245 };
246 method.resolve(state, kind);
247 }
248
249 state.end_scope();
250
251 if superclass.is_some() {
252 state.end_scope();
253 }
254
255 state.current_class = enclosing;
256 }
257 Stmt::Expr(expr) | Stmt::Print(expr) => expr.resolve(state),
258 Stmt::Function(function) => {
259 state.declare(&function.name);
260 state.declare(&function.name);
261 function.resolve(state, FunctionType::Function);
262 }
263 Stmt::If {
264 condition,
265 then_branch,
266 else_branch,
267 } => {
268 condition.resolve(state);
269 then_branch.resolve(state);
270 if let Some(else_branch) = else_branch {
271 else_branch.resolve(state);
272 }
273 }
274 Stmt::Return { keyword, value } => {
275 if state.current_function == FunctionType::None {
276 state.error(new_error(
277 "Can't return from top-level code".to_string(),
278 keyword,
279 error_codes::RETURN_OUTSIDE_FUNCTION,
280 ));
281 }
282 if let Some(value) = value {
283 value.resolve(state);
284 }
285 }
286 Stmt::Variable { name, initialiser } => {
287 state.declare(name);
288 if let Some(initialiser) = initialiser {
289 initialiser.resolve(state);
290 }
291 state.define(name);
292 }
293 Stmt::While { condition, body } => {
294 condition.resolve(state);
295 body.resolve(state);
296 }
297 }
298 }
299}
300
301pub(super) fn resolve(ast: &Ast) -> Result<()> {
302 let mut state = State {
303 scopes: vec![],
304 error: None,
305 current_function: FunctionType::None,
306 current_class: ClassType::None,
307 };
308 for stmt in &ast.0 {
309 stmt.resolve(&mut state);
310 }
311 state.error.map_or(Ok(()), Err)
312}