use crate::engine::{FinxError, Result};
use crate::lexer::Token;
use crate::parser::{Expr, Stmt};
use crate::vm::{Function, Instruction, Value};
use std::collections::HashMap;
use std::rc::Rc;
enum ResolvedInParent {
Local(usize),
Upvalue(usize),
}
pub struct Compiler<'a> {
instructions: Vec<Instruction>,
globals: HashMap<String, usize>,
next_global: usize,
locals: Vec<HashMap<String, usize>>,
next_local: Vec<usize>,
upvalue_details: Vec<(String, usize, bool)>,
parent: Option<&'a Compiler<'a>>,
}
impl<'a> Compiler<'a> {
fn new_global_compiler() -> Self {
Compiler {
instructions: Vec::new(),
globals: HashMap::new(),
next_global: 0,
locals: vec![],
next_local: vec![],
upvalue_details: Vec::new(),
parent: None,
}
}
fn new_global_compiler_with_natives(native_names: &[String]) -> Self {
let mut compiler = Self::new_global_compiler();
for name in native_names {
compiler.allocate_global(name.clone());
}
compiler
}
fn new_for_function(parent_compiler: &'a Compiler<'a>) -> Self {
Compiler {
instructions: Vec::new(),
globals: parent_compiler.globals.clone(),
next_global: parent_compiler.next_global,
locals: vec![],
next_local: vec![],
upvalue_details: Vec::new(),
parent: Some(parent_compiler),
}
}
fn allocate_global(&mut self, name: String) -> usize {
let idx = self.next_global;
self.globals.insert(name, idx);
self.next_global += 1;
idx
}
fn get_or_allocate_global(&mut self, name: &str) -> usize {
if let Some(&idx) = self.globals.get(name) {
idx
} else {
self.allocate_global(name.to_string())
}
}
fn enter_scope(&mut self) {
self.locals.push(HashMap::new());
let current_next = *self.next_local.last().unwrap_or(&0);
self.next_local.push(current_next);
}
fn exit_scope(&mut self) {
let final_next = self.next_local.pop().unwrap();
self.locals.pop();
if let Some(outer_next) = self.next_local.last_mut() {
*outer_next = final_next;
}
}
fn declare_local(&mut self, name: String) -> usize {
let idx = *self.next_local.last_mut().unwrap();
self.locals.last_mut().unwrap().insert(name, idx);
*self.next_local.last_mut().unwrap() += 1;
idx
}
fn resolve_local(&self, name: &str) -> Option<usize> {
for scope in self.locals.iter().rev() {
if let Some(&idx) = scope.get(name) {
return Some(idx);
}
}
None
}
fn resolve_for_child_capture(&self, name: &str) -> Option<ResolvedInParent> {
if let Some(local_idx) = self.resolve_local(name) {
return Some(ResolvedInParent::Local(local_idx));
}
if let Some(pos) = self
.upvalue_details
.iter()
.position(|(up_name, _, _)| up_name == name)
{
return Some(ResolvedInParent::Upvalue(pos));
}
self.parent?.resolve_for_child_capture(name)
}
fn add_upvalue(&mut self, name: String, parent_index: usize, is_parent_upvalue: bool) -> usize {
for (i, (existing_name, existing_idx, existing_is_upvalue)) in
self.upvalue_details.iter().enumerate()
{
if existing_name == &name
&& *existing_idx == parent_index
&& *existing_is_upvalue == is_parent_upvalue
{
return i;
}
}
let upvalue_index = self.upvalue_details.len();
self.upvalue_details
.push((name, parent_index, is_parent_upvalue));
upvalue_index
}
fn resolve_local_or_upvalue(
&mut self,
name: &str,
) -> Option<std::result::Result<usize, usize>> {
if let Some(local_idx) = self.resolve_local(name) {
return Some(Ok(local_idx));
}
if let Some(parent_compiler) = self.parent {
match parent_compiler.resolve_for_child_capture(name) {
Some(ResolvedInParent::Local(parent_local_idx)) => {
let upvalue_idx = self.add_upvalue(name.to_string(), parent_local_idx, false);
Some(Err(upvalue_idx))
}
Some(ResolvedInParent::Upvalue(parent_upvalue_idx)) => {
let upvalue_idx = self.add_upvalue(name.to_string(), parent_upvalue_idx, true);
Some(Err(upvalue_idx))
}
None => None,
}
} else {
None
}
}
fn compile_function(
&'a self,
func_name: String,
params: Vec<String>,
body: Vec<Stmt>,
) -> Result<(Function, Vec<crate::vm::UpvalueSource>)> {
let mut func_compiler = Compiler::new_for_function(self);
if self.parent.is_none() {
if let Some(&global_idx) = self.globals.get(&func_name) {
func_compiler
.globals
.entry(func_name.clone())
.or_insert(global_idx);
}
}
func_compiler.enter_scope();
for param in ¶ms {
func_compiler.declare_local(param.clone());
}
if self.parent.is_some() {
func_compiler.declare_local(func_name.clone());
}
for stmt in &body {
func_compiler.compile_statement_recursive(stmt)?;
}
func_compiler.instructions.push(Instruction::Return);
let upvalue_sources = func_compiler
.upvalue_details
.iter()
.map(|(_, parent_idx, is_parent_upvalue)| {
if *is_parent_upvalue {
crate::vm::UpvalueSource::OuterUpvalue(*parent_idx)
} else {
crate::vm::UpvalueSource::Local(*parent_idx)
}
})
.collect();
let function = Function {
num_params: params.len(),
num_upvalues: func_compiler.upvalue_details.len(),
code: func_compiler.instructions,
};
Ok((function, upvalue_sources))
}
fn emit_variable_store(&mut self, name: &str) {
match self.resolve_local_or_upvalue(name) {
Some(Ok(local_idx)) => self.instructions.push(Instruction::StoreLocal(local_idx)),
Some(Err(upvalue_idx)) => self
.instructions
.push(Instruction::StoreUpvalue(upvalue_idx)),
None => {
let global_idx = self.get_or_allocate_global(name);
self.instructions.push(Instruction::StoreGlobal(global_idx));
}
}
}
fn emit_variable_load(&mut self, name: &str) -> Result<()> {
match self.resolve_local_or_upvalue(name) {
Some(Ok(local_idx)) => self.instructions.push(Instruction::LoadLocal(local_idx)),
Some(Err(upvalue_idx)) => self
.instructions
.push(Instruction::LoadUpvalue(upvalue_idx)),
None => {
if let Some(&global_idx) = self.globals.get(name) {
self.instructions.push(Instruction::LoadGlobal(global_idx));
} else {
return Err(FinxError::UndefinedVariable(name.to_string()));
}
}
}
Ok(())
}
fn compile_statement_recursive(&mut self, stmt: &Stmt) -> Result<()> {
match stmt {
Stmt::Let { name, value } => {
self.compile_expression_recursive(value)?;
let local_idx = self.declare_local(name.clone());
self.instructions.push(Instruction::StoreLocal(local_idx));
}
Stmt::Assign { name, value } => {
self.compile_expression_recursive(value)?;
self.emit_variable_store(name);
}
Stmt::Expr(expr) => {
self.compile_expression_recursive(expr)?;
}
Stmt::Fn { name, params, body } => {
let is_global = self.parent.is_none();
let var_idx = if is_global {
self.get_or_allocate_global(name)
} else {
self.declare_local(name.clone())
};
let (function, upvalue_sources) =
self.compile_function(name.clone(), params.clone(), body.clone())?;
self.instructions
.push(Instruction::Closure(function, upvalue_sources));
if is_global {
self.instructions.push(Instruction::StoreGlobal(var_idx));
} else {
self.instructions.push(Instruction::StoreLocal(var_idx));
}
}
Stmt::Return(expr) => {
self.compile_expression_recursive(expr)?;
self.instructions.push(Instruction::Return);
}
Stmt::If {
cond,
then_branch,
else_branch,
} => {
self.compile_if_statement(cond, then_branch, else_branch)?;
}
Stmt::While { cond, body } => {
self.compile_while_statement(cond, body)?;
}
Stmt::For {
var,
start,
end,
body,
} => {
self.compile_for_statement(var, start, end, body)?;
}
}
Ok(())
}
fn compile_if_statement(
&mut self,
cond: &Expr,
then_branch: &[Stmt],
else_branch: &Option<Box<Stmt>>,
) -> Result<()> {
self.compile_expression_recursive(cond)?;
let jump_if_false_idx = self.instructions.len();
self.instructions.push(Instruction::JumpIfFalse(0));
for stmt in then_branch {
self.compile_statement_recursive(stmt)?;
}
if let Some(else_stmt) = else_branch {
let jump_over_else_idx = self.instructions.len();
self.instructions.push(Instruction::Jump(0));
let else_start_offset = self.instructions.len();
self.instructions[jump_if_false_idx] = Instruction::JumpIfFalse(else_start_offset);
self.compile_statement_recursive(else_stmt)?;
let after_else_offset = self.instructions.len();
self.instructions[jump_over_else_idx] = Instruction::Jump(after_else_offset);
} else {
let after_then_offset = self.instructions.len();
self.instructions[jump_if_false_idx] = Instruction::JumpIfFalse(after_then_offset);
}
Ok(())
}
fn compile_while_statement(&mut self, cond: &Expr, body: &[Stmt]) -> Result<()> {
let loop_start = self.instructions.len();
self.compile_expression_recursive(cond)?;
let exit_jump_addr = self.instructions.len();
self.instructions.push(Instruction::JumpIfFalse(0));
for stmt in body {
self.compile_statement_recursive(stmt)?;
}
self.instructions.push(Instruction::Loop(loop_start));
let after_loop = self.instructions.len();
self.instructions[exit_jump_addr] = Instruction::JumpIfFalse(after_loop);
Ok(())
}
fn compile_for_statement(
&mut self,
var: &str,
start: &Expr,
end: &Expr,
body: &[Stmt],
) -> Result<()> {
self.enter_scope();
self.compile_expression_recursive(start)?;
let loop_var_idx = self.declare_local(var.to_string());
self.instructions
.push(Instruction::StoreLocal(loop_var_idx));
self.compile_expression_recursive(end)?;
let end_var_idx = self.declare_local(format!("{}__end", var));
self.instructions.push(Instruction::StoreLocal(end_var_idx));
let loop_start = self.instructions.len();
self.instructions.push(Instruction::LoadLocal(loop_var_idx));
self.instructions.push(Instruction::LoadLocal(end_var_idx));
self.instructions.push(Instruction::LessThan);
let exit_jump_addr = self.instructions.len();
self.instructions.push(Instruction::JumpIfFalse(0));
for stmt in body {
self.compile_statement_recursive(stmt)?;
}
self.instructions.push(Instruction::LoadLocal(loop_var_idx));
self.instructions
.push(Instruction::LoadConst(Value::Number(1.0)));
self.instructions.push(Instruction::Add);
self.instructions
.push(Instruction::StoreLocal(loop_var_idx));
self.instructions.push(Instruction::Loop(loop_start));
let after_loop = self.instructions.len();
self.instructions[exit_jump_addr] = Instruction::JumpIfFalse(after_loop);
self.exit_scope();
Ok(())
}
fn compile_expression_recursive(&mut self, expr: &Expr) -> Result<()> {
match expr {
Expr::Number(val) => {
self.instructions
.push(Instruction::LoadConst(Value::Number(*val)));
}
Expr::String(val) => {
self.instructions
.push(Instruction::LoadConst(Value::Str(Rc::new(val.clone()))));
}
Expr::Bool(val) => {
self.instructions
.push(Instruction::LoadConst(Value::Bool(*val)));
}
Expr::Null => {
self.instructions.push(Instruction::LoadConst(Value::Null));
}
Expr::Binary { left, op, right } => {
self.compile_binary_expression(left, op, right)?;
}
Expr::Call { callee, args } => {
self.compile_call_expression(callee, args)?;
}
Expr::Block(statements) => {
self.compile_block_expression(statements)?;
}
Expr::Identifier(name) => {
self.emit_variable_load(name)?;
}
}
Ok(())
}
fn compile_binary_expression(&mut self, left: &Expr, op: &Token, right: &Expr) -> Result<()> {
self.compile_expression_recursive(left)?;
self.compile_expression_recursive(right)?;
let instruction = match op {
Token::Plus => Instruction::Add,
Token::Minus => Instruction::Subtract,
Token::Star => Instruction::Multiply,
Token::Slash => Instruction::Divide,
Token::Percent => Instruction::Modulo,
Token::EqEq => Instruction::Equal,
Token::BangEq => Instruction::NotEqual,
Token::Lt => Instruction::LessThan,
Token::LtEq => Instruction::LessThanOrEqual,
Token::Gt => Instruction::GreaterThan,
Token::GtEq => Instruction::GreaterThanOrEqual,
_ => {
return Err(FinxError::CompilerError(format!(
"Unknown binary operator: {:?}",
op
)));
}
};
self.instructions.push(instruction);
Ok(())
}
fn compile_call_expression(&mut self, callee: &Expr, args: &[Expr]) -> Result<()> {
self.compile_expression_recursive(callee)?;
for arg in args {
self.compile_expression_recursive(arg)?;
}
self.instructions.push(Instruction::Call(args.len()));
Ok(())
}
fn compile_block_expression(&mut self, statements: &[Stmt]) -> Result<()> {
self.enter_scope();
if statements.is_empty() {
self.instructions.push(Instruction::LoadConst(Value::Null));
} else {
for stmt in &statements[..statements.len() - 1] {
self.compile_statement_recursive(stmt)?;
}
let last_stmt = &statements[statements.len() - 1];
match last_stmt {
Stmt::Expr(expr) => {
self.compile_expression_recursive(expr)?;
}
_ => {
self.compile_statement_recursive(last_stmt)?;
self.instructions.push(Instruction::LoadConst(Value::Null));
}
}
}
self.exit_scope();
Ok(())
}
fn compile_top_level_statement(&mut self, stmt: Stmt) -> Result<()> {
match &stmt {
Stmt::Let { name, value } => {
self.compile_expression_recursive(value)?;
let global_idx = self.get_or_allocate_global(name);
self.instructions.push(Instruction::StoreGlobal(global_idx));
}
Stmt::Assign { name, value } => {
self.compile_expression_recursive(value)?;
let global_idx = self.get_or_allocate_global(name);
self.instructions.push(Instruction::StoreGlobal(global_idx));
}
Stmt::Expr(expr) => {
self.compile_expression_recursive(expr)?;
self.instructions.push(Instruction::Pop);
}
Stmt::Fn { name, params, body } => {
let global_idx = self.get_or_allocate_global(name);
let (function, upvalue_sources) =
self.compile_function(name.clone(), params.clone(), body.clone())?;
self.instructions
.push(Instruction::Closure(function, upvalue_sources));
self.instructions.push(Instruction::StoreGlobal(global_idx));
}
Stmt::Return(_) => {
return Err(FinxError::CompilerError(
"Return statement not allowed at top level".to_string(),
));
}
Stmt::If { .. } => {
self.compile_statement_recursive(&stmt)?;
}
Stmt::While { .. } => {
self.compile_statement_recursive(&stmt)?;
}
Stmt::For { .. } => {
self.compile_statement_recursive(&stmt)?;
}
}
Ok(())
}
fn compile_top_level_statement_preserve_expr(&mut self, stmt: Stmt) -> Result<()> {
match &stmt {
Stmt::Let { name, value } => {
self.compile_expression_recursive(value)?;
let global_idx = self.get_or_allocate_global(name);
self.instructions.push(Instruction::StoreGlobal(global_idx));
}
Stmt::Assign { name, value } => {
self.compile_expression_recursive(value)?;
let global_idx = self.get_or_allocate_global(name);
self.instructions.push(Instruction::StoreGlobal(global_idx));
}
Stmt::Expr(expr) => {
self.compile_expression_recursive(expr)?;
}
Stmt::Fn { name, params, body } => {
let global_idx = self.get_or_allocate_global(name);
let (function, upvalue_sources) =
self.compile_function(name.clone(), params.clone(), body.clone())?;
self.instructions
.push(Instruction::Closure(function, upvalue_sources));
self.instructions.push(Instruction::StoreGlobal(global_idx));
}
Stmt::Return(_) => {
return Err(FinxError::CompilerError(
"Return statement not allowed at top level".to_string(),
));
}
Stmt::If { .. } => {
self.compile_statement_recursive(&stmt)?;
}
Stmt::While { .. } => {
self.compile_statement_recursive(&stmt)?;
}
Stmt::For { .. } => {
self.compile_statement_recursive(&stmt)?;
}
}
Ok(())
}
}
pub fn compile(source: Vec<Stmt>) -> Result<Vec<Instruction>> {
let mut compiler = Compiler::new_global_compiler();
for stmt in source {
compiler.compile_top_level_statement(stmt)?;
}
Ok(compiler.instructions)
}
pub fn compile_with_natives(
source: Vec<Stmt>,
native_names: &[String],
) -> Result<Vec<Instruction>> {
let mut compiler = Compiler::new_global_compiler_with_natives(native_names);
for stmt in source {
compiler.compile_top_level_statement(stmt)?;
}
Ok(compiler.instructions)
}
pub fn compile_for_eval(source: Vec<Stmt>) -> Result<Vec<Instruction>> {
let mut compiler = Compiler::new_global_compiler();
for stmt in source {
compiler.compile_top_level_statement_preserve_expr(stmt)?;
}
Ok(compiler.instructions)
}
pub fn compile_for_eval_with_natives(
source: Vec<Stmt>,
native_names: &[String],
) -> Result<Vec<Instruction>> {
let mut compiler = Compiler::new_global_compiler_with_natives(native_names);
for stmt in source {
compiler.compile_top_level_statement_preserve_expr(stmt)?;
}
Ok(compiler.instructions)
}