mod anchor;
mod eval;
mod frame;
mod lua_val;
mod metamethod;
mod object;
mod stack;
mod table;
mod table_ops;
pub use anchor::Anchor;
pub use lua_val::LuaType;
pub use lua_val::RustFunc;
use indexmap::IndexMap;
use rand::SeedableRng;
use std::sync::Arc;
use super::Instr;
use super::Result;
use super::compiler;
use super::compiler::Bytecode;
use super::compiler::RuntimeCaches;
use super::error::Error;
use super::error::ErrorKind;
use super::error::StackFrame;
use super::error::TypeError;
use super::host::{DefaultCallbacks, HostCallbacks};
use super::instr::{ArgCount, Builtin, RetCount};
use anchor::Registry;
pub(super) use lua_val::Val;
pub(super) use object::ObjectPtr;
use object::{GcHeap, Markable, UpvaluePool, UpvalueRef};
use table::Table;
#[hotpath::measure]
#[allow(
clippy::too_many_arguments,
reason = "single-call-site GC marking entry point; bundling into a struct adds boilerplate without clarifying anything"
)]
pub(super) fn mark_gc_roots(
heap: &GcHeap,
stack: &[Val],
globals: &IndexMap<String, Val>,
builtins: &[Val],
string_literals: &[Val],
active_call_roots: &[Val],
upvalue_pool: &UpvaluePool,
registry: &Registry,
) {
stack.mark_reachable(heap, upvalue_pool);
globals.mark_reachable(heap, upvalue_pool);
builtins.mark_reachable(heap, upvalue_pool);
string_literals.mark_reachable(heap, upvalue_pool);
active_call_roots.mark_reachable(heap, upvalue_pool);
registry.mark_reachable(heap, upvalue_pool);
}
#[derive(Clone)]
pub(super) struct CallInfo {
pub(super) bytecode: Arc<Bytecode>,
pub(super) ip: usize,
}
pub struct State {
pub(super) globals: IndexMap<String, Val>,
pub(super) builtins: [Val; Builtin::COUNT],
pub(super) globals_version: u64,
pub(super) stack: Vec<Val>,
pub(super) stack_bottom: usize,
pub(super) heap: GcHeap,
pub(super) string_literals: Vec<Val>,
pub(super) active_call_roots: Vec<Val>,
pub(super) upvalue_pool: UpvaluePool,
pub(super) open_upvalues: Vec<(usize, UpvalueRef)>,
pub(super) vararg_call_bases: Vec<usize>,
pub(super) cost_remaining: i64,
pub(super) cost_budget: i64,
pub(super) cost_used: u64,
pub(super) metamethod_depth: u32,
pub(super) call_depth: u32,
pub(super) call_stack: Vec<CallInfo>,
pub(super) callbacks: Box<dyn HostCallbacks + Send>,
pub(super) current_source: Option<String>,
user_data: Option<Box<dyn std::any::Any + Send>>,
pub(super) rng: rand::rngs::StdRng,
pub(super) registry: Registry,
}
const MAX_CALL_DEPTH: u32 = 1000;
const MAX_STACK_SIZE: usize = 1_000_000;
impl State {
const GC_INITIAL_THRESHOLD: usize = 20;
pub fn new() -> Self {
Self::with_callbacks(Box::new(DefaultCallbacks))
}
pub fn with_callbacks(callbacks: Box<dyn HostCallbacks + Send>) -> Self {
let mut me = Self::empty_with_callbacks(callbacks);
me.open_libs();
me
}
pub fn empty() -> Self {
Self::empty_with_callbacks(Box::new(DefaultCallbacks))
}
pub(crate) fn empty_with_callbacks(callbacks: Box<dyn HostCallbacks + Send>) -> Self {
let state_id = anchor::next_state_id();
Self {
globals: IndexMap::new(),
builtins: std::array::from_fn(|_| Val::Nil),
globals_version: 0,
stack: Vec::with_capacity(256), stack_bottom: 0,
heap: GcHeap::with_threshold(Self::GC_INITIAL_THRESHOLD),
string_literals: Vec::with_capacity(64), active_call_roots: Vec::with_capacity(64),
upvalue_pool: UpvaluePool::new(),
open_upvalues: Vec::new(),
vararg_call_bases: Vec::new(),
cost_remaining: i64::MAX,
cost_budget: i64::MAX,
cost_used: 0,
metamethod_depth: 0,
call_depth: 0,
call_stack: Vec::with_capacity(64), callbacks,
current_source: None,
user_data: None,
rng: rand::rngs::StdRng::seed_from_u64(0),
registry: Registry::new(state_id),
}
}
pub fn set_rng_seed(&mut self, seed: u64) {
self.rng = rand::rngs::StdRng::seed_from_u64(seed);
}
pub fn set_cost_budget(&mut self, budget: i64) {
self.cost_budget = budget;
self.cost_remaining = budget;
self.cost_used = 0;
}
pub fn cost_used(&self) -> u64 {
self.cost_used
}
pub fn cost_remaining(&self) -> i64 {
self.cost_remaining
}
#[inline(always)]
pub fn consume_cost(&mut self, cost: u64) -> Result<()> {
if cost > 0 && self.cost_remaining <= 0 {
return Err(self.error(ErrorKind::BudgetExceeded {
used: self.cost_used,
budget: self.cost_budget,
}));
}
self.cost_remaining -= cost as i64;
self.cost_used += cost;
Ok(())
}
pub fn set_user_data<T: Send + 'static>(&mut self, data: T) {
self.user_data = Some(Box::new(data));
}
pub fn user_data<T: Send + 'static>(&self) -> Option<&T> {
self.user_data.as_ref()?.downcast_ref()
}
pub fn user_data_mut<T: Send + 'static>(&mut self) -> Option<&mut T> {
self.user_data.as_mut()?.downcast_mut()
}
pub fn clear_user_data(&mut self) {
self.user_data = None;
}
pub fn anchor(&mut self) -> Result<Anchor> {
let val = self.at_index(-1)?;
if matches!(val, Val::Nil) {
return Err(Error::without_location(ErrorKind::AnchorNil));
}
let anchor = self.registry.insert(val);
self.pop_val();
Ok(anchor)
}
pub fn anchor_at(&mut self, idx: isize) -> Result<Anchor> {
let val = self.at_index(idx)?;
if matches!(val, Val::Nil) {
return Err(Error::without_location(ErrorKind::AnchorNil));
}
Ok(self.registry.insert(val))
}
pub fn anchor_function(&mut self) -> Result<Anchor> {
let val = self.at_index(-1)?;
let typ = val.typ(&self.heap);
if typ != LuaType::Function {
return Err(self.type_error(TypeError::FunctionCall(typ)));
}
let anchor = self.registry.insert(val);
self.pop_val();
Ok(anchor)
}
pub fn anchor_function_at(&mut self, idx: isize) -> Result<Anchor> {
let val = self.at_index(idx)?;
let typ = val.typ(&self.heap);
if typ != LuaType::Function {
return Err(self.type_error(TypeError::FunctionCall(typ)));
}
Ok(self.registry.insert(val))
}
pub fn push_anchor(&mut self, a: Anchor) -> Result<()> {
match self.registry.get(a) {
Some(val) => {
self.stack.push(val);
Ok(())
}
None => Err(Error::without_location(ErrorKind::InvalidAnchor)),
}
}
pub fn call_anchor(&mut self, a: Anchor, args: ArgCount, rets: RetCount) -> Result<()> {
let val = match self.registry.get(a) {
Some(val) => val,
None => return Err(Error::without_location(ErrorKind::InvalidAnchor)),
};
let n_args = match args {
ArgCount::Fixed(n) => n as usize,
ArgCount::Dynamic => {
return Err(Error::without_location(ErrorKind::InternalError(
"call_anchor does not support ArgCount::Dynamic; use ArgCount::Fixed".into(),
)));
}
};
let insert_at = self.stack.len().checked_sub(n_args).ok_or_else(|| {
Error::without_location(ErrorKind::InvalidStackIndex {
index: -(n_args as isize) - 1,
})
})?;
self.stack.insert(insert_at, val);
self.call(args, rets)
}
pub fn release_anchor(&mut self, a: Anchor) -> bool {
self.registry.remove(a)
}
pub fn anchor_type(&self, a: Anchor) -> Option<LuaType> {
self.registry.get(a).map(|val| val.typ(&self.heap))
}
pub fn anchor_count(&self) -> usize {
self.registry.len()
}
pub fn callbacks_mut(&mut self) -> &mut dyn HostCallbacks {
self.callbacks.as_mut()
}
pub fn replace_callbacks(
&mut self,
callbacks: Box<dyn HostCallbacks + Send>,
) -> Box<dyn HostCallbacks + Send> {
std::mem::replace(&mut self.callbacks, callbacks)
}
pub fn object_count(&self) -> usize {
self.heap.object_count()
}
pub fn string_count(&self) -> usize {
self.heap.string_count()
}
pub fn heap_size(&self) -> usize {
self.heap.object_count() + self.heap.string_count()
}
pub fn gc_should_run(&self) -> bool {
self.heap.is_full()
}
pub fn gc_threshold(&self) -> usize {
self.heap.threshold()
}
pub fn gc_set_threshold(&mut self, threshold: usize) {
self.heap.set_threshold(threshold);
}
pub fn gc_disable_auto(&mut self) {
self.heap.set_threshold(usize::MAX);
}
#[hotpath::measure]
pub fn gc_collect(&mut self) {
mark_gc_roots(
&self.heap,
&self.stack,
&self.globals,
&self.builtins,
&self.string_literals,
&self.active_call_roots,
&self.upvalue_pool,
&self.registry,
);
self.heap.collect();
}
pub(crate) fn host_print(&mut self, message: &str) {
let line = self
.call_stack
.last()
.and_then(|info| {
info.bytecode
.line_info
.get(info.ip.saturating_sub(1))
.copied()
})
.unwrap_or(0);
let source = self.current_source.as_deref();
self.callbacks.on_print(source, line, message);
}
pub(crate) fn host_error(&mut self, error: &Error) {
let source = self.current_source.as_deref();
self.callbacks.on_error(source, error);
}
pub fn current_source(&self) -> Option<&str> {
self.current_source.as_deref()
}
#[hotpath::measure]
pub fn get_global(&mut self, name: &str) {
let val = if let Some(slot) = Builtin::from_name(name) {
self.builtins[slot as usize]
} else {
self.globals.get(name).copied().unwrap_or_default()
};
self.stack.push(val);
}
#[hotpath::measure]
pub fn set_global(&mut self, name: &str) {
let val = self.pop_val();
self.set_global_value(name, val);
}
pub(super) fn set_global_value(&mut self, name: &str, val: Val) {
self.set_global_value_owned(name.to_string(), val);
}
pub(super) fn set_global_value_owned(&mut self, name: String, val: Val) {
if let Some(slot) = Builtin::from_name(&name) {
self.builtins[slot as usize] = val;
self.globals_version = self.globals_version.wrapping_add(1);
}
self.globals.insert(name, val);
}
pub fn with_restricted_env<F, R>(&mut self, whitelist: &[&str], f: F) -> R
where
F: FnOnce(&mut Self) -> R,
{
let mut restricted_globals = IndexMap::new();
let mut restricted_builtins: [Val; Builtin::COUNT] = std::array::from_fn(|_| Val::Nil);
for name in whitelist {
if let Some(slot) = Builtin::from_name(name) {
restricted_builtins[slot as usize] = self.builtins[slot as usize];
}
if let Some(val) = self.globals.get(*name) {
restricted_globals.insert((*name).to_string(), *val);
}
}
let saved_globals = std::mem::replace(&mut self.globals, restricted_globals);
let saved_builtins = std::mem::replace(&mut self.builtins, restricted_builtins);
self.globals_version = self.globals_version.wrapping_add(1);
let result = f(self);
self.globals = saved_globals;
self.builtins = saved_builtins;
self.globals_version = self.globals_version.wrapping_add(1);
result
}
#[hotpath::measure]
pub(super) fn alloc_string(&mut self, bytes: impl AsRef<[u8]>) -> Val {
if self.heap.is_full() {
self.gc_collect();
}
let ptr = self.heap.alloc_string(bytes.as_ref());
Val::Str(ptr)
}
pub fn error(&self, kind: ErrorKind) -> Error {
let pos = 0;
let column = 0;
Error::new(kind, pos, column)
}
pub(super) fn type_error(&self, e: TypeError) -> Error {
self.error(ErrorKind::TypeError(e))
}
#[allow(private_interfaces)]
pub(super) fn build_stack_trace(&self, current_frame: &frame::Frame) -> Vec<StackFrame> {
let mut trace = Vec::with_capacity(self.call_stack.len() + 1);
trace.push(current_frame.to_stack_frame());
for call_info in self.call_stack.iter().rev().skip(1) {
let line = if call_info.ip > 0 {
call_info
.bytecode
.line_info
.get(call_info.ip - 1)
.copied()
.unwrap_or(0)
} else {
call_info.bytecode.line_info.first().copied().unwrap_or(0)
};
trace.push(StackFrame {
function_name: call_info.bytecode.name.clone(),
source: call_info.bytecode.source.clone(),
line,
});
}
trace
}
}
impl Default for State {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::Bytecode;
use super::Instr;
use super::State;
use super::compiler::parse_str;
use super::lua_val::Val;
use crate::instr::RetCount;
#[test]
fn vm_test01() {
let mut state = State::new();
let input = parse_str("a = 1").unwrap();
state.eval_chunk(input, 0).unwrap();
assert_eq!(Val::Num(1.0), *state.globals.get("a").unwrap());
}
#[test]
fn vm_test02() {
let mut state = State::new();
let input = Bytecode {
code: vec![
Instr::push_string(1),
Instr::push_string(2),
Instr::concat(2),
Instr::set_global(0),
Instr::ret(RetCount::Fixed(0)),
],
string_literals: vec!["key".into(), "a".into(), "b".into()],
..Bytecode::default()
};
state.eval_chunk(input, 0).unwrap();
let val = state.globals.get("key").unwrap();
assert_eq!(val.as_string(&state.heap), Some(&b"ab"[..]));
}
#[test]
fn vm_test04() {
let mut state = State::new();
let input = Bytecode {
code: vec![
Instr::push_num(0),
Instr::push_num(0),
Instr::equal(),
Instr::set_global(0),
Instr::ret(RetCount::Fixed(0)),
],
number_literals: vec![2.5],
string_literals: vec!["a".into()],
..Bytecode::default()
};
state.eval_chunk(input, 0).unwrap();
assert_eq!(Val::Bool(true), *state.globals.get("a").unwrap());
}
#[test]
fn vm_test05() {
let mut state = State::new();
let input = Bytecode {
code: vec![
Instr::push_bool(true),
Instr::branch_false_keep(2),
Instr::pop(),
Instr::push_bool(false),
Instr::set_global(0),
Instr::ret(RetCount::Fixed(0)),
],
string_literals: vec!["key".into()],
..Bytecode::default()
};
state.eval_chunk(input, 0).unwrap();
assert_eq!(Val::Bool(false), *state.globals.get("key").unwrap());
}
#[test]
fn vm_test06() {
let mut state = State::new();
let code = vec![
Instr::push_bool(true),
Instr::branch_false(3),
Instr::push_num(0),
Instr::set_global(0),
Instr::ret(RetCount::Fixed(0)),
];
let chunk = Bytecode {
code,
number_literals: vec![5.0],
string_literals: vec!["a".into()],
..Bytecode::default()
};
state.eval_chunk(chunk, 0).unwrap();
assert_eq!(Val::Num(5.0), *state.globals.get("a").unwrap());
}
#[test]
fn vm_test07() {
let mut state = State::new();
let code = vec![
Instr::push_num(0),
Instr::push_num(0),
Instr::less(),
Instr::branch_false(2),
Instr::push_bool(true),
Instr::set_global(0),
Instr::ret(RetCount::Fixed(0)),
];
let chunk = Bytecode {
code,
number_literals: vec![2.0],
string_literals: vec!["a".into()],
..Bytecode::default()
};
state.eval_chunk(chunk, 0).unwrap();
assert!(state.globals.get("a").is_none());
}
#[test]
fn vm_test08() {
let code = vec![
Instr::push_num(2), Instr::set_global(0),
Instr::get_global(0), Instr::push_num(0),
Instr::less(),
Instr::branch_false(5),
Instr::get_global(0),
Instr::push_num(1),
Instr::add(),
Instr::set_global(0),
Instr::jump(-9),
Instr::ret(RetCount::Fixed(0)),
];
let chunk = Bytecode {
code,
number_literals: vec![1.0, 10.0, 0.0],
string_literals: vec!["a".into()],
..Bytecode::default()
};
let mut state = State::new();
state.eval_chunk(chunk, 0).unwrap();
}
#[test]
fn vm_test09() {
let code = vec![
Instr::push_num(0),
Instr::set_local(0),
Instr::get_local(0),
Instr::push_num(1),
Instr::less(),
Instr::branch_false(5),
Instr::get_local(0),
Instr::push_num(2),
Instr::add(),
Instr::set_local(0),
Instr::jump(-9),
Instr::get_local(0),
Instr::set_global(0),
Instr::ret(RetCount::Fixed(0)),
];
let chunk = Bytecode {
code,
number_literals: vec![1.0, 10.0, 1.0],
string_literals: vec!["x".into()],
num_locals: 1,
..Bytecode::default()
};
let mut state = State::new();
state.eval_chunk(chunk, 0).unwrap();
assert_eq!(Val::Num(10.0), *state.globals.get("x").unwrap());
}
#[test]
fn vm_test10() {
let code = vec![
Instr::push_num(0), Instr::push_num(1), Instr::push_num(1), Instr::for_prep(0, 3),
Instr::push_num(0),
Instr::set_global(0), Instr::for_loop(0, -3),
Instr::ret(RetCount::Fixed(0)),
];
let chunk = Bytecode {
code,
number_literals: vec![6.0, 2.0],
string_literals: vec!["a".into()],
num_locals: 4,
..Bytecode::default()
};
let mut state = State::new();
state.eval_chunk(chunk, 0).unwrap();
assert!(state.globals.get("a").is_none());
}
#[test]
fn vm_test11() {
let text = "
a = 0
for i = 1, 3 do
a = a + i
end";
let chunk = parse_str(text).unwrap();
let mut state = State::new();
state.eval_chunk(chunk, 0).unwrap();
let a = state.globals.get("a").unwrap().as_num().unwrap();
assert_eq!(a, 6.0);
}
#[test]
fn gc_host_controlled() {
let mut state = State::new();
let initial_objects = state.object_count();
let initial_strings = state.string_count();
assert!(state.heap_size() >= initial_objects + initial_strings);
state.gc_disable_auto();
assert_eq!(state.gc_threshold(), usize::MAX);
let code = parse_str("t1 = {} t2 = {} t3 = {}").unwrap();
state.eval_chunk(code, 0).unwrap();
assert!(state.object_count() > initial_objects);
let before_gc = state.object_count();
state.gc_collect();
assert_eq!(state.object_count(), before_gc);
let code = parse_str("t1 = nil t2 = nil t3 = nil").unwrap();
state.eval_chunk(code, 0).unwrap();
state.gc_collect();
assert!(state.object_count() < before_gc);
}
#[test]
fn gc_threshold_control() {
let mut state = State::empty();
state.gc_set_threshold(100);
assert_eq!(state.gc_threshold(), 100);
assert!(!state.gc_should_run());
state.gc_set_threshold(1);
let code = parse_str("t = {}").unwrap();
state.eval_chunk(code, 0).unwrap();
}
#[test]
fn callback_pattern_local_upvalue() {
use crate::{ArgCount, RetCount};
let mut state = State::new();
let code = r#"
local function helper()
return 42
end
function on_tick()
return helper()
end
"#;
state.load_string(code).unwrap();
state.call(ArgCount::Fixed(0), RetCount::Fixed(0)).unwrap();
state.get_global("on_tick");
assert_eq!(state.typ(-1), crate::LuaType::Function);
state.call(ArgCount::Fixed(0), RetCount::Fixed(1)).unwrap();
let result = state.to_number(-1).unwrap();
assert_eq!(result, 42.0);
}
#[test]
fn callback_pattern_mutable_upvalue() {
use crate::{ArgCount, RetCount};
let mut state = State::new();
let code = r#"
local counter = 0
local function increment()
counter = counter + 1
return counter
end
function tick()
return increment()
end
"#;
state.load_string(code).unwrap();
state.call(ArgCount::Fixed(0), RetCount::Fixed(0)).unwrap();
for expected in 1..=5 {
state.get_global("tick");
state.call(ArgCount::Fixed(0), RetCount::Fixed(1)).unwrap();
let result = state.to_number(-1).unwrap();
state.pop(1);
assert_eq!(result, expected as f64);
}
}
#[test]
fn callback_pattern_nested_locals() {
use crate::{ArgCount, RetCount};
let mut state = State::new();
let code = r#"
local base = 100
local function inner()
return base
end
local function outer()
return inner() + 10
end
function callback()
return outer() + 1
end
"#;
state.load_string(code).unwrap();
state.call(ArgCount::Fixed(0), RetCount::Fixed(0)).unwrap();
state.get_global("callback");
state.call(ArgCount::Fixed(0), RetCount::Fixed(1)).unwrap();
let result = state.to_number(-1).unwrap();
assert_eq!(result, 111.0); }
#[test]
fn error_line_numbers() {
use crate::{ArgCount, RetCount};
let mut state = State::new();
let code = "-- comment\nlocal t = {}\nt()";
state.load_string(code).unwrap();
let result = state.call(ArgCount::Fixed(0), RetCount::Fixed(0));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(!err.stack_trace.is_empty());
assert_eq!(err.stack_trace[0].line, 3);
}
}