lox/
resolver.rs

1use super::error_codes;
2use super::Ast;
3use super::Expr;
4use super::Function;
5use super::Stmt;
6
7use std::cell::Cell;
8use std::collections::HashMap;
9
10use flexi_parse::error::Error;
11use flexi_parse::new_error;
12use flexi_parse::token::Ident;
13use flexi_parse::Result;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16enum FunctionType {
17    None,
18    Function,
19    Method,
20    Initialiser,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24enum ClassType {
25    None,
26    Class,
27    SubClass,
28}
29
30struct State {
31    scopes: Vec<HashMap<String, bool>>,
32    error: Option<Error>,
33    current_function: FunctionType,
34    current_class: ClassType,
35}
36
37impl State {
38    fn declare(&mut self, name: &Ident) {
39        if let Some(scope) = self.scopes.last_mut() {
40            if scope.contains_key(name.string()) {
41                self.error(new_error(
42                    "Already a variable with this name in scope".to_string(),
43                    name,
44                    error_codes::SHADOW,
45                ));
46            } else {
47                scope.insert(name.string().to_owned(), false);
48            }
49        }
50    }
51
52    fn define(&mut self, name: &Ident) {
53        if let Some(scope) = self.scopes.last_mut() {
54            scope.insert(name.string().to_owned(), true);
55        }
56    }
57
58    fn resolve_local(&self, name: &Ident, distance: &Cell<Option<usize>>) {
59        for (i, scope) in self.scopes.iter().enumerate().rev() {
60            if scope.contains_key(name.string()) {
61                distance.set(Some(self.scopes.len() - 1 - i));
62                return;
63            }
64        }
65    }
66
67    fn error(&mut self, error: Error) {
68        if let Some(existing_error) = &mut self.error {
69            existing_error.add(error);
70        } else {
71            self.error = Some(error);
72        }
73    }
74
75    fn begin_scope(&mut self) {
76        self.scopes.push(HashMap::new());
77    }
78
79    fn end_scope(&mut self) {
80        self.scopes.pop();
81    }
82}
83
84impl Expr {
85    fn resolve(&self, state: &mut State) {
86        match self {
87            Expr::Assign {
88                name,
89                value,
90                distance,
91            } => {
92                value.resolve(state);
93                state.resolve_local(name, distance);
94            }
95            Expr::Binary(binary) => {
96                binary.left().resolve(state);
97                binary.right().resolve(state);
98            }
99            Expr::Call {
100                callee,
101                paren: _,
102                arguments,
103            } => {
104                callee.resolve(state);
105                for argument in arguments {
106                    argument.resolve(state);
107                }
108            }
109            Expr::Get { object, name: _ } => object.resolve(state),
110            Expr::Group(expr) => expr.resolve(state),
111            Expr::Literal(_) => {}
112            Expr::Logical(logical) => {
113                logical.left().resolve(state);
114                logical.right().resolve(state);
115            }
116            Expr::Set {
117                object,
118                name: _,
119                value,
120            } => {
121                object.resolve(state);
122                value.resolve(state);
123            }
124            Expr::Super {
125                keyword,
126                distance,
127                dot: _,
128                method: _,
129            } => {
130                if state.current_class == ClassType::None {
131                    state.error(new_error(
132                        "Can't use 'super' outside of a class".to_string(),
133                        keyword,
134                        error_codes::INVALID_SUPER,
135                    ));
136                } else if state.current_class == ClassType::Class {
137                    state.error(new_error(
138                        "Can't use 'super' in a class with no superclass".to_string(),
139                        keyword,
140                        error_codes::INVALID_SUPER,
141                    ));
142                }
143                state.resolve_local(keyword.ident(), distance);
144            }
145            Expr::This { keyword, distance } => {
146                if state.current_class == ClassType::None {
147                    state.error(new_error(
148                        "Can't use 'this' outside of a class".to_string(),
149                        keyword,
150                        error_codes::THIS_OUTSIDE_CLASS,
151                    ));
152                }
153
154                state.resolve_local(keyword.ident(), distance);
155            }
156            Expr::Unary(unary) => unary.right().resolve(state),
157            Expr::Variable { name, distance } => {
158                if let Some(scope) = state.scopes.last() {
159                    if scope.get(name.string()) == Some(&false) {
160                        state.error(new_error(
161                            "Can't read a variable in its own intialiser".to_string(),
162                            name,
163                            error_codes::INVALID_INITIALISER,
164                        ));
165                    }
166                }
167                state.resolve_local(name, distance);
168            }
169        }
170    }
171}
172
173impl Function {
174    fn resolve(&self, state: &mut State, kind: FunctionType) {
175        let enclosing = state.current_function;
176        state.current_function = kind;
177        state.begin_scope();
178        for param in &self.params {
179            state.declare(param);
180            state.define(param);
181        }
182        for stmt in &self.body {
183            stmt.resolve(state);
184        }
185        state.end_scope();
186        state.current_function = enclosing;
187    }
188}
189
190impl Stmt {
191    fn resolve(&self, state: &mut State) {
192        match self {
193            Stmt::Block(stmts) => {
194                state.begin_scope();
195                for stmt in stmts {
196                    stmt.resolve(state);
197                }
198                state.end_scope();
199            }
200            Stmt::Class {
201                name,
202                superclass,
203                superclass_distance,
204                methods,
205            } => {
206                let enclosing = state.current_class;
207                state.current_class = ClassType::Class;
208
209                state.declare(name);
210                state.define(name);
211
212                if let Some(superclass) = superclass {
213                    if name.string() == superclass.string() {
214                        state.error(new_error(
215                            "A class can't inherit from itself".to_string(),
216                            superclass,
217                            error_codes::CYCLICAL_INHERITANCE,
218                        ));
219                    }
220
221                    state.current_class = ClassType::SubClass;
222
223                    state.resolve_local(superclass, superclass_distance);
224
225                    state.begin_scope();
226                    state
227                        .scopes
228                        .last_mut()
229                        .unwrap()
230                        .insert("super".to_string(), true);
231                }
232
233                state.begin_scope();
234                state
235                    .scopes
236                    .last_mut()
237                    .unwrap()
238                    .insert("this".to_string(), true);
239
240                for method in methods {
241                    let kind = if method.name.string() == "init" {
242                        FunctionType::Initialiser
243                    } else {
244                        FunctionType::Method
245                    };
246                    method.resolve(state, kind);
247                }
248
249                state.end_scope();
250
251                if superclass.is_some() {
252                    state.end_scope();
253                }
254
255                state.current_class = enclosing;
256            }
257            Stmt::Expr(expr) | Stmt::Print(expr) => expr.resolve(state),
258            Stmt::Function(function) => {
259                state.declare(&function.name);
260                state.declare(&function.name);
261                function.resolve(state, FunctionType::Function);
262            }
263            Stmt::If {
264                condition,
265                then_branch,
266                else_branch,
267            } => {
268                condition.resolve(state);
269                then_branch.resolve(state);
270                if let Some(else_branch) = else_branch {
271                    else_branch.resolve(state);
272                }
273            }
274            Stmt::Return { keyword, value } => {
275                if state.current_function == FunctionType::None {
276                    state.error(new_error(
277                        "Can't return from top-level code".to_string(),
278                        keyword,
279                        error_codes::RETURN_OUTSIDE_FUNCTION,
280                    ));
281                }
282                if let Some(value) = value {
283                    value.resolve(state);
284                }
285            }
286            Stmt::Variable { name, initialiser } => {
287                state.declare(name);
288                if let Some(initialiser) = initialiser {
289                    initialiser.resolve(state);
290                }
291                state.define(name);
292            }
293            Stmt::While { condition, body } => {
294                condition.resolve(state);
295                body.resolve(state);
296            }
297        }
298    }
299}
300
301pub(super) fn resolve(ast: &Ast) -> Result<()> {
302    let mut state = State {
303        scopes: vec![],
304        error: None,
305        current_function: FunctionType::None,
306        current_class: ClassType::None,
307    };
308    for stmt in &ast.0 {
309        stmt.resolve(&mut state);
310    }
311    state.error.map_or(Ok(()), Err)
312}