flexi-parse 0.11.0

Simple, flexible parsing
use super::Ast;
use super::Expr;
use super::Function;
use super::Stmt;
use super::error_codes;

use std::cell::Cell;
use std::collections::HashMap;

use flexi_parse::Result;
use flexi_parse::error::Error;
use flexi_parse::new_error;
use flexi_parse::token::Ident;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum FunctionType {
    None,
    Function,
    Method,
    Initialiser,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ClassType {
    None,
    Class,
    SubClass,
}

struct State {
    scopes: Vec<HashMap<String, bool>>,
    error: Option<Error>,
    current_function: FunctionType,
    current_class: ClassType,
}

impl State {
    fn declare(&mut self, name: &Ident) {
        if let Some(scope) = self.scopes.last_mut() {
            if scope.contains_key(name.string()) {
                self.error(new_error(
                    "Already a variable with this name in scope".to_string(),
                    name,
                    error_codes::SHADOW,
                ));
            } else {
                scope.insert(name.string().to_owned(), false);
            }
        }
    }

    fn define(&mut self, name: &Ident) {
        self.scopes
            .last_mut()
            .unwrap()
            .insert(name.string().to_owned(), true);
    }

    fn resolve_local(&self, name: &Ident, distance: &Cell<Option<usize>>) {
        for (i, scope) in self.scopes.iter().enumerate().rev() {
            if scope.contains_key(name.string()) {
                distance.set(Some(self.scopes.len() - 1 - i));
                return;
            }
        }
    }

    fn error(&mut self, error: Error) {
        if let Some(existing_error) = &mut self.error {
            existing_error.add(error);
        } else {
            self.error = Some(error);
        }
    }

    fn with_scope(&mut self, f: impl FnOnce(&mut State)) {
        self.scopes.push(HashMap::with_capacity(10));
        f(self);
        self.scopes.pop();
    }
}

impl Expr {
    fn resolve(&self, state: &mut State) {
        match self {
            Expr::Assign {
                name,
                value,
                distance,
            } => {
                value.resolve(state);
                state.resolve_local(name, distance);
            }
            Expr::Binary(binary) => {
                binary.left().resolve(state);
                binary.right().resolve(state);
            }
            Expr::Call {
                callee,
                paren: _,
                arguments,
            } => {
                callee.resolve(state);
                for argument in arguments {
                    argument.resolve(state);
                }
            }
            Expr::Get { object, name: _ } => object.resolve(state),
            Expr::Group(expr) => expr.resolve(state),
            Expr::Literal(_) => {}
            Expr::Logical(logical) => {
                logical.left().resolve(state);
                logical.right().resolve(state);
            }
            Expr::Set {
                object,
                name: _,
                value,
            } => {
                object.resolve(state);
                value.resolve(state);
            }
            Expr::Super {
                keyword,
                distance,
                dot: _,
                method: _,
            } => {
                if state.current_class == ClassType::None {
                    state.error(new_error(
                        "Can't use 'super' outside of a class".to_string(),
                        keyword,
                        error_codes::INVALID_SUPER,
                    ));
                } else if state.current_class == ClassType::Class {
                    state.error(new_error(
                        "Can't use 'super' in a class with no superclass".to_string(),
                        keyword,
                        error_codes::INVALID_SUPER,
                    ));
                }
                state.resolve_local(keyword.ident(), distance);
            }
            Expr::This { keyword, distance } => {
                if state.current_class == ClassType::None {
                    state.error(new_error(
                        "Can't use 'this' outside of a class".to_string(),
                        keyword,
                        error_codes::THIS_OUTSIDE_CLASS,
                    ));
                }

                state.resolve_local(keyword.ident(), distance);
            }
            Expr::Unary(unary) => unary.right().resolve(state),
            Expr::Variable { name, distance } => {
                if let Some(scope) = state.scopes.last()
                    && scope.get(name.string()) == Some(&false)
                {
                    state.error(new_error(
                        "Can't read a variable in its own intialiser".to_string(),
                        name,
                        error_codes::INVALID_INITIALISER,
                    ));
                }
                state.resolve_local(name, distance);
            }
        }
    }
}

impl Function {
    fn resolve(&self, state: &mut State, kind: FunctionType) {
        let enclosing = state.current_function;
        state.current_function = kind;
        state.with_scope(|state| {
            for param in &self.params {
                state.declare(param);
                state.define(param);
            }
            for stmt in &self.body {
                stmt.resolve(state);
            }
        });
        state.current_function = enclosing;
    }
}

impl Stmt {
    fn resolve(&self, state: &mut State) {
        match self {
            Stmt::Block(stmts) => {
                state.with_scope(|state| {
                    for stmt in stmts {
                        stmt.resolve(state);
                    }
                });
            }
            Stmt::Class {
                name,
                superclass,
                superclass_distance,
                methods,
            } => {
                let enclosing = state.current_class;
                state.current_class = ClassType::Class;

                state.declare(name);
                state.define(name);

                state.with_scope(|state| {
                    if let Some(superclass) = superclass {
                        if name.string() == superclass.string() {
                            state.error(new_error(
                                "A class can't inherit from itself".to_string(),
                                superclass,
                                error_codes::CYCLICAL_INHERITANCE,
                            ));
                        }

                        state.current_class = ClassType::SubClass;

                        state.resolve_local(superclass, superclass_distance);

                        state
                            .scopes
                            .last_mut()
                            .unwrap()
                            .insert("super".to_string(), true);
                    }

                    state.with_scope(|state| {
                        state
                            .scopes
                            .last_mut()
                            .unwrap()
                            .insert("this".to_string(), true);

                        for method in methods {
                            let kind = if method.name.string() == "init" {
                                FunctionType::Initialiser
                            } else {
                                FunctionType::Method
                            };
                            method.resolve(state, kind);
                        }
                    });
                });

                state.current_class = enclosing;
            }
            Stmt::Expr(expr) | Stmt::Print(expr) => expr.resolve(state),
            Stmt::Function(function) => {
                state.declare(&function.name);
                state.declare(&function.name);
                function.resolve(state, FunctionType::Function);
            }
            Stmt::If {
                condition,
                then_branch,
                else_branch,
            } => {
                condition.resolve(state);
                then_branch.resolve(state);
                if let Some(else_branch) = else_branch {
                    else_branch.resolve(state);
                }
            }
            Stmt::Return { keyword, value } => {
                if state.current_function == FunctionType::None {
                    state.error(new_error(
                        "Can't return from top-level code".to_string(),
                        keyword,
                        error_codes::RETURN_OUTSIDE_FUNCTION,
                    ));
                }
                if let Some(value) = value {
                    value.resolve(state);
                }
            }
            Stmt::Variable { name, initialiser } => {
                state.declare(name);
                if let Some(initialiser) = initialiser {
                    initialiser.resolve(state);
                }
                state.define(name);
            }
            Stmt::While { condition, body } => {
                condition.resolve(state);
                body.resolve(state);
            }
        }
    }
}

pub(super) fn resolve(ast: &Ast) -> Result<()> {
    let mut state = State {
        scopes: vec![],
        error: None,
        current_function: FunctionType::None,
        current_class: ClassType::None,
    };
    for stmt in &ast.0 {
        stmt.resolve(&mut state);
    }
    state.error.map_or(Ok(()), Err)
}