use crate::chunk::Chunk;
use crate::error_handler::ErrorHandler;
use crate::expr::{BinaryExpr, CallExpr, ExprValue, LogicalExpr, UnaryExpr, VariableExpr};
use crate::function::Function;
use crate::opcodes::OpCode;
use crate::stmt::{
AssignStmt, BinaryAssignStmt, ExpressionStmt, ForStmt, IfStmt, InputStmt, PrintStmt,
ReturnStmt, Stmt, WhileStmt,
};
use crate::tokens::{Token, TokenType};
use crate::visitor::{ExprVisitor, StmtVisitor};
use std::collections::HashMap;
pub(crate) struct Compiler {
current_line: usize,
variables: HashMap<String, usize>,
functions: Vec<Function>,
native_functions: HashMap<String, usize>,
cur_variable: usize,
error_handler: ErrorHandler,
instructions: Vec<u8>,
lines: Vec<usize>,
literals: Vec<ExprValue>,
}
impl Compiler {
pub(crate) fn new() -> Self {
Self {
current_line: 0,
variables: HashMap::new(),
functions: Vec::new(),
native_functions: HashMap::new(),
cur_variable: 0,
error_handler: ErrorHandler::new(),
instructions: Vec::new(),
lines: Vec::new(),
literals: Vec::new(),
}
}
pub(crate) fn compile(
&mut self,
stmts: Vec<Stmt>,
functions: &mut Vec<Function>,
native_functions: HashMap<String, usize>,
error_handler: &mut ErrorHandler,
) -> Option<Chunk> {
self.functions = functions.clone();
self.native_functions = native_functions;
self.reset();
for stmt in stmts.clone() {
self.current_line = stmt.line;
stmt.accept(self);
}
let chunk = Chunk::new(
self.instructions.clone(),
self.lines.clone(),
self.literals.clone(),
self.cur_variable,
0,
);
let vars = self.variables.clone();
for func in functions {
if func.chunk.is_none() {
self.reset();
self.variables = HashMap::new();
self.cur_variable = 0;
for (i, param) in func.parameters.iter().enumerate() {
self.variables.insert(param.lexeme.clone(), i);
self.cur_variable += 1;
}
for stmt in &func.stmts {
self.current_line = stmt.line;
stmt.accept(self);
}
self.emit_opcode(OpCode::Nil);
self.emit_opcode(OpCode::Ret);
func.set_chunk(Chunk::new(
self.instructions.clone(),
self.lines.clone(),
self.literals.clone(),
self.cur_variable,
func.parameters.len(),
));
}
}
self.variables = vars;
error_handler.merge(&self.error_handler);
if self.error_handler.had_errors {
None
} else {
Some(chunk)
}
}
pub(crate) fn clear_errors(&mut self) {
self.error_handler = ErrorHandler::new();
}
fn reset(&mut self) {
self.instructions = Vec::new();
self.literals = Vec::new();
self.lines = Vec::new();
}
fn emit_opcode(&mut self, opcode: OpCode) {
self.instructions.push(opcode as u8);
self.lines.push(self.current_line);
}
fn emit_byte(&mut self, byte: u8) {
self.instructions.push(byte);
self.lines.push(self.current_line);
}
#[allow(clippy::map_entry)]
fn get_variable(&mut self, name: Token) -> u8 {
if !self.variables.contains_key(&name.lexeme) {
if self.cur_variable == 255 {
self.error_handler
.error_token(&name, "Cannot have more than 255 local variables.");
return 0;
}
self.variables.insert(name.lexeme, self.cur_variable);
let index = self.cur_variable;
self.cur_variable += 1;
return index as u8;
}
*self.variables.get(&name.lexeme).unwrap() as u8
}
fn begin_jmp(&mut self, opcode: OpCode) -> usize {
self.emit_opcode(opcode);
let pos = self.instructions.len();
self.emit_byte(0); pos
}
fn patch_jmp(&mut self, pos: usize, line: usize) {
let delta = self.instructions.len() - pos - 1;
if delta > 255 {
self.error_handler.error(line, "IF body is too long.");
return;
}
self.instructions[pos] = delta as u8;
}
fn begin_jmp_back(&mut self) -> usize {
self.instructions.len()
}
fn patch_jmp_back(&mut self, pos: usize, line: usize) {
let delta = self.instructions.len() - pos + 2;
if delta > 255 {
self.error_handler.error(line, "WHILE body is too long.");
return;
}
self.emit_opcode(OpCode::JmpBack);
self.emit_byte(delta as u8);
}
}
impl StmtVisitor<()> for Compiler {
fn visit_expression_stmt(&mut self, stmt: &ExpressionStmt, _: usize) {
stmt.value.accept(self);
self.emit_opcode(OpCode::Pop);
}
fn visit_print_stmt(&mut self, stmt: &PrintStmt, _: usize) {
for expr in stmt.values.iter().rev() {
expr.accept(self);
}
self.emit_opcode(OpCode::Print);
self.emit_byte(stmt.values.len() as u8);
}
fn visit_input_stmt(&mut self, stmt: &InputStmt, _: usize) {
let var_id = self.get_variable(stmt.variable.clone());
self.emit_opcode(OpCode::Input);
self.emit_byte(var_id);
}
fn visit_assign_stmt(&mut self, stmt: &AssignStmt, _: usize) {
let var_id = self.get_variable(stmt.name.clone());
stmt.value.accept(self);
self.emit_opcode(OpCode::SetVar);
self.emit_byte(var_id);
}
fn visit_if_stmt(&mut self, stmt: &IfStmt, line: usize) {
stmt.condition.accept(self); let if_start = self.begin_jmp(OpCode::Jne);
for stmt in &stmt.then_stmts {
stmt.accept(self);
}
let jmp_over_else = self.begin_jmp(OpCode::Jmp);
self.patch_jmp(if_start, line);
for stmt in &stmt.else_stmts {
stmt.accept(self);
}
self.patch_jmp(jmp_over_else, line);
}
fn visit_while_stmt(&mut self, stmt: &WhileStmt, line: usize) {
let body_start = self.begin_jmp_back();
stmt.condition.accept(self);
let condition_if = self.begin_jmp(OpCode::Jne);
for stmt in &stmt.body {
stmt.accept(self);
}
self.patch_jmp_back(body_start, line);
self.patch_jmp(condition_if, line);
}
fn visit_for_stmt(&mut self, stmt: &ForStmt, line: usize) {
self.get_variable(stmt.variable.clone());
stmt.max_value.accept(self);
let max_value_id = self.cur_variable as u8;
if self.cur_variable >= 255 {
self.error_handler
.error_token(&stmt.variable, "Too many variables");
}
self.cur_variable += 1;
self.emit_opcode(OpCode::SetVar);
self.emit_byte(self.cur_variable as u8 - 1);
stmt.min_value.accept(self);
let var_id = self.get_variable(stmt.variable.clone());
self.emit_opcode(OpCode::SetVar);
self.emit_byte(var_id);
let comparison_jmp = self.begin_jmp_back();
self.emit_opcode(OpCode::Var);
self.emit_byte(var_id);
self.emit_opcode(OpCode::Var);
self.emit_byte(max_value_id);
self.emit_opcode(OpCode::LessEqual);
let for_end_jmp = self.begin_jmp(OpCode::Jne);
for stmt in &stmt.body {
stmt.accept(self);
}
self.emit_opcode(OpCode::Inc);
self.emit_byte(var_id);
self.patch_jmp_back(comparison_jmp, line);
self.patch_jmp(for_end_jmp, line);
}
fn visit_return_stmt(&mut self, stmt: &ReturnStmt, _: usize) {
stmt.return_value.accept(self);
self.emit_opcode(OpCode::Ret);
}
fn visit_binary_assign_stmt(&mut self, stmt: &BinaryAssignStmt, _: usize) {
let var_id = self.get_variable(stmt.name.clone());
self.emit_opcode(OpCode::Var);
self.emit_byte(var_id);
stmt.value.accept(self);
self.emit_opcode(match stmt.operator {
TokenType::PlusEqual => OpCode::Add,
TokenType::MinusEqual => OpCode::Sub,
TokenType::StarEqual => OpCode::Mul,
TokenType::SlashEqual => OpCode::Div,
TokenType::CaretEqual => OpCode::Pow,
TokenType::PercentEqual => OpCode::Mod,
_ => unreachable!(),
});
self.emit_opcode(OpCode::SetVar);
self.emit_byte(var_id);
}
}
impl ExprVisitor<()> for Compiler {
fn visit_variable_expr(&mut self, expr: &VariableExpr) {
let var_id = self.get_variable(expr.name.clone());
self.emit_opcode(OpCode::Var);
self.emit_byte(var_id);
}
fn visit_logical_expr(&mut self, expr: &LogicalExpr) {
expr.left.accept(self);
expr.right.accept(self);
self.emit_opcode(match expr.operator {
TokenType::And => OpCode::And,
TokenType::Or => OpCode::Or,
_ => unreachable!(),
})
}
fn visit_binary_expr(&mut self, expr: &BinaryExpr) {
expr.left.accept(self);
expr.right.accept(self);
self.emit_opcode(match expr.operator {
TokenType::Plus => OpCode::Add,
TokenType::Minus => OpCode::Sub,
TokenType::Star => OpCode::Mul,
TokenType::Slash => OpCode::Div,
TokenType::Caret => OpCode::Pow,
TokenType::Percent => OpCode::Mod,
TokenType::Equal => OpCode::Equal,
TokenType::Less => OpCode::Less,
TokenType::LessEqual => OpCode::LessEqual,
TokenType::Greater => OpCode::Greater,
TokenType::GreaterEqual => OpCode::GreaterEqual,
_ => unreachable!(),
});
}
fn visit_unary_expr(&mut self, expr: &UnaryExpr) {
expr.right.accept(self);
self.emit_opcode(match expr.operator {
TokenType::Not => OpCode::Not,
TokenType::Minus => OpCode::Minus,
_ => unreachable!(),
});
}
fn visit_call_expr(&mut self, expr: &CallExpr) {
for arg in &expr.arguments {
arg.accept(self);
}
for (i, function) in self.functions.iter().enumerate() {
if function.name == expr.name {
self.emit_opcode(OpCode::Call);
self.emit_byte(i as u8);
return;
}
}
for (name, id) in &self.native_functions.clone() {
if name == &expr.name {
self.emit_opcode(OpCode::Native);
self.emit_byte(*id as u8);
return;
}
}
unreachable!();
}
fn visit_literal_expr(&mut self, expr: &ExprValue) {
let index = self.literals.len();
self.literals.push(expr.clone());
if index <= 255 {
self.emit_opcode(OpCode::Literal8);
self.emit_byte(index as u8);
} else if index <= 65535 {
self.emit_opcode(OpCode::Literal16);
self.emit_byte((index >> 8) as u8);
self.emit_byte((index & 0xFF) as u8);
} else {
self.error_handler
.errors
.push("Cannot have more than 65,535 literals".to_string());
}
}
}