use crate::lib_registry::LibraryModule;
use crate::lua_value::LuaValue;
use crate::lua_vm::{LuaError, LuaResult, LuaState};
pub fn create_coroutine_lib() -> LibraryModule {
crate::lib_module!("coroutine", {
"create" => coroutine_create,
"resume" => coroutine_resume,
"yield" => coroutine_yield,
"status" => coroutine_status,
"running" => coroutine_running,
"wrap" => coroutine_wrap,
"isyieldable" => coroutine_isyieldable,
"close" => coroutine_close,
})
}
fn coroutine_create(l: &mut LuaState) -> LuaResult<usize> {
let func = match l.get_arg(1) {
Some(f) => f,
None => {
return Err(l.error("coroutine.create requires a function argument".to_string()));
}
};
if !func.is_function() && !func.is_cfunction() {
return Err(l.error("coroutine.create requires a function argument".to_string()));
}
let vm = l.vm_mut();
let thread_val = vm.create_thread(func)?;
l.push_value(thread_val)?;
Ok(1)
}
fn coroutine_resume(l: &mut LuaState) -> LuaResult<usize> {
let thread_val = match l.get_arg(1) {
Some(t) => t,
None => {
return Err(l.error("coroutine.resume requires a thread argument".to_string()));
}
};
if !thread_val.is_thread() {
return Err(l.error("coroutine.resume requires a thread argument".to_string()));
}
let all_args = l.get_args();
let args: Vec<LuaValue> = if all_args.len() > 1 {
all_args[1..].to_vec()
} else {
Vec::new()
};
let vm = l.vm_mut();
match vm.resume_thread(thread_val, args) {
Ok((_finished, results)) => {
let result_count = results.len();
l.push_value(LuaValue::boolean(true))?; for result in results {
l.push_value(result)?;
}
Ok(1 + result_count)
}
Err(e) => {
let error_val = if let Some(thread) = thread_val.as_thread_mut() {
let err_obj = thread.error_object;
if !err_obj.is_nil() {
err_obj
} else {
let msg = thread.get_error_msg(e);
if msg.is_empty() {
LuaValue::nil()
} else {
l.create_string(&msg)?
}
}
} else {
LuaValue::nil()
};
l.push_value(LuaValue::boolean(false))?; l.push_value(error_val)?;
Ok(2)
}
}
}
fn coroutine_yield(l: &mut LuaState) -> LuaResult<usize> {
if l.nny > 0 {
if l.is_main_thread() {
return Err(l.error("attempt to yield from outside a coroutine".to_string()));
} else {
return Err(l.error("attempt to yield across a C-call boundary".to_string()));
}
}
let args = l.get_args();
l.do_yield(args)?;
Ok(0)
}
fn coroutine_status(l: &mut LuaState) -> LuaResult<usize> {
let thread_val = match l.get_arg(1) {
Some(t) => t,
None => {
return Err(l.error("coroutine.status requires a thread argument".to_string()));
}
};
if !thread_val.is_thread() {
return Err(l.error("coroutine.status requires a thread argument".to_string()));
}
let cs = &l.vm_mut().const_strings;
let str_running = cs.str_running;
let str_suspended = cs.str_suspended;
let str_normal = cs.str_normal;
let str_dead = cs.str_dead;
let status_val = if let Some(thread) = thread_val.as_thread_mut() {
if thread.is_main_thread() {
str_running
} else if thread.dead {
str_dead
} else if thread.call_depth() > 0 {
if thread.is_yielded() {
str_suspended
} else {
let is_self = std::ptr::eq(l as *const LuaState, thread as *const LuaState);
if is_self { str_running } else { str_normal }
}
} else if !thread.stack().is_empty() {
str_suspended
} else {
str_dead
}
} else {
str_dead
};
l.push_value(status_val)?;
Ok(1)
}
fn coroutine_running(l: &mut LuaState) -> LuaResult<usize> {
let thread_ptr = unsafe { l.thread_ptr() };
if l.is_main_thread() {
l.push_value(LuaValue::thread(thread_ptr))?;
l.push_value(LuaValue::boolean(true))?;
return Ok(2);
}
let thread_value = LuaValue::thread(thread_ptr);
l.push_value(thread_value)?;
l.push_value(LuaValue::boolean(false))?;
Ok(2)
}
fn coroutine_wrap(l: &mut LuaState) -> LuaResult<usize> {
let func = match l.get_arg(1) {
Some(f) => f,
None => {
return Err(l.error("coroutine.wrap requires a function argument".to_string()));
}
};
if !func.is_function() && !func.is_cfunction() {
return Err(l.error("coroutine.wrap requires a function argument".to_string()));
}
let vm = l.vm_mut();
let thread_val = vm.create_thread(func)?;
let wrapper_func = vm.create_c_closure(coroutine_wrap_call, vec![thread_val])?;
l.push_value(wrapper_func)?;
Ok(1)
}
fn coroutine_wrap_call(l: &mut LuaState) -> LuaResult<usize> {
let mut thread_val = LuaValue::nil();
if let Some(frame_idx) = l.call_depth().checked_sub(1)
&& let Some(func_val) = l.get_frame_func(frame_idx)
&& let Some(cclosure) = func_val.as_cclosure()
{
if let Some(upval) = cclosure.upvalues().first() {
thread_val = *upval;
}
}
if !thread_val.is_thread() {
return Err(l.error("invalid wrapped coroutine".to_string()));
}
let args = l.get_args();
let vm = l.vm_mut();
match vm.resume_thread(thread_val, args) {
Ok((_finished, results)) => {
for result in &results {
l.push_value(*result)?;
}
Ok(results.len())
}
Err(_e) => {
if let Some(thread) = thread_val.as_thread_mut() {
let err_obj = std::mem::take(&mut thread.error_object);
if !err_obj.is_nil() {
l.error_object = err_obj;
let msg = std::mem::take(&mut thread.error_msg);
l.error_msg = msg;
} else {
let msg = std::mem::take(&mut thread.error_msg);
l.error_msg = msg;
}
}
Err(LuaError::RuntimeError)
}
}
}
fn coroutine_isyieldable(l: &mut LuaState) -> LuaResult<usize> {
let is_yieldable = if let Some(arg) = l.get_arg(1) {
if let Some(thread) = arg.as_thread_mut() {
thread.nny == 0
} else {
return Err(l.error("value is not a thread".to_string()));
}
} else {
l.nny == 0
};
l.push_value(LuaValue::boolean(is_yieldable))?;
Ok(1)
}
fn coroutine_close(l: &mut LuaState) -> LuaResult<usize> {
let thread_val = match l.get_arg(1) {
Some(t) if t.is_thread() => t,
Some(t) if !t.is_nil() => {
return Err(l.error("bad argument #1 to 'close' (coroutine expected)".to_string()));
}
_ => {
let thread_ptr = unsafe { l.thread_ptr() };
LuaValue::thread(thread_ptr)
}
};
if let Some(thread) = thread_val.as_thread_mut() {
let is_self = std::ptr::eq(l as *const LuaState, thread as *const LuaState);
let status: u8 = if is_self {
3 } else if thread.dead {
0 } else if thread.is_yielded() {
1 } else if thread.call_depth() > 0 {
2 } else if !thread.stack().is_empty() {
1 } else {
0 };
match status {
0 | 1 => {
}
2 => {
return Err(l.error("cannot close a normal coroutine".to_string()));
}
3 => {
if thread.is_main_thread() {
return Err(l.error("cannot close main thread".to_string()));
}
if l.is_closing {
l.push_value(LuaValue::boolean(true))?;
return Ok(1);
}
l.is_closing = true;
let _ = l.close_tbc_with_error(0, LuaValue::nil());
l.close_upvalues(0);
l.is_closing = false;
return Err(LuaError::CloseThread);
}
_ => unreachable!(),
}
let close_result = thread.close_tbc_with_error(0, LuaValue::nil());
thread.close_upvalues(0);
while thread.call_depth() > 0 {
thread.pop_frame();
}
thread.stack_truncate();
match close_result {
Ok(()) => {
if !thread.error_object.is_nil() {
let err_obj = std::mem::take(&mut thread.error_object);
l.push_value(LuaValue::boolean(false))?;
l.push_value(err_obj)?;
Ok(2)
} else {
l.push_value(LuaValue::boolean(true))?;
Ok(1)
}
}
Err(LuaError::Yield) => {
Err(LuaError::Yield)
}
Err(_e) => {
let err_obj = std::mem::take(&mut thread.error_object);
let error_val = if !err_obj.is_nil() {
err_obj
} else {
let msg = std::mem::take(&mut thread.error_msg);
if msg.is_empty() {
LuaValue::nil()
} else {
l.create_string(&msg)?
}
};
l.push_value(LuaValue::boolean(false))?;
l.push_value(error_val)?;
Ok(2)
}
}
} else {
l.push_value(LuaValue::boolean(true))?;
Ok(1)
}
}