use std::mem;
use std::collections::HashMap;
use crate::error::{Error, ResolveError};
use crate::expr::{Expr, ExprVisitor};
use crate::stmt::{Stmt, StmtVisitor};
use crate::interpreter::Interpreter;
use crate::token::Token;
enum FunctionType {
None,
Function,
Initializer,
Method,
}
enum ClassType {
None,
Class,
Subclass,
}
enum LoopType {
None,
While,
}
pub struct Resolver<'a, 'w> {
interpreter: &'a mut Interpreter<'w>,
scopes: Vec<HashMap<String, bool>>,
current_function: FunctionType,
current_class: ClassType,
current_loop: LoopType,
}
impl<'a, 'w> Resolver<'a, 'w> {
pub fn new(interpreter: &'a mut Interpreter<'w>) -> Self {
Resolver {
interpreter,
scopes: vec![],
current_function: FunctionType::None,
current_class: ClassType::None,
current_loop: LoopType::None,
}
}
fn resolve_expr(&mut self, expr: &Expr) {
expr.accept(self);
}
fn resolve_stmt(&mut self, stmt: &Stmt) {
stmt.accept(self);
}
pub fn resolve(&mut self, statements: &Vec<Stmt>) {
for statement in statements {
self.resolve_stmt(statement)
}
}
fn resolve_function(&mut self, function: &Stmt, r#type: FunctionType) {
let Stmt::Function(function) = function else { unreachable!() };
let enclosing_function = mem::replace(&mut self.current_function, r#type);
self.begin_scope();
for param in &function.params {
self.declare(param);
self.define(param);
}
self.resolve(&function.body);
self.end_scope();
self.current_function = enclosing_function;
}
fn begin_scope(&mut self) {
self.scopes.push(HashMap::new());
}
fn end_scope(&mut self) {
self.scopes.pop();
}
fn declare(&mut self, name: &Token) {
if self.scopes.is_empty() {
return;
}
let scope = self.scopes.last_mut().expect("stack to be not empty");
if scope.contains_key(&name.lexeme) {
ResolveError {
token: name.clone(),
message: format!("A variable is already defined with name '{}' in this scope", name.lexeme),
}.throw();
}
scope.insert(name.lexeme.to_owned(), false);
}
fn define(&mut self, name: &Token) {
if self.scopes.is_empty() {
return;
}
self.scopes
.last_mut()
.expect("stack to be not empty")
.insert(name.lexeme.to_owned(), true);
}
fn resolve_local(&mut self, name: &Token) {
for (i, scope) in self.scopes.iter().rev().enumerate() {
if scope.contains_key(&name.lexeme) {
self.interpreter.resolve(name, i);
return;
}
}
}
}
impl<'a, 'w> ExprVisitor<()> for Resolver<'a, 'w> {
fn visit_variable_expr(&mut self, expr: &Expr) {
let Expr::Variable(variable) = expr else { unreachable!() };
if let Some(scope) = self.scopes.last() {
if let Some(entry) = scope.get(&variable.name.lexeme) {
if !entry {
ResolveError {
token: variable.name.to_owned(),
message: "Cannot read local variable in its own initializer".to_string(),
}.throw();
}
}
}
self.resolve_local(&variable.name);
}
fn visit_assign_expr(&mut self, expr: &Expr) {
let Expr::Assign(assign) = expr else { unreachable!() };
self.resolve_expr(&assign.value);
self.resolve_local(&assign.name);
}
fn visit_literal_expr(&mut self, expr: &Expr) {
let Expr::Literal(_) = expr else { unreachable!() };
return;
}
fn visit_logical_expr(&mut self, expr: &Expr) {
let Expr::Logical(logical) = expr else { unreachable!() };
self.resolve_expr(&logical.left);
self.resolve_expr(&logical.right);
}
fn visit_unary_expr(&mut self, expr: &Expr) {
let Expr::Unary(unary) = expr else { unreachable!() };
self.resolve_expr(&unary.expr);
}
fn visit_binary_expr(&mut self, expr: &Expr) {
let Expr::Binary(binary) = expr else { unreachable!() };
self.resolve_expr(&binary.left);
self.resolve_expr(&binary.right);
}
fn visit_grouping_expr(&mut self, expr: &Expr) {
let Expr::Grouping(grouping) = expr else { unreachable!() };
self.resolve_expr(&grouping.expr);
}
fn visit_call_expr(&mut self, expr: &Expr) {
let Expr::Call(call) = expr else { unreachable!() };
self.resolve_expr(&call.callee);
for argument in &call.arguments {
self.resolve_expr(argument);
}
}
fn visit_get_expr(&mut self, expr: &Expr) {
let Expr::Get(get) = expr else { unreachable!() };
self.resolve_expr(&get.object);
}
fn visit_set_expr(&mut self, expr: &Expr) {
let Expr::Set(set) = expr else { unreachable!() };
self.resolve_expr(&set.value);
self.resolve_expr(&set.object);
}
fn visit_this_expr(&mut self, expr: &Expr) {
let Expr::This(this) = expr else { unreachable!() };
if let ClassType::None = self.current_class {
ResolveError {
token: this.keyword.clone(),
message: "Cannot use 'this' outside of a class".to_string(),
}.throw();
return;
}
self.resolve_local(&this.keyword);
}
fn visit_super_expr(&mut self, expr: &Expr) {
let Expr::Super(super_expr) = expr else { unreachable!() };
match self.current_class {
ClassType::Subclass => (),
ClassType::None => ResolveError {
token: super_expr.keyword.clone(),
message: "Cannot use 'super' outside of a class".to_string()
}.throw(),
_ => ResolveError {
token: super_expr.keyword.clone(),
message: "Cannot use 'super' in a class with no superclass".to_string(),
}.throw(),
}
self.resolve_local(&super_expr.keyword);
}
}
impl<'a, 'w> StmtVisitor<()> for Resolver<'a, 'w> {
fn visit_block_stmt(&mut self, stmt: &Stmt) {
let Stmt::Block(block) = stmt else { unreachable!() };
self.begin_scope();
self.resolve(&block.statements);
self.end_scope();
}
fn visit_var_stmt(&mut self, stmt: &Stmt) {
let Stmt::Var(var) = stmt else { unreachable!() };
self.declare(&var.name);
if let Some(initializer) = &var.initializer {
self.resolve_expr(initializer);
}
self.define(&var.name);
}
fn visit_function_stmt(&mut self, stmt: &Stmt) {
let Stmt::Function(function) = stmt else { unreachable!() };
self.declare(&function.name);
self.define(&function.name);
self.resolve_function(stmt, FunctionType::Function);
}
fn visit_expression_stmt(&mut self, stmt: &Stmt) {
let Stmt::Expression(expr) = stmt else { unreachable!() };
self.resolve_expr(&expr.expr);
}
fn visit_if_stmt(&mut self, stmt: &Stmt) {
let Stmt::If(if_stmt) = stmt else { unreachable!() };
self.resolve_expr(&if_stmt.condition);
self.resolve_stmt(&if_stmt.then_branch);
if let Some(else_branch) = &if_stmt.else_branch {
self.resolve_stmt(else_branch);
}
}
fn visit_print_stmt(&mut self, stmt: &Stmt) {
let Stmt::Print(print) = stmt else { unreachable!() };
self.resolve_expr(&print.expr);
}
fn visit_return_stmt(&mut self, stmt: &Stmt) {
let Stmt::Return(return_stmt) = stmt else { unreachable!() };
if let FunctionType::None = self.current_function {
ResolveError {
token: return_stmt.keyword.clone(),
message: "Cannot return from top-level code".to_string(),
}.throw();
}
if let Some(value) = &return_stmt.value {
if let FunctionType::Initializer = self.current_function {
ResolveError {
token: return_stmt.keyword.clone(),
message: "Cannot return a value from an initializer".to_string(),
}.throw();
return;
}
self.resolve_expr(value);
}
}
fn visit_break_stmt(&mut self, stmt: &Stmt) {
let Stmt::Break(break_stmt) = stmt else { unreachable!() };
if let LoopType::None = self.current_loop {
ResolveError {
token: break_stmt.keyword.clone(),
message: "Cannot break outside of a loop".to_string(),
}.throw();
}
}
fn visit_while_stmt(&mut self, stmt: &Stmt) {
let Stmt::While(while_stmt) = stmt else { unreachable!() };
let enclosing_loop = mem::replace(&mut self.current_loop, LoopType::While);
self.resolve_expr(&while_stmt.condition);
self.resolve_stmt(&while_stmt.body);
self.current_loop = enclosing_loop;
}
fn visit_class_stmt(&mut self, stmt: &Stmt) {
let Stmt::Class(class_stmt) = stmt else { unreachable!() };
let enclosing_class = mem::replace(&mut self.current_class, ClassType::Class);
self.declare(&class_stmt.name);
self.define(&class_stmt.name);
if let Some(ref superclass) = class_stmt.superclass {
if let Expr::Variable(variable) = superclass {
if class_stmt.name.lexeme == variable.name.lexeme {
ResolveError {
token: variable.name.clone(),
message: "A class cannot inherit from itself".to_string(),
}.throw();
}
} else {
unreachable!();
}
self.current_class = ClassType::Subclass;
self.resolve_expr(superclass);
self.begin_scope();
self.scopes
.last_mut()
.expect("stack to be not empty")
.insert("super".to_string(), true);
}
self.begin_scope();
self.scopes
.last_mut()
.expect("stack to be not empty")
.insert("this".to_string(), true);
for method in &class_stmt.methods {
if let Stmt::Function(function) = method {
let decleration = if function.name.lexeme.eq("init") {
FunctionType::Initializer
} else {
FunctionType::Method
};
self.resolve_function(method, decleration);
} else {
unreachable!();
}
}
self.end_scope();
if class_stmt.superclass.is_some() {
self.end_scope();
}
self.current_class = enclosing_class;
}
}