xbasic 0.3.2

A library that allows adding a scripting language onto your project with ease. This lets your users write their own arbitrary logic.
Documentation
use crate::chunk::Chunk;
use crate::error_handler::ErrorHandler;
use crate::expr::{BinaryExpr, CallExpr, ExprValue, LogicalExpr, UnaryExpr, VariableExpr};
use crate::function::Function;
use crate::opcodes::OpCode;
use crate::stmt::{
	AssignStmt, BinaryAssignStmt, ExpressionStmt, ForStmt, IfStmt, InputStmt, PrintStmt,
	ReturnStmt, Stmt, WhileStmt,
};
use crate::tokens::{Token, TokenType};
use crate::visitor::{ExprVisitor, StmtVisitor};
use std::collections::HashMap;

pub(crate) struct Compiler {
	current_line: usize,
	variables: HashMap<String, usize>,
	functions: Vec<Function>,
	native_functions: HashMap<String, usize>,
	cur_variable: usize,
	error_handler: ErrorHandler,
	instructions: Vec<u8>,
	lines: Vec<usize>,
	literals: Vec<ExprValue>,
}

impl Compiler {
	pub(crate) fn new() -> Self {
		Self {
			current_line: 0,
			variables: HashMap::new(),
			functions: Vec::new(),
			native_functions: HashMap::new(),
			cur_variable: 0,
			error_handler: ErrorHandler::new(),
			instructions: Vec::new(),
			lines: Vec::new(),
			literals: Vec::new(),
		}
	}

	pub(crate) fn compile(
		&mut self,
		stmts: Vec<Stmt>,
		functions: &mut Vec<Function>,
		native_functions: HashMap<String, usize>,
		error_handler: &mut ErrorHandler,
	) -> Option<Chunk> {
		self.functions = functions.clone();
		self.native_functions = native_functions;
		self.reset();

		for stmt in stmts.clone() {
			self.current_line = stmt.line;
			stmt.accept(self);
		}

		let chunk = Chunk::new(
			self.instructions.clone(),
			self.lines.clone(),
			self.literals.clone(),
			self.cur_variable,
			0,
		);

		let vars = self.variables.clone();
		for func in functions {
			if func.chunk.is_none() {
				self.reset();
				self.variables = HashMap::new();
				self.cur_variable = 0;
				for (i, param) in func.parameters.iter().enumerate() {
					self.variables.insert(param.lexeme.clone(), i);
					self.cur_variable += 1;
				}

				for stmt in &func.stmts {
					self.current_line = stmt.line;
					stmt.accept(self);
				}
				self.emit_opcode(OpCode::Nil);
				self.emit_opcode(OpCode::Ret);
				func.set_chunk(Chunk::new(
					self.instructions.clone(),
					self.lines.clone(),
					self.literals.clone(),
					self.cur_variable,
					func.parameters.len(),
				));
			}
		}
		self.variables = vars;

		error_handler.merge(&self.error_handler);
		if self.error_handler.had_errors {
			None
		} else {
			Some(chunk)
		}
	}

	pub(crate) fn clear_errors(&mut self) {
		self.error_handler = ErrorHandler::new();
	}

	fn reset(&mut self) {
		self.instructions = Vec::new();
		self.literals = Vec::new();
		self.lines = Vec::new();
	}

	fn emit_opcode(&mut self, opcode: OpCode) {
		self.instructions.push(opcode as u8);
		self.lines.push(self.current_line);
	}

	fn emit_byte(&mut self, byte: u8) {
		self.instructions.push(byte);
		self.lines.push(self.current_line);
	}

	#[allow(clippy::map_entry)]
	fn get_variable(&mut self, name: Token) -> u8 {
		// Cannot use entry here, because it would require borrowing from `self` more than once.
		if !self.variables.contains_key(&name.lexeme) {
			if self.cur_variable == 255 {
				self.error_handler
					.error_token(&name, "Cannot have more than 255 local variables.");
				return 0;
			}

			self.variables.insert(name.lexeme, self.cur_variable);
			let index = self.cur_variable;
			self.cur_variable += 1;
			return index as u8;
		}

		*self.variables.get(&name.lexeme).unwrap() as u8
	}

	fn begin_jmp(&mut self, opcode: OpCode) -> usize {
		self.emit_opcode(opcode);
		let pos = self.instructions.len();
		self.emit_byte(0); // Placeholder
		pos
	}

	fn patch_jmp(&mut self, pos: usize, line: usize) {
		let delta = self.instructions.len() - pos - 1;
		if delta > 255 {
			self.error_handler.error(line, "IF body is too long.");
			return;
		}
		self.instructions[pos] = delta as u8;
	}

	fn begin_jmp_back(&mut self) -> usize {
		self.instructions.len()
	}

	fn patch_jmp_back(&mut self, pos: usize, line: usize) {
		let delta = self.instructions.len() - pos + 2;
		if delta > 255 {
			self.error_handler.error(line, "WHILE body is too long.");
			return;
		}
		self.emit_opcode(OpCode::JmpBack);
		self.emit_byte(delta as u8);
	}
}

impl StmtVisitor<()> for Compiler {
	fn visit_expression_stmt(&mut self, stmt: &ExpressionStmt, _: usize) {
		stmt.value.accept(self);
		self.emit_opcode(OpCode::Pop);
	}

	fn visit_print_stmt(&mut self, stmt: &PrintStmt, _: usize) {
		for expr in stmt.values.iter().rev() {
			expr.accept(self);
		}
		self.emit_opcode(OpCode::Print);
		self.emit_byte(stmt.values.len() as u8);
	}

	fn visit_input_stmt(&mut self, stmt: &InputStmt, _: usize) {
		let var_id = self.get_variable(stmt.variable.clone());
		self.emit_opcode(OpCode::Input);
		self.emit_byte(var_id);
	}

	fn visit_assign_stmt(&mut self, stmt: &AssignStmt, _: usize) {
		let var_id = self.get_variable(stmt.name.clone());
		stmt.value.accept(self);
		self.emit_opcode(OpCode::SetVar);
		self.emit_byte(var_id);
	}

	fn visit_if_stmt(&mut self, stmt: &IfStmt, line: usize) {
		stmt.condition.accept(self); // Push condition result onto stack
		let if_start = self.begin_jmp(OpCode::Jne);
		for stmt in &stmt.then_stmts {
			stmt.accept(self);
		}
		let jmp_over_else = self.begin_jmp(OpCode::Jmp);
		self.patch_jmp(if_start, line);
		for stmt in &stmt.else_stmts {
			stmt.accept(self);
		}
		self.patch_jmp(jmp_over_else, line);
	}

	fn visit_while_stmt(&mut self, stmt: &WhileStmt, line: usize) {
		let body_start = self.begin_jmp_back();
		stmt.condition.accept(self);
		let condition_if = self.begin_jmp(OpCode::Jne);
		for stmt in &stmt.body {
			stmt.accept(self);
		}
		self.patch_jmp_back(body_start, line);
		self.patch_jmp(condition_if, line);
	}

	fn visit_for_stmt(&mut self, stmt: &ForStmt, line: usize) {
		// initialize x if needed
		self.get_variable(stmt.variable.clone());

		stmt.max_value.accept(self);
		let max_value_id = self.cur_variable as u8;

		if self.cur_variable >= 255 {
			self.error_handler
				.error_token(&stmt.variable, "Too many variables");
		}

		// Allocate space for max_value
		self.cur_variable += 1;
		self.emit_opcode(OpCode::SetVar);
		self.emit_byte(self.cur_variable as u8 - 1);

		// set x to min_value
		stmt.min_value.accept(self);
		let var_id = self.get_variable(stmt.variable.clone());
		self.emit_opcode(OpCode::SetVar);
		self.emit_byte(var_id);

		let comparison_jmp = self.begin_jmp_back();

		// check that x <= max value
		self.emit_opcode(OpCode::Var);
		self.emit_byte(var_id);
		self.emit_opcode(OpCode::Var);
		self.emit_byte(max_value_id);
		self.emit_opcode(OpCode::LessEqual);

		let for_end_jmp = self.begin_jmp(OpCode::Jne);

		// body
		for stmt in &stmt.body {
			stmt.accept(self);
		}

		// increment x
		self.emit_opcode(OpCode::Inc);
		self.emit_byte(var_id);

		// jmp back to comparison
		self.patch_jmp_back(comparison_jmp, line);

		//x > max value
		self.patch_jmp(for_end_jmp, line);
	}

	fn visit_return_stmt(&mut self, stmt: &ReturnStmt, _: usize) {
		stmt.return_value.accept(self);
		self.emit_opcode(OpCode::Ret);
	}

	fn visit_binary_assign_stmt(&mut self, stmt: &BinaryAssignStmt, _: usize) {
		let var_id = self.get_variable(stmt.name.clone());
		self.emit_opcode(OpCode::Var);
		self.emit_byte(var_id);

		stmt.value.accept(self);

		self.emit_opcode(match stmt.operator {
			TokenType::PlusEqual => OpCode::Add,
			TokenType::MinusEqual => OpCode::Sub,
			TokenType::StarEqual => OpCode::Mul,
			TokenType::SlashEqual => OpCode::Div,
			TokenType::CaretEqual => OpCode::Pow,
			TokenType::PercentEqual => OpCode::Mod,
			_ => unreachable!(),
		});

		self.emit_opcode(OpCode::SetVar);
		self.emit_byte(var_id);
	}
}

impl ExprVisitor<()> for Compiler {
	fn visit_variable_expr(&mut self, expr: &VariableExpr) {
		let var_id = self.get_variable(expr.name.clone());
		self.emit_opcode(OpCode::Var);
		self.emit_byte(var_id);
	}

	fn visit_logical_expr(&mut self, expr: &LogicalExpr) {
		expr.left.accept(self);
		expr.right.accept(self);
		self.emit_opcode(match expr.operator {
			TokenType::And => OpCode::And,
			TokenType::Or => OpCode::Or,
			_ => unreachable!(),
		})
	}

	fn visit_binary_expr(&mut self, expr: &BinaryExpr) {
		expr.left.accept(self);
		expr.right.accept(self);
		self.emit_opcode(match expr.operator {
			TokenType::Plus => OpCode::Add,
			TokenType::Minus => OpCode::Sub,
			TokenType::Star => OpCode::Mul,
			TokenType::Slash => OpCode::Div,
			TokenType::Caret => OpCode::Pow,
			TokenType::Percent => OpCode::Mod,
			TokenType::Equal => OpCode::Equal,
			TokenType::Less => OpCode::Less,
			TokenType::LessEqual => OpCode::LessEqual,
			TokenType::Greater => OpCode::Greater,
			TokenType::GreaterEqual => OpCode::GreaterEqual,
			_ => unreachable!(),
		});
	}

	fn visit_unary_expr(&mut self, expr: &UnaryExpr) {
		expr.right.accept(self);
		self.emit_opcode(match expr.operator {
			TokenType::Not => OpCode::Not,
			TokenType::Minus => OpCode::Minus,
			_ => unreachable!(),
		});
	}

	fn visit_call_expr(&mut self, expr: &CallExpr) {
		// Arguments
		for arg in &expr.arguments {
			arg.accept(self);
		}

		// User function
		for (i, function) in self.functions.iter().enumerate() {
			if function.name == expr.name {
				// Call
				self.emit_opcode(OpCode::Call);
				self.emit_byte(i as u8);
				return;
			}
		}

		// Native function
		for (name, id) in &self.native_functions.clone() {
			if name == &expr.name {
				self.emit_opcode(OpCode::Native);
				self.emit_byte(*id as u8);
				return;
			}
		}

		unreachable!();
	}

	fn visit_literal_expr(&mut self, expr: &ExprValue) {
		let index = self.literals.len();
		self.literals.push(expr.clone());
		if index <= 255 {
			// TODO we need an integration test for > 255 literals
			self.emit_opcode(OpCode::Literal8);
			self.emit_byte(index as u8);
		} else if index <= 65535 {
			self.emit_opcode(OpCode::Literal16);
			self.emit_byte((index >> 8) as u8);
			self.emit_byte((index & 0xFF) as u8);
		} else {
			// Too many literals
			self.error_handler
				.errors
				.push("Cannot have more than 65,535 literals".to_string());
		}
	}
}