use crate::engine::FinxError;
use std::cell::RefCell;
use std::collections::HashMap;
use std::ops::Add;
use std::rc::Rc;
pub type NativeFn = Rc<dyn Fn(&[Value]) -> Value + 'static>;
#[derive(Clone)]
pub struct NativeFunction {
pub func: NativeFn,
pub name: String,
pub num_params: usize,
}
impl std::fmt::Debug for NativeFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NativeFunction")
.field("name", &self.name)
.field("num_params", &self.num_params)
.field("func", &"<closure>")
.finish()
}
}
#[derive(Clone, Debug)]
pub enum Value {
Number(f64),
Str(Rc<String>),
Bool(bool),
Null,
#[doc(hidden)]
_InternalUsize(usize),
#[doc(hidden)]
_NativeFunction(NativeFunction),
}
impl std::fmt::Display for Value {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Value::Number(i) => write!(f, "{i}"),
Value::Str(s) => write!(f, "{s}"),
Value::Bool(b) => write!(f, "{b}"),
Value::Null => write!(f, "null"),
Value::_InternalUsize(i) => write!(f, "InternalUsize({})", i),
Value::_NativeFunction(_) => write!(f, "<native function>"),
}
}
}
impl Value {
pub fn as_num(&self) -> Option<f64> {
if let Value::Number(i) = self {
Some(*i)
} else {
None
}
}
pub fn as_str(&self) -> Option<&str> {
if let Value::Str(s) = self {
Some(s)
} else {
None
}
}
pub fn as_bool(&self) -> Option<bool> {
if let Value::Bool(b) = self {
Some(*b)
} else {
None
}
}
pub fn is_null(&self) -> bool {
matches!(self, Value::Null)
}
pub fn is_falsy(&self) -> bool {
self.is_null() || self.as_bool() == Some(false)
}
}
impl Add for Value {
type Output = Value;
fn add(self, rhs: Value) -> Value {
match (self, rhs) {
(Value::Number(a), Value::Number(b)) => Value::Number(a + b),
(Value::Str(a), Value::Str(b)) => Value::Str(Rc::new(format!("{}{}", a, b))),
(Value::Bool(a), Value::Bool(b)) => Value::Bool(a || b), _ => Value::Null,
}
}
}
impl PartialEq for Value {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Value::Number(a), Value::Number(b)) => a == b,
(Value::Str(a), Value::Str(b)) => a == b,
(Value::Bool(a), Value::Bool(b)) => a == b,
(Value::Null, Value::Null) => true,
(Value::_NativeFunction(a), Value::_NativeFunction(b)) => {
std::ptr::eq(&a.func as *const _, &b.func as *const _)
}
_ => false,
}
}
}
impl From<f64> for Value {
fn from(i: f64) -> Self {
Value::Number(i)
}
}
impl From<bool> for Value {
fn from(b: bool) -> Self {
Value::Bool(b)
}
}
impl From<&str> for Value {
fn from(s: &str) -> Self {
Value::Str(Rc::new(s.to_string()))
}
}
impl From<String> for Value {
fn from(s: String) -> Self {
Value::Str(Rc::new(s))
}
}
impl From<()> for Value {
fn from(_: ()) -> Self {
Value::Null
}
}
#[derive(Clone, Debug)]
enum UpvalueState {
Open(usize),
Closed(Value),
}
#[derive(Clone, Debug)]
struct Upvalue {
state: Rc<RefCell<UpvalueState>>,
}
impl Upvalue {
fn new_open(stack_index: usize) -> Self {
Upvalue {
state: Rc::new(RefCell::new(UpvalueState::Open(stack_index))),
}
}
fn get_value(&self, stack: &[Value]) -> Value {
match &*self.state.borrow() {
UpvalueState::Open(idx) => stack[*idx].clone(),
UpvalueState::Closed(val) => val.clone(),
}
}
fn close_if_matches_index(&self, stack_idx_to_close: usize, stack: &[Value]) {
let mut state = self.state.borrow_mut();
if let UpvalueState::Open(open_idx) = *state {
if open_idx == stack_idx_to_close {
*state = UpvalueState::Closed(stack[open_idx].clone());
}
}
}
}
#[derive(Clone, Debug)]
pub enum UpvalueSource {
Local(usize),
OuterUpvalue(usize),
}
#[derive(Clone, Debug)]
pub enum Instruction {
LoadConst(Value),
Pop,
LoadGlobal(usize),
StoreGlobal(usize),
LoadLocal(usize),
StoreLocal(usize),
LoadUpvalue(usize),
StoreUpvalue(usize),
Jump(usize),
JumpIfFalse(usize),
Call(usize),
Return,
Add,
Subtract,
Multiply,
Divide,
Modulo,
Equal,
NotEqual,
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
Closure(Function, Vec<UpvalueSource>),
Loop(usize), ExitLoop, }
#[derive(Clone, Debug)]
pub struct Function {
pub code: Vec<Instruction>,
pub num_params: usize,
pub num_upvalues: usize,
}
#[derive(Clone)]
struct Frame {
func: Rc<Function>,
ip: usize,
base: usize,
upvalues: Vec<Upvalue>,
}
#[derive(Clone)]
struct Closure {
func: Rc<Function>,
upvalues: Vec<Upvalue>,
}
pub struct VM {
stack: Vec<Value>,
frames: Vec<Frame>,
globals: Vec<Value>,
functions: Vec<Closure>,
native_functions: HashMap<String, NativeFunction>,
max_recursion_depth: usize,
has_open_upvalues: bool,
}
const INITIAL_VEC_CAPACITY: usize = 64;
impl VM {
pub fn new() -> Self {
Self {
stack: Vec::with_capacity(INITIAL_VEC_CAPACITY),
frames: Vec::with_capacity(INITIAL_VEC_CAPACITY / 4), globals: Vec::with_capacity(INITIAL_VEC_CAPACITY),
functions: Vec::with_capacity(INITIAL_VEC_CAPACITY / 2), native_functions: HashMap::new(),
max_recursion_depth: 1000, has_open_upvalues: false,
}
}
pub fn register_native_function(&mut self, name: &str, func: NativeFn, num_params: usize) {
let native_func = NativeFunction {
func,
name: name.to_string(),
num_params,
};
self.native_functions
.insert(name.to_string(), native_func.clone());
self.globals.push(Value::_NativeFunction(native_func));
}
pub fn setup_globals_for_natives(
&mut self,
native_names: &[String],
) -> std::result::Result<(), FinxError> {
self.globals.clear();
for name in native_names {
if let Some(native_func) = self.native_functions.get(name) {
self.globals
.push(Value::_NativeFunction(native_func.clone()));
} else {
return Err(FinxError::VmError(format!(
"Native function '{}' was not registered",
name
)));
}
}
Ok(())
}
pub fn native_function_count(&self) -> usize {
self.native_functions.len()
}
pub fn get_native_function_names(&self) -> Vec<String> {
self.native_functions.keys().cloned().collect()
}
pub fn set_max_recursion_depth(&mut self, depth: usize) {
self.max_recursion_depth = depth;
}
pub fn ensure_globals_capacity(&mut self, capacity: usize) {
if self.globals.len() < capacity {
self.globals.resize(capacity, Value::Null);
}
}
pub fn set_global_at_index(&mut self, index: usize, value: Value) {
if index >= self.globals.len() {
self.globals.resize(index + 1, Value::Null);
}
self.globals[index] = value;
}
fn execute_arithmetic_op<F>(
&mut self,
op: F,
op_name: &str,
) -> std::result::Result<(), FinxError>
where
F: Fn(f64, f64) -> f64,
{
let b = self.stack.pop().unwrap();
let a = self.stack.pop().unwrap();
if let (Some(a_num), Some(b_num)) = (a.as_num(), b.as_num()) {
self.stack.push(Value::Number(op(a_num, b_num)));
Ok(())
} else {
Err(FinxError::TypeError(format!(
"{} requires two numbers, got: {} and {}",
op_name, a, b
)))
}
}
fn execute_comparison_op<F>(
&mut self,
op: F,
op_name: &str,
) -> std::result::Result<(), FinxError>
where
F: Fn(f64, f64) -> bool,
{
let b = self.stack.pop().unwrap();
let a = self.stack.pop().unwrap();
if let (Some(a_num), Some(b_num)) = (a.as_num(), b.as_num()) {
self.stack.push(Value::Bool(op(a_num, b_num)));
Ok(())
} else {
Err(FinxError::TypeError(format!(
"{} requires two numbers, got: {} and {}",
op_name, a, b
)))
}
}
fn handle_return(&mut self, frame_idx: usize) {
let ret_val = self.stack.pop().unwrap_or(Value::Null);
let returning_frame_base = self.frames[frame_idx].base;
self.close_upvalues_at_or_above(returning_frame_base, frame_idx);
let popped_frame = self.frames.pop().unwrap();
self.stack.truncate(popped_frame.base);
if !self.frames.is_empty() {
self.stack.push(ret_val);
}
}
fn close_upvalues_at_or_above(&mut self, base_index: usize, exclude_frame_idx: usize) {
if !self.has_open_upvalues {
return;
}
if base_index >= self.stack.len() {
return;
}
let mut found_any_open = false;
for func_closure in &mut self.functions {
if func_closure.upvalues.is_empty() {
continue; }
for upval in &mut func_closure.upvalues {
let (needs_closing, is_open) = {
let state = upval.state.borrow();
match *state {
UpvalueState::Open(idx) => {
found_any_open = true;
(idx >= base_index, true)
}
UpvalueState::Closed(_) => (false, false),
}
};
if !is_open {
continue;
}
if needs_closing {
let mut state = upval.state.borrow_mut();
if let UpvalueState::Open(open_idx) = *state {
if open_idx < self.stack.len() {
*state = UpvalueState::Closed(self.stack[open_idx].clone());
}
}
}
}
}
for (i, frame) in self.frames.iter_mut().enumerate() {
if i == exclude_frame_idx || frame.upvalues.is_empty() {
continue; }
for upval in &mut frame.upvalues {
let (needs_closing, is_open) = {
let state = upval.state.borrow();
match *state {
UpvalueState::Open(idx) => {
found_any_open = true;
(idx >= base_index, true)
}
UpvalueState::Closed(_) => (false, false),
}
};
if !is_open {
continue;
}
if needs_closing {
let mut state = upval.state.borrow_mut();
if let UpvalueState::Open(open_idx) = *state {
if open_idx < self.stack.len() {
*state = UpvalueState::Closed(self.stack[open_idx].clone());
}
}
}
}
}
self.has_open_upvalues = found_any_open;
}
fn handle_conditional_jump(
&mut self,
frame_idx: usize,
new_ip: usize,
code_len: usize,
) -> std::result::Result<(), FinxError> {
let condition = self.stack.pop().unwrap_or(Value::Null);
if condition.is_falsy() {
if new_ip > code_len {
return Err(FinxError::VmError(format!(
"JumpIfFalse out of bounds: {} (code length: {})",
new_ip, code_len
)));
}
self.frames[frame_idx].ip = new_ip;
}
Ok(())
}
fn handle_function_call(&mut self, arg_count: usize) -> std::result::Result<(), FinxError> {
let function_stack_idx = self.stack.len() - arg_count - 1;
let function_obj = self.stack[function_stack_idx].clone();
match function_obj {
Value::_InternalUsize(closure_idx) => {
self.handle_closure_call(closure_idx, arg_count, function_stack_idx)
}
Value::_NativeFunction(native_func) => {
self.handle_native_call(native_func, arg_count, function_stack_idx)
}
_ => Err(FinxError::VmError(format!(
"Expected closure or native function on stack, got: {:?}",
function_obj
))),
}
}
fn handle_closure_call(
&mut self,
closure_idx: usize,
arg_count: usize,
function_stack_idx: usize,
) -> std::result::Result<(), FinxError> {
if self.frames.len() >= self.max_recursion_depth {
return Err(FinxError::VmError(format!(
"Maximum recursion depth of {} exceeded. This prevents infinite recursion and memory exhaustion.",
self.max_recursion_depth
)));
}
if self.functions[closure_idx].func.num_params != arg_count {
return Err(FinxError::VmError(format!(
"Mismatched argument count: expected {}, got {}",
self.functions[closure_idx].func.num_params, arg_count
)));
}
let new_base = function_stack_idx;
self.stack.remove(function_stack_idx); let closure = &self.functions[closure_idx];
let frame = Frame {
func: Rc::clone(&closure.func), ip: 0,
base: new_base,
upvalues: closure.upvalues.clone(), };
if !frame.upvalues.is_empty() {
self.has_open_upvalues = true;
}
self.frames.push(frame);
Ok(())
}
fn handle_native_call(
&mut self,
native_func: NativeFunction,
arg_count: usize,
function_stack_idx: usize,
) -> std::result::Result<(), FinxError> {
if native_func.num_params != arg_count {
return Err(FinxError::VmError(format!(
"Mismatched argument count for native function '{}': expected {}, got {}",
native_func.name, native_func.num_params, arg_count
)));
}
let args_start = function_stack_idx + 1;
let args_end = args_start + arg_count;
let args: Vec<Value> = self.stack[args_start..args_end].to_vec();
self.stack.truncate(function_stack_idx);
let result = (native_func.func)(&args);
self.stack.push(result);
Ok(())
}
fn handle_closure_creation(
&mut self,
func_template: Function,
upvalue_sources: Vec<UpvalueSource>,
) -> std::result::Result<(), FinxError> {
let current_frame_base = self.frames.last().unwrap().base;
let current_frame_upvalues = &self.frames.last().unwrap().upvalues;
let mut captured_upvalues = Vec::with_capacity(func_template.num_upvalues);
for source in upvalue_sources {
match source {
UpvalueSource::Local(local_idx_in_enclosing) => {
let stack_slot_to_capture = current_frame_base + local_idx_in_enclosing;
captured_upvalues.push(Upvalue::new_open(stack_slot_to_capture));
}
UpvalueSource::OuterUpvalue(upvalue_idx_in_enclosing) => {
captured_upvalues
.push(current_frame_upvalues[upvalue_idx_in_enclosing].clone());
}
}
}
if func_template.num_upvalues != captured_upvalues.len() {
return Err(FinxError::VmError(format!(
"Mismatched upvalue count for closure: expected {}, got {}",
func_template.num_upvalues,
captured_upvalues.len()
)));
}
for (existing_idx, existing_closure) in self.functions.iter().enumerate() {
if self.closures_are_equivalent(&existing_closure, &func_template, &captured_upvalues) {
self.stack.push(Value::_InternalUsize(existing_idx));
return Ok(());
}
}
let closure_idx = self.functions.len();
self.functions.push(Closure {
func: Rc::new(func_template),
upvalues: captured_upvalues,
});
if !self.functions[closure_idx].upvalues.is_empty() {
self.has_open_upvalues = true;
}
self.stack.push(Value::_InternalUsize(closure_idx));
Ok(())
}
fn closures_are_equivalent(
&self,
existing_closure: &Closure,
new_func: &Function,
_new_upvalues: &[Upvalue],
) -> bool {
if existing_closure.func.num_params != new_func.num_params
|| existing_closure.func.num_upvalues != new_func.num_upvalues
|| existing_closure.func.code.len() != new_func.code.len()
{
return false;
}
if existing_closure.func.num_upvalues == 0 && new_func.num_upvalues == 0 {
for (existing_instr, new_instr) in
existing_closure.func.code.iter().zip(new_func.code.iter())
{
if !self.instructions_equivalent(existing_instr, new_instr) {
return false;
}
}
return true;
}
false
}
fn instructions_equivalent(&self, instr1: &Instruction, instr2: &Instruction) -> bool {
use Instruction::*;
match (instr1, instr2) {
(LoadConst(v1), LoadConst(v2)) => self.values_equivalent(v1, v2),
(Pop, Pop) => true,
(LoadGlobal(i1), LoadGlobal(i2)) => i1 == i2,
(StoreGlobal(i1), StoreGlobal(i2)) => i1 == i2,
(LoadLocal(i1), LoadLocal(i2)) => i1 == i2,
(StoreLocal(i1), StoreLocal(i2)) => i1 == i2,
(LoadUpvalue(i1), LoadUpvalue(i2)) => i1 == i2,
(StoreUpvalue(i1), StoreUpvalue(i2)) => i1 == i2,
(Jump(i1), Jump(i2)) => i1 == i2,
(JumpIfFalse(i1), JumpIfFalse(i2)) => i1 == i2,
(Call(i1), Call(i2)) => i1 == i2,
(Return, Return) => true,
(Add, Add) => true,
(Subtract, Subtract) => true,
(Multiply, Multiply) => true,
(Divide, Divide) => true,
(Modulo, Modulo) => true,
(Equal, Equal) => true,
(NotEqual, NotEqual) => true,
(LessThan, LessThan) => true,
(LessThanOrEqual, LessThanOrEqual) => true,
(GreaterThan, GreaterThan) => true,
(GreaterThanOrEqual, GreaterThanOrEqual) => true,
(Loop(i1), Loop(i2)) => i1 == i2,
(ExitLoop, ExitLoop) => true,
(Closure(_, _), Closure(_, _)) => false,
_ => false,
}
}
fn values_equivalent(&self, val1: &Value, val2: &Value) -> bool {
match (val1, val2) {
(Value::Number(n1), Value::Number(n2)) => n1 == n2,
(Value::Str(s1), Value::Str(s2)) => s1 == s2,
(Value::Bool(b1), Value::Bool(b2)) => b1 == b2,
(Value::Null, Value::Null) => true,
(Value::_InternalUsize(i1), Value::_InternalUsize(i2)) => i1 == i2,
(Value::_NativeFunction(f1), Value::_NativeFunction(f2)) => {
std::ptr::eq(&f1.func as *const _, &f2.func as *const _)
}
_ => false,
}
}
fn run(&mut self, entry_closure: Closure) -> std::result::Result<(), FinxError> {
self.frames.push(Frame {
func: Rc::clone(&entry_closure.func),
ip: 0,
base: 0, upvalues: entry_closure.upvalues.clone(),
});
while let Some(frame_idx) = self.frames.len().checked_sub(1) {
let (current_ip, code_len, func_base) = {
let frame = &self.frames[frame_idx];
(frame.ip, frame.func.code.len(), frame.base)
};
if current_ip >= code_len {
if self.frames.len() == 1 {
self.frames.pop();
continue;
} else {
let returning_frame_base = self.frames[frame_idx].base;
self.close_upvalues_at_or_above(returning_frame_base, frame_idx);
let _ = self.stack.pop(); let popped_frame = self.frames.pop().unwrap();
self.stack.truncate(popped_frame.base);
continue;
}
} let instr = {
let frame = &self.frames[frame_idx];
frame.func.code[frame.ip].clone()
};
self.frames[frame_idx].ip += 1;
match instr {
Instruction::LoadConst(val) => self.stack.push(val),
Instruction::StoreGlobal(i) => {
if i >= self.globals.len() {
self.globals.resize(i + 1, Value::Null); }
self.globals[i] = self.stack.pop().unwrap();
}
Instruction::StoreLocal(i) => {
let val = self.stack.pop().unwrap();
if (i + func_base) >= self.stack.len() {
self.stack.resize(i + func_base + 1, Value::Null);
}
self.stack[func_base + i] = val;
}
Instruction::StoreUpvalue(i) => {
let val = self.stack.pop().unwrap();
let frame = &self.frames[frame_idx]; let upval = &frame.upvalues[i];
upval.close_if_matches_index(func_base + i, &self.stack);
*upval.state.borrow_mut() = UpvalueState::Closed(val);
}
Instruction::LoadGlobal(i) => {
self.stack.push(self.globals[i].clone());
}
Instruction::LoadLocal(i) => {
let val = self.stack[func_base + i].clone();
self.stack.push(val);
}
Instruction::LoadUpvalue(i) => {
let frame = &self.frames[frame_idx]; let upval = &frame.upvalues[i];
self.stack.push(upval.get_value(&self.stack));
}
Instruction::Jump(new_ip) => {
if new_ip > code_len {
return Err(FinxError::VmError(format!(
"Jump out of bounds: {} (current IP: {}, code length: {})",
new_ip, current_ip, code_len
)));
}
self.frames[frame_idx].ip = new_ip;
}
Instruction::JumpIfFalse(new_ip) => {
self.handle_conditional_jump(frame_idx, new_ip, code_len)?;
}
Instruction::Add => {
let b = self.stack.pop().unwrap();
let a = self.stack.pop().unwrap();
self.stack.push(a + b);
}
Instruction::Subtract => {
self.execute_arithmetic_op(|a, b| a - b, "Subtract")?;
}
Instruction::Multiply => {
self.execute_arithmetic_op(|a, b| a * b, "Multiply")?;
}
Instruction::Divide => {
let b = self.stack.pop().unwrap();
let a = self.stack.pop().unwrap();
if let (Some(a_num), Some(b_num)) = (a.as_num(), b.as_num()) {
if b_num == 0.0 {
return Err(FinxError::VmError("Division by zero".to_string()));
}
self.stack.push(Value::Number(a_num / b_num));
} else {
return Err(FinxError::TypeError(format!(
"Divide requires two numbers, got: {} and {}",
a, b
)));
}
}
Instruction::Modulo => {
let b = self.stack.pop().unwrap();
let a = self.stack.pop().unwrap();
if let (Some(a_num), Some(b_num)) = (a.as_num(), b.as_num()) {
if b_num == 0.0 {
return Err(FinxError::VmError("Modulo by zero".to_string()));
}
self.stack.push(Value::Number(a_num % b_num));
} else {
return Err(FinxError::TypeError(format!(
"Modulo requires two numbers, got: {} and {}",
a, b
)));
}
}
Instruction::Equal => {
let b = self.stack.pop().unwrap();
let a = self.stack.pop().unwrap();
self.stack.push(Value::Bool(a == b));
}
Instruction::NotEqual => {
let b = self.stack.pop().unwrap();
let a = self.stack.pop().unwrap();
self.stack.push(Value::Bool(a != b));
}
Instruction::LessThan => {
self.execute_comparison_op(|a, b| a < b, "LessThan")?;
}
Instruction::LessThanOrEqual => {
self.execute_comparison_op(|a, b| a <= b, "LessThanOrEqual")?;
}
Instruction::GreaterThan => {
self.execute_comparison_op(|a, b| a > b, "GreaterThan")?;
}
Instruction::GreaterThanOrEqual => {
self.execute_comparison_op(|a, b| a >= b, "GreaterThanOrEqual")?;
}
Instruction::Return => {
self.handle_return(frame_idx);
}
Instruction::Call(arg_count) => {
self.handle_function_call(arg_count)?;
}
Instruction::Closure(func_template, upvalue_sources) => {
self.handle_closure_creation(func_template, upvalue_sources)?;
}
Instruction::Pop => {
self.stack.pop();
}
Instruction::Loop(target_ip) => {
if target_ip >= code_len {
return Err(FinxError::VmError(format!(
"Loop target out of bounds: {} (code length: {})",
target_ip, code_len
)));
}
self.frames[frame_idx].ip = target_ip;
}
Instruction::ExitLoop => {
}
}
}
Ok(())
}
}
pub fn run_code(code: Vec<Instruction>) -> std::result::Result<(), FinxError> {
let mut vm = VM::new();
let entry_closure = Closure {
func: Rc::new(Function {
code,
num_params: 0,
num_upvalues: 0,
}),
upvalues: vec![],
};
vm.run(entry_closure)?;
Ok(())
}
pub fn run_code_with_vm(mut vm: VM, code: Vec<Instruction>) -> std::result::Result<VM, FinxError> {
let entry_closure = Closure {
func: Rc::new(Function {
code,
num_params: 0,
num_upvalues: 0,
}),
upvalues: vec![],
};
vm.run(entry_closure)?;
Ok(vm)
}
pub fn eval_code_with_vm(
vm: &mut VM,
code: Vec<Instruction>,
) -> std::result::Result<Value, FinxError> {
let entry_closure = Closure {
func: Rc::new(Function {
code,
num_params: 0,
num_upvalues: 0,
}),
upvalues: vec![],
};
let initial_stack_size = vm.stack.len();
vm.run(entry_closure)?;
if vm.stack.len() > initial_stack_size {
Ok(vm.stack.last().cloned().unwrap_or(Value::Null))
} else {
Ok(Value::Null)
}
}