use std::fmt;
use std::rc::Rc;
use super::{Code, FuncClosure, FuncDef, Local, OpResult, StackFrame, Value};
use crate::error::Result;
use crate::value::CallResult;
use crate::{LuaError, VM};
#[derive(Debug)]
pub enum ReturnType {
Lua(Rc<Code>, usize),
RustLua(FuncCont, Rc<Code>, usize),
Rust(FuncCont),
}
pub struct FuncCont {
pub name: &'static str,
pub func: Rc<dyn Fn(&mut VM, Result<Value>) -> Result<Value>>,
}
impl fmt::Debug for FuncCont {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "cont:{}", self.name)
}
}
impl VM {
const MAX_RUST_CALL_DEPTH: usize = 150;
const MAX_CALL_DEPTH: usize = 1_000_000;
pub fn call_prepare(
&mut self,
func_def: Rc<FuncDef>,
args: Vec<Value>,
ret: Option<(FuncCont, bool)>,
reg_ret: Option<usize>,
tail: bool,
meta: Option<&'static str>,
) -> Result<CallResult> {
match &*func_def {
FuncDef::Builtin(func) => {
self.call_prepare_builtin(func_def.clone(), args, ret, reg_ret, meta);
self.check_call_depth()?;
self.debug_call()?;
let callable = func.func.clone();
let mut ret = callable(self)?;
loop {
self.debug_return(&ret)?;
let frame = self.frames.pop().unwrap();
assert!(!frame.recursive);
if let Some(handle) = frame.protected {
self.thread.borrow_mut().protected = handle;
}
match frame.ret {
ReturnType::Lua(code, pc) => {
self.code = code;
self.pc = pc;
}
ReturnType::RustLua(func, code, pc) => {
ret = (func.func)(self, Ok(ret))?;
self.code = code;
self.pc = pc;
}
ReturnType::Rust(func) => {
ret = (func.func)(self, Ok(ret))?;
continue;
}
}
break Ok(CallResult::Return(ret));
}
}
FuncDef::BuiltinRaw(func) => {
self.call_prepare_builtin(func_def.clone(), args, ret, reg_ret, meta);
self.check_call_depth()?;
self.debug_call()?;
(func.func)(self)
}
FuncDef::Defined(func) => {
self.call_prepare_closure(func_def.clone(), func, args, ret, reg_ret, tail, meta);
self.check_call_depth()?;
self.debug_call()?;
Ok(CallResult::Continue)
}
}
}
pub fn call_recursive(&mut self, func_def: Rc<FuncDef>, args: Vec<Value>) -> Result<Value> {
self.call_recursive_meta(func_def, args, None)
}
pub fn call_recursive_meta(
&mut self,
func_def: Rc<FuncDef>,
args: Vec<Value>,
meta: Option<&'static str>,
) -> Result<Value> {
match self.call_prepare(func_def, args, None, None, false, meta)? {
CallResult::Return(ret) => return Ok(ret),
CallResult::Yield(..) => return err!(LuaError::CoroutineYieldRust),
CallResult::Continue => {}
}
self.frames.last_mut().unwrap().recursive = true;
match self.run_loop()? {
OpResult::Return(ret) => Ok(ret),
OpResult::Yield(..) => unreachable!(),
}
}
fn check_call_depth(&self) -> Result<()> {
if self.rust_call_depth > Self::MAX_RUST_CALL_DEPTH {
return err!(LuaError::StackOverflow);
} else if self.frames.len() >= Self::MAX_CALL_DEPTH {
if self.frames.len() == Self::MAX_CALL_DEPTH {
return err!(LuaError::StackOverflow);
} else if self.frames.len() > Self::MAX_CALL_DEPTH / 10 * 11 {
return err!(LuaError::ErrorHandler);
}
}
Ok(())
}
fn call_prepare_builtin(
&mut self,
func_def: Rc<FuncDef>,
args: Vec<Value>,
ret: Option<(FuncCont, bool)>,
ret_reg: Option<usize>,
meta: Option<&'static str>,
) {
self.frames.push(StackFrame {
recursive: false,
tail: false,
protected: None,
func_def,
ret: match ret {
Some((func, lua)) => {
if lua {
ReturnType::RustLua(func, self.code.clone(), self.pc)
} else {
ReturnType::Rust(func)
}
}
None => ReturnType::Lua(self.code.clone(), self.pc),
},
ret_reg,
regs: Vec::new(),
locals: Vec::new(),
ups: Vec::new(),
varargs: args,
meta,
transfer: (0, 0),
});
}
fn call_prepare_closure(
&mut self,
func_def: Rc<FuncDef>,
func: &FuncClosure,
args: Vec<Value>,
ret: Option<(FuncCont, bool)>,
ret_reg: Option<usize>,
tail: bool,
meta: Option<&'static str>,
) {
let mut args = args.into_iter();
let mut locals = vec![Local::Temp; func.locals_cap];
for (i, name) in func.params.iter().enumerate() {
locals[i] = Local::Stack {
val: args.next().unwrap_or(Value::Nil),
name: name.clone(),
to_close: false,
};
}
let mut ups = Vec::with_capacity(func.ups.len());
for (var, name) in func.ups.iter() {
ups.push((var.clone(), name.clone()));
}
let varargs = if func.varargs {
args.collect()
} else {
Vec::new()
};
if tail {
let frame = self.frames.last_mut().unwrap();
frame.tail = true;
frame.func_def = func_def;
frame.regs = vec![Value::Nil; func.regs];
frame.locals = locals;
frame.ups = ups;
frame.varargs = varargs;
frame.transfer = (0, 0);
} else {
self.frames.push(StackFrame {
recursive: false,
tail: false,
protected: None,
func_def,
ret: match ret {
Some((func, lua)) => {
if lua {
ReturnType::RustLua(func, self.code.clone(), self.pc)
} else {
ReturnType::Rust(func)
}
}
None => ReturnType::Lua(self.code.clone(), self.pc),
},
ret_reg,
regs: vec![Value::Nil; func.regs],
locals,
ups,
varargs,
meta,
transfer: (0, 0),
});
}
self.pc = 0;
self.code = func.code.clone();
}
}