use std::rc::Rc;
use super::StdLib;
use crate::value::{CallResult, FuncBuiltin, FuncDef, Thread, ThreadState};
use crate::vm::{Local, OpResult, StackFrame};
use crate::{Error, LuaError, Result, Value, VM};
pub(super) fn module(stdlib: &mut StdLib) -> Result<()> {
stdlib
.module("coroutine")
.func("close", close)?
.func("create", create)?
.func("isyieldable", isyieldable)?
.func("resume", resume)?
.func("running", running)?
.func("status", status)?
.func("wrap", wrap)?
.func_raw("yield", r#yield)?;
Ok(())
}
fn close(vm: &mut VM) -> Result<Value> {
let co = vm.arg_thread(0)?;
let mut co = co
.try_borrow_mut()
.map_err(|_| Error::from_lua(LuaError::CoroutineCloseRunning))?;
match co.state {
ThreadState::Normal => return err!(LuaError::CoroutineCloseNormal),
ThreadState::Running => return err!(LuaError::CoroutineCloseRunning),
_ => {}
}
co.state = ThreadState::Dead;
let mut err = co.error.take().map(|(e, _, _)| Error::from_lua(e));
while let Some(frame) = co.frames.pop() {
for loc in frame.locals.into_iter().rev() {
let val = match loc {
Local::Stack { val, to_close, .. } if to_close => val,
Local::Heap { var, to_close, .. } if to_close => var.take(),
Local::Temp | Local::Stack { .. } | Local::Heap { .. } => continue,
};
if val.is_truthy() {
if let Some(meta) = vm.get_metatable(&val) {
let func = meta.borrow().get(&Value::str("__close")).to_func()?;
let args = vec![
val.clone(),
match err.as_ref() {
Some(e) => e.to_value(),
None => Value::Nil,
},
];
if let Err(e) = vm.call_recursive(func, args) {
err = Some(e);
}
}
}
}
}
Ok(match err {
Some(err) => Value::Mult(vec![Value::Bool(false), err.into_value()]),
None => Value::Bool(true),
})
}
fn create(vm: &mut VM) -> Result<Value> {
let f = vm.arg_func(0)?;
Ok(Value::Thread(vm.alloc_thread(Thread::from(f))))
}
fn isyieldable(vm: &mut VM) -> Result<Value> {
let thread = match vm.arg_opt(0) {
Some(v) => v.to_thread()?,
None => vm.thread.clone(),
};
let thread = thread.borrow();
let across_boundary = if matches!(thread.state, ThreadState::Running) {
__is_across_rust_boundary(&vm.frames[0..vm.frames.len() - 1])
} else {
__is_across_rust_boundary(&thread.frames)
};
Ok(Value::Bool(thread.func.is_some() && !across_boundary))
}
fn __is_across_rust_boundary(frames: &[StackFrame]) -> bool {
for frame in frames.iter() {
if matches!(&*frame.func_def, FuncDef::Builtin(..)) {
return true;
}
}
false
}
fn resume(vm: &mut VM) -> Result<Value> {
let co = vm.arg_thread(0)?;
let args = vm.arg_split(1);
let state = co.borrow().state;
match state {
ThreadState::Normal | ThreadState::Running => {
return Ok(Value::Mult(vec![
Value::Bool(false),
Value::string(format!("{}", LuaError::CoroutineResumeNonSuspended)),
]));
}
ThreadState::Dead => {
return Ok(Value::Mult(vec![
Value::Bool(false),
Value::string(format!("{}", LuaError::CoroutineResumeDead)),
]))
}
_ => {}
};
let save_thread = vm.thread.clone();
save_thread.borrow_mut().state = ThreadState::Normal;
let save_frames = vm.frames.split_off(0);
let save_pc = vm.pc;
let save_code = vm.code.clone();
vm.thread = co.clone();
let res = match state {
ThreadState::Created => {
let func = {
let mut thread = co.borrow_mut();
thread.state = ThreadState::Running;
match &thread.func {
Some(func) => func.clone(),
None => unreachable!(),
}
};
match vm.call_prepare(func, args, None, None, false, None) {
Ok(CallResult::Return(ret)) => Ok(OpResult::Return(ret)),
Ok(CallResult::Continue) => vm.run_loop(),
Ok(CallResult::Yield(yld)) => Ok(OpResult::Yield(yld)),
Err(e) => Err(e),
}
}
ThreadState::Suspended => {
{
let mut thread = co.borrow_mut();
thread.state = ThreadState::Running;
vm.frames = thread.frames.split_off(0);
let (code, pc) = thread.pc.as_ref().unwrap();
vm.pc = *pc;
vm.code = code.clone();
let yld = vm.frames.pop().unwrap(); if let Some(yld_reg) = yld.ret_reg {
let frame = vm.frames.last_mut().unwrap();
frame.regs[yld_reg] = Value::Mult(args);
}
}
vm.run_loop()
}
_ => unreachable!(),
};
let mut co = co.borrow_mut();
co.frames = vm.frames.split_off(0);
co.pc = Some((vm.code.clone(), vm.pc));
vm.thread = save_thread;
vm.thread.borrow_mut().state = ThreadState::Running;
vm.frames = save_frames;
vm.pc = save_pc;
vm.code = save_code;
let res = match res {
Ok(res) => Ok(match res {
OpResult::Yield(yld) => {
co.state = ThreadState::Suspended;
yld
}
OpResult::Return(res) => {
co.state = ThreadState::Dead;
res
}
}),
Err(e) => {
co.state = ThreadState::Dead;
if let Error::Lua { typ, pos, trace } = &e {
co.error = Some((typ.clone(), pos.clone(), trace.clone()));
}
Err(e)
}
};
Ok(match res {
Ok(res) => {
let mut res = res.into_vec();
res.insert(0, Value::Bool(true));
Value::Mult(res)
}
Err(e) => Value::Mult(vec![Value::Bool(false), e.into_value()]),
})
}
fn running(vm: &mut VM) -> Result<Value> {
let thread = vm.thread.clone();
let main = thread.borrow().func.is_none();
Ok(Value::Mult(vec![Value::Thread(thread), Value::Bool(main)]))
}
fn status(vm: &mut VM) -> Result<Value> {
let co = vm.arg_thread(0)?;
let co = co.borrow();
Ok(Value::str(match co.state {
ThreadState::Created | ThreadState::Suspended => "suspended",
ThreadState::Running => "running",
ThreadState::Normal => "normal",
ThreadState::Dead => "dead",
}))
}
fn wrap(vm: &mut VM) -> Result<Value> {
let f = vm.arg_func(0)?;
let thread = Value::Thread(vm.alloc_thread(Thread::from(f)));
Ok(Value::Func(vm.alloc_builtin(FuncBuiltin {
module: "coroutine",
name: "__coroutine_wrap",
func: Rc::new(move |vm| {
vm.frames
.last_mut()
.unwrap()
.varargs
.insert(0, thread.clone());
match resume(vm) {
Ok(res) => {
let mut res = res.into_vec();
let success = res.remove(0).to_bool()?;
if success {
Ok(Value::Mult(res))
} else {
thread.to_thread()?.borrow_mut().state = ThreadState::Dead;
vm.frames.last_mut().unwrap().varargs = vec![thread.clone()];
err!(LuaError::CustomValue(if res.is_empty() {
Value::Nil
} else {
res.swap_remove(0)
}))
}
}
Err(e) => Err(e),
}
}),
})))
}
fn r#yield(vm: &mut VM) -> Result<CallResult> {
if vm.thread.borrow().func.is_none() {
return err!(LuaError::CoroutineYieldMain);
} else if vm.frames.iter().any(|f| f.recursive) {
return err!(empty, LuaError::CoroutineYieldRust);
}
let ret = Value::Mult(vm.arg_split(0));
vm.debug_return(&ret)?;
Ok(CallResult::Yield(ret))
}