use lua_types::{
error::LuaError,
value::LuaValue,
LuaType,
LuaStatus,
gc::GcRef,
};
use crate::state_stub::{LuaState, LuaStateStubExt as _, lua_CFunction, upvalue_index};
const COS_RUN: i32 = 0;
const COS_DEAD: i32 = 1;
const COS_YIELD: i32 = 2;
const COS_NORM: i32 = 3;
const STAT_NAMES: [&[u8]; 4] = [b"running", b"dead", b"suspended", b"normal"];
pub const CO_FUNCS: &[(&[u8], lua_CFunction)] = &[
(b"create", co_create),
(b"resume", co_resume),
(b"running", co_running),
(b"status", co_status),
(b"wrap", co_wrap),
(b"yield", co_yield),
(b"isyieldable", co_isyieldable),
(b"close", co_close),
];
fn get_co(state: &mut LuaState) -> Result<GcRef<lua_types::value::LuaThread>, LuaError> {
let co = state.to_thread(1);
if co.is_none() {
let got = state.arg(1);
return Err(LuaError::type_arg_error(1, "thread", &got));
}
Ok(co.expect("checked above"))
}
fn aux_status(state: &mut LuaState, co: &GcRef<lua_types::value::LuaThread>) -> i32 {
let co_id = co.id;
let entry_rc = {
let g = state.global();
if co_id == g.current_thread_id {
return COS_RUN;
}
if co_id == g.main_thread_id {
return COS_NORM;
}
match g.threads.get(&co_id) {
Some(e) => e.state.clone(),
None => return COS_DEAD,
}
};
let co_state = match entry_rc.try_borrow() {
Ok(state) => state,
Err(_) => {
return COS_NORM;
}
};
let raw_status = co_state.status;
if raw_status == LuaStatus::Yield as u8 {
return COS_YIELD;
}
if raw_status != LuaStatus::Ok as u8 {
return COS_DEAD;
}
let has_frames = co_state.ci.as_usize() > 0;
if has_frames {
return COS_NORM;
}
let ci_func = co_state.call_info[0].func.0;
let top = co_state.top.0;
let lua_gettop = top as i64 - ci_func as i64 - 1;
if lua_gettop == 0 {
COS_DEAD
} else {
COS_YIELD
}
}
fn aux_resume(state: &mut LuaState, co: GcRef<lua_types::value::LuaThread>, narg: i32) -> i32 {
let co_id = co.id;
let entry_rc = {
let g = state.global();
match g.threads.get(&co_id) {
Some(e) => e.state.clone(),
None => {
drop(g);
push_lit_or_nil(state, b"cannot resume dead coroutine");
return -1;
}
}
};
let parent_thread_id = state.global().current_thread_id;
let top_before = state.get_top();
if top_before < narg {
push_lit_or_nil(state, b"not enough arguments to resume");
return -1;
}
let first_arg_idx = top_before - narg + 1;
let args: Vec<LuaValue> = (first_arg_idx..=top_before)
.map(|i| state.value_at(i))
.collect();
lua_vm::api::set_top(state, (top_before - narg) as i32).ok();
let parent_open_upval_slots: Vec<(u64, lua_vm::state::StackIdx)> = state
.openupval
.iter()
.filter_map(|uv| match &*uv.slot() {
lua_types::UpValState::Open { thread_id, idx } => {
Some((*thread_id as u64, *idx))
}
lua_types::UpValState::Closed(_) => None,
})
.collect();
{
let mut g = state.global_mut();
for (tid, idx) in &parent_open_upval_slots {
let val = state.get_at(*idx);
g.cross_thread_upvals.insert((*tid, *idx), val);
}
}
push_parent_gc_snapshot(state);
let (status, results_or_err): (LuaStatus, Vec<LuaValue>) = {
let mut co_state = match entry_rc.try_borrow_mut() {
Ok(b) => b,
Err(_) => {
pop_parent_gc_snapshot(state);
let mut g = state.global_mut();
for (tid, idx) in &parent_open_upval_slots {
g.cross_thread_upvals.remove(&(*tid, *idx));
}
drop(g);
push_lit_or_nil(state, b"cannot resume non-suspended coroutine");
return -1;
}
};
if co_state.check_stack(narg + 1).is_err() {
drop(co_state);
pop_parent_gc_snapshot(state);
let mut g = state.global_mut();
for (tid, idx) in &parent_open_upval_slots {
g.cross_thread_upvals.remove(&(*tid, *idx));
}
drop(g);
push_lit_or_nil(state, b"too many arguments to resume");
return -1;
}
for v in args {
co_state.push(v);
}
co_state.global_mut().current_thread_id = co_id;
let mut nres: i32 = 0;
let status = lua_vm::do_::lua_resume(&mut *co_state, Some(state), narg, &mut nres);
co_state.global_mut().current_thread_id = parent_thread_id;
let co_top = co_state.top_idx().0 as i32;
let ci_func = co_state.current_call_info().func.0 as i32;
let count = if status == LuaStatus::Ok || status == LuaStatus::Yield {
nres
} else {
1
};
let start = co_top - count;
let vals: Vec<LuaValue> = (start..co_top)
.map(|i| co_state.get_at(lua_vm::state::StackIdx(i as u32)))
.collect();
let new_co_top = if status == LuaStatus::Ok || status == LuaStatus::Yield {
(co_top - count).max(ci_func + 1)
} else {
co_top - count
};
co_state.set_top(lua_vm::state::StackIdx(new_co_top.max(0) as u32));
(status, vals)
};
pop_parent_gc_snapshot(state);
{
let mut g = state.global_mut();
let mut flush: Vec<(lua_vm::state::StackIdx, LuaValue)> = Vec::new();
for (tid, idx) in &parent_open_upval_slots {
if let Some(v) = g.cross_thread_upvals.remove(&(*tid, *idx)) {
flush.push((*idx, v));
}
}
drop(g);
for (idx, v) in flush {
state.set_at(idx, v);
}
}
match status {
LuaStatus::Ok | LuaStatus::Yield => {
if state.check_stack(results_or_err.len() as i32 + 1).is_err() {
push_lit_or_nil(state, b"too many results to resume");
return -1;
}
let n = results_or_err.len();
for v in results_or_err {
state.push(v);
}
n as i32
}
_ => {
for v in results_or_err {
state.push(v);
}
-1
}
}
}
fn push_parent_gc_snapshot(state: &mut LuaState) {
let top = state.top_idx();
let stack_snapshot: Vec<LuaValue> = (0..top.0)
.map(|i| state.get_at(lua_vm::state::StackIdx(i)))
.collect();
let open_upval_snapshot = state.openupval.clone();
let mut g = state.global_mut();
g.suspended_parent_stacks.push(stack_snapshot);
g.suspended_parent_open_upvals.push(open_upval_snapshot);
}
fn pop_parent_gc_snapshot(state: &mut LuaState) {
let mut g = state.global_mut();
g.suspended_parent_open_upvals.pop();
g.suspended_parent_stacks.pop();
}
fn push_lit_or_nil(state: &mut LuaState, bytes: &[u8]) {
match state.intern_str(bytes) {
Ok(s) => state.push(LuaValue::Str(s)),
Err(_) => state.push(LuaValue::Nil),
}
}
pub fn co_resume(state: &mut LuaState) -> Result<usize, LuaError> {
let co = get_co(state)?;
let narg = state.get_top() - 1;
let r = aux_resume(state, co, narg);
if r < 0 {
state.push(LuaValue::Bool(false));
state.insert(-2)?;
Ok(2)
} else {
state.push(LuaValue::Bool(true));
state.insert(-(r + 1))?;
Ok((r + 1) as usize)
}
}
fn aux_wrap(state: &mut LuaState) -> Result<usize, LuaError> {
let up = state.value_at(upvalue_index(1));
let co = match up {
LuaValue::Thread(t) => t,
_ => {
return Err(LuaError::runtime(format_args!(
"coroutine.wrap: upvalue is not a thread"
)))
}
};
let narg = state.get_top();
let r = aux_resume(state, co.clone(), narg);
if r < 0 {
let top = state.get_top();
let mut err_val = state.value_at(top);
if aux_status(state, &co) == COS_DEAD {
let old_err = state.pop();
let nclose = close_suspended_or_dead(state, co)?;
err_val = if nclose >= 2 {
let top = state.get_top();
state.value_at(top)
} else {
old_err
};
state.pop_n(nclose);
}
Err(LuaError::from_value(err_val))
} else {
Ok(r as usize)
}
}
pub fn co_create(state: &mut LuaState) -> Result<usize, LuaError> {
state.check_arg_type(1, LuaType::Function)?;
let body = state.value_at(1);
let _nl = state.new_thread(Some(body))?;
Ok(1)
}
pub fn co_wrap(state: &mut LuaState) -> Result<usize, LuaError> {
co_create(state)?;
state.push_cclosure(aux_wrap, 1)?;
Ok(1)
}
pub fn co_yield(state: &mut LuaState) -> Result<usize, LuaError> {
let n = state.get_top();
let r = lua_vm::do_::lua_yieldk(state, n, 0, None)?;
Ok(r as usize)
}
pub fn co_status(state: &mut LuaState) -> Result<usize, LuaError> {
let co = get_co(state)?;
let idx = aux_status(state, &co) as usize;
let name: &[u8] = STAT_NAMES[idx];
let interned = state.intern_str(name)?;
state.push(LuaValue::Str(interned));
Ok(1)
}
pub fn co_isyieldable(state: &mut LuaState) -> Result<usize, LuaError> {
let is_yieldable = if matches!(state.type_at(1), LuaType::None) {
state.is_yieldable()
} else {
let co = get_co(state)?;
let co_id = co.id;
let (is_main, is_current) = {
let g = state.global();
(co_id == g.main_thread_id, co_id == g.current_thread_id)
};
if is_main {
false
} else if is_current {
state.is_yieldable()
} else {
let entry_rc = {
let g = state.global();
g.threads
.get(&co_id)
.expect("thread value carries an id that must resolve in GlobalState::threads")
.state
.clone()
};
let target_is_yieldable = match entry_rc.try_borrow() {
Ok(b) => b.is_yieldable(),
Err(_) => false,
};
target_is_yieldable
}
};
state.push(LuaValue::Bool(is_yieldable));
Ok(1)
}
pub fn co_running(state: &mut LuaState) -> Result<usize, LuaError> {
let is_main = state.push_thread()?;
state.push(LuaValue::Bool(is_main));
Ok(2)
}
pub fn co_close(state: &mut LuaState) -> Result<usize, LuaError> {
lua_vm::state::inc_c_stack(state)?;
let result = (|| {
let co = get_co(state)?;
let status = aux_status(state, &co);
match status {
COS_DEAD | COS_YIELD => close_suspended_or_dead(state, co),
_ => {
let name = if status == COS_RUN { "running" } else { "normal" };
Err(LuaError::runtime(format_args!(
"cannot close a {} coroutine",
name
)))
}
}
})();
state.n_ccalls -= 1;
result
}
fn close_suspended_or_dead(
state: &mut LuaState,
co: GcRef<lua_types::value::LuaThread>,
) -> Result<usize, LuaError> {
let co_id = co.id;
let entry_rc_opt = {
let g = state.global();
g.threads.get(&co_id).map(|e| e.state.clone())
};
let entry_rc = match entry_rc_opt {
Some(rc) => rc,
None => {
state.push(LuaValue::Bool(true));
return Ok(1);
}
};
let parent_thread_id = state.global().current_thread_id;
let caller_c_calls = state.c_calls();
let parent_open_upval_slots: Vec<(u64, lua_vm::state::StackIdx)> = state
.openupval
.iter()
.filter_map(|uv| match &*uv.slot() {
lua_types::UpValState::Open { thread_id, idx } => {
Some((*thread_id as u64, *idx))
}
lua_types::UpValState::Closed(_) => None,
})
.collect();
{
let mut g = state.global_mut();
for (tid, idx) in &parent_open_upval_slots {
let val = state.get_at(*idx);
g.cross_thread_upvals.insert((*tid, *idx), val);
}
}
push_parent_gc_snapshot(state);
let (status, err_value): (i32, Option<LuaValue>) = {
let mut co_state = entry_rc.borrow_mut();
co_state.global_mut().current_thread_id = co_id;
co_state.n_ccalls = caller_c_calls;
let in_status = co_state.status as i32;
let s = lua_vm::state::reset_thread(&mut *co_state, in_status);
co_state.global_mut().current_thread_id = parent_thread_id;
if s == LuaStatus::Ok as i32 {
(s, None)
} else {
let top = co_state.top_idx().0;
if top > 0 {
let err = co_state.get_at(lua_vm::state::StackIdx(top - 1));
co_state.set_top(lua_vm::state::StackIdx(top - 1));
(s, Some(err))
} else {
(s, Some(LuaValue::Nil))
}
}
};
pop_parent_gc_snapshot(state);
{
let mut g = state.global_mut();
let mut flush: Vec<(lua_vm::state::StackIdx, LuaValue)> = Vec::new();
for (tid, idx) in &parent_open_upval_slots {
if let Some(v) = g.cross_thread_upvals.remove(&(*tid, *idx)) {
flush.push((*idx, v));
}
}
drop(g);
for (idx, v) in flush {
state.set_at(idx, v);
}
}
if status == LuaStatus::Ok as i32 {
state.push(LuaValue::Bool(true));
Ok(1)
} else {
state.push(LuaValue::Bool(false));
if let Some(v) = err_value {
state.push(v);
} else {
state.push(LuaValue::Nil);
}
Ok(2)
}
}
pub fn open_coroutine(state: &mut LuaState) -> Result<usize, LuaError> {
state.new_lib(CO_FUNCS)?;
Ok(1)
}