use crate::error::{LuaError, LuaResult, RuntimeError};
use crate::vm::callinfo::LUA_MULTRET;
use crate::vm::closure::{Closure, RustClosure, Upvalue};
use crate::vm::execute::{self, CallResult};
use crate::vm::gc::arena::GcRef;
use crate::vm::state::{LuaState, LuaThread, ThreadStatus};
use crate::vm::value::Val;
fn nargs(state: &LuaState) -> usize {
let func = state.call_stack[state.ci].func;
if state.top > func + 1 {
state.top - func - 1
} else {
0
}
}
fn arg(state: &LuaState, n: usize) -> Val {
let func = state.call_stack[state.ci].func;
state.stack_get(func + 1 + n)
}
fn simple_error(msg: String) -> LuaError {
LuaError::Runtime(RuntimeError {
message: msg,
level: 0,
traceback: vec![],
})
}
pub fn co_create(state: &mut LuaState) -> LuaResult<u32> {
let func_val = arg(state, 0);
match func_val {
Val::Function(r) => {
let cl = state
.gc
.closures
.get(r)
.ok_or_else(|| simple_error("invalid function reference".into()))?;
if matches!(cl, Closure::Rust(_)) {
return Err(simple_error(
"bad argument #1 to 'create' (Lua function expected)".into(),
));
}
}
_ => {
return Err(simple_error(
"bad argument #1 to 'create' (Lua function expected)".into(),
));
}
}
let thread = LuaThread::new(func_val, state.global);
let thread_ref = state.gc.alloc_thread(thread);
state.push(Val::Thread(thread_ref));
Ok(1)
}
pub fn co_resume(state: &mut LuaState) -> LuaResult<u32> {
let co_val = arg(state, 0);
let Val::Thread(co_ref) = co_val else {
return Err(simple_error(
"bad argument #1 to 'resume' (coroutine expected)".into(),
));
};
let n_resume_args = if nargs(state) > 1 {
nargs(state) - 1
} else {
0
};
let mut resume_args = Vec::with_capacity(n_resume_args);
for i in 0..n_resume_args {
resume_args.push(arg(state, 1 + i));
}
match auxresume(state, co_ref, &resume_args) {
Ok(results) => {
let base = state.base;
state.stack_set(base, Val::Bool(true));
for (i, val) in results.iter().enumerate() {
state.stack_set(base + 1 + i, *val);
}
state.top = base + 1 + results.len();
Ok((1 + results.len()) as u32)
}
Err(error_val) => {
let base = state.base;
state.stack_set(base, Val::Bool(false));
state.stack_set(base + 1, error_val);
state.top = base + 2;
Ok(2)
}
}
}
fn close_thread_upvalues(state: &mut LuaState) -> Vec<(GcRef<Upvalue>, usize)> {
let mut suspended = Vec::new();
for &uv_ref in &state.open_upvalues {
if let Some(uv) = state.gc.upvalues.get(uv_ref)
&& let Some(idx) = uv.stack_index()
{
suspended.push((uv_ref, idx));
}
}
for &(uv_ref, _) in &suspended {
if let Some(uv) = state.gc.upvalues.get_mut(uv_ref) {
uv.close(&state.stack);
}
}
suspended
}
pub(crate) fn auxresume(
state: &mut LuaState,
co_ref: GcRef<LuaThread>,
args: &[Val],
) -> Result<Vec<Val>, Val> {
let co_status = state
.gc
.threads
.get(co_ref)
.map_or(ThreadStatus::Dead, |t| t.status);
match co_status {
ThreadStatus::Dead => {
let r = state.gc.intern_string(b"cannot resume dead coroutine");
return Err(Val::Str(r));
}
ThreadStatus::Running | ThreadStatus::Normal => {
let r = state
.gc
.intern_string(b"cannot resume non-suspended coroutine");
return Err(Val::Str(r));
}
ThreadStatus::Initial | ThreadStatus::Suspended => {
}
}
let resumer_suspended = close_thread_upvalues(state);
let saved_current_thread = state.current_thread;
let mut resumer = state.save_thread_state();
resumer.suspended_upvals = resumer_suspended;
state.saved_threads.push(resumer);
state.load_thread_by_ref(co_ref, ThreadStatus::Running);
state.current_thread = Some(co_ref);
if co_status == ThreadStatus::Initial {
for &val in args {
state.push(val);
}
} else {
for (i, &val) in args.iter().enumerate() {
state.stack_set(state.base + i, val);
}
state.top = state.base + args.len();
}
let exec_result = if co_status == ThreadStatus::Initial {
let func_idx = 0;
(|| -> LuaResult<()> {
match state.precall(func_idx, LUA_MULTRET)? {
CallResult::Lua => execute::execute(state),
CallResult::Rust => Ok(()),
}
})()
} else if state.yielded_in_hook {
state.base = state.call_stack[state.ci].base;
if state.top < state.base {
state.top = state.base;
}
state.yielded_in_hook = false;
(|| -> LuaResult<()> {
while state.ci > 0 {
execute::execute(state)?;
}
Ok(())
})()
} else {
let first_result = state.base;
if state.poscall(first_result) {
state.top = state.call_stack[state.ci].top;
}
(|| -> LuaResult<()> {
while state.ci > 0 {
execute::execute(state)?;
}
Ok(())
})()
};
match exec_result {
Ok(()) => {
let mut results = Vec::new();
let ci_func = state.call_stack[state.ci].func;
for i in ci_func..state.top {
results.push(state.stack_get(i));
}
let Some(resumer) = state.saved_threads.pop() else {
let r = state
.gc
.intern_string(b"internal error: missing resumer state");
return Err(Val::Str(r));
};
state.save_and_restore_by_ref(co_ref, ThreadStatus::Dead, resumer);
state.current_thread = saved_current_thread;
Ok(results)
}
Err(LuaError::Yield(n_results)) => {
let mut results = Vec::new();
let start = state.top.saturating_sub(n_results as usize);
for i in start..state.top {
results.push(state.stack_get(i));
}
let Some(resumer) = state.saved_threads.pop() else {
let r = state
.gc
.intern_string(b"internal error: missing resumer state");
return Err(Val::Str(r));
};
state.save_and_restore_by_ref(co_ref, ThreadStatus::Suspended, resumer);
state.current_thread = saved_current_thread;
Ok(results)
}
Err(err) => {
let error_val = state.error_object.take().unwrap_or_else(|| {
let r = state.gc.intern_string(err.to_string().as_bytes());
Val::Str(r)
});
let Some(resumer) = state.saved_threads.pop() else {
let r = state
.gc
.intern_string(b"internal error: missing resumer state");
return Err(Val::Str(r));
};
state.save_and_restore_by_ref(co_ref, ThreadStatus::Dead, resumer);
state.current_thread = saved_current_thread;
Err(error_val)
}
}
}
pub fn co_yield(state: &mut LuaState) -> LuaResult<u32> {
if state.current_thread.is_none() || state.n_ccalls > 0 {
return Err(simple_error(
"attempt to yield across metamethod/C-call boundary".into(),
));
}
let n = nargs(state) as u32;
Err(LuaError::Yield(n))
}
pub fn co_wrap(state: &mut LuaState) -> LuaResult<u32> {
let func_val = arg(state, 0);
match func_val {
Val::Function(r) => {
let cl = state
.gc
.closures
.get(r)
.ok_or_else(|| simple_error("invalid function reference".into()))?;
if matches!(cl, Closure::Rust(_)) {
return Err(simple_error(
"bad argument #1 to 'wrap' (Lua function expected)".into(),
));
}
}
_ => {
return Err(simple_error(
"bad argument #1 to 'wrap' (Lua function expected)".into(),
));
}
}
let thread = LuaThread::new(func_val, state.global);
let thread_ref = state.gc.alloc_thread(thread);
let mut wrapper = RustClosure::new(wrap_aux, "wrap_aux");
wrapper.upvalues.push(Val::Thread(thread_ref));
let closure_ref = state.gc.alloc_closure(Closure::Rust(wrapper));
state.push(Val::Function(closure_ref));
Ok(1)
}
fn wrap_aux(state: &mut LuaState) -> LuaResult<u32> {
let ci_func = state.call_stack[state.ci].func;
let func_val = state.stack_get(ci_func);
let Val::Function(closure_ref) = func_val else {
return Err(simple_error("invalid wrap closure".into()));
};
let co_ref = {
let cl = state
.gc
.closures
.get(closure_ref)
.ok_or_else(|| simple_error("invalid closure reference".into()))?;
match cl {
Closure::Rust(rc) => {
if let Some(Val::Thread(r)) = rc.upvalues.first() {
*r
} else {
return Err(simple_error("wrap: missing thread upvalue".into()));
}
}
_ => return Err(simple_error("wrap: expected Rust closure".into())),
}
};
let n = nargs(state);
let mut args = Vec::with_capacity(n);
for i in 0..n {
args.push(arg(state, i));
}
match auxresume(state, co_ref, &args) {
Ok(results) => {
let base = state.base;
for (i, val) in results.iter().enumerate() {
state.stack_set(base + i, *val);
}
state.top = base + results.len();
Ok(results.len() as u32)
}
Err(error_val) => {
let final_val = if let Val::Str(r) = error_val {
let where_prefix = execute::get_where(state, 1);
if where_prefix.is_empty() {
error_val
} else {
let original = state
.gc
.string_arena
.get(r)
.map(|s| String::from_utf8_lossy(s.data()).to_string())
.unwrap_or_default();
let full = format!("{where_prefix}{original}");
Val::Str(state.gc.intern_string(full.as_bytes()))
}
} else {
error_val
};
state.error_object = Some(final_val);
let display = match final_val {
Val::Str(r) => state
.gc
.string_arena
.get(r)
.map(|s| String::from_utf8_lossy(s.data()).to_string())
.unwrap_or_default(),
_ => format!("{final_val}"),
};
Err(LuaError::Runtime(RuntimeError {
message: display,
level: 0,
traceback: vec![],
}))
}
}
}
pub fn co_status(state: &mut LuaState) -> LuaResult<u32> {
let co_val = arg(state, 0);
let Val::Thread(co_ref) = co_val else {
return Err(simple_error(
"bad argument #1 to 'status' (coroutine expected)".into(),
));
};
if state.current_thread == Some(co_ref) {
let s = state.gc.intern_string(b"running");
state.push(Val::Str(s));
return Ok(1);
}
let status_str = match state.gc.threads.get(co_ref).map(|t| t.status) {
Some(ThreadStatus::Running) => "running",
Some(ThreadStatus::Initial) | Some(ThreadStatus::Suspended) => "suspended",
Some(ThreadStatus::Normal) => "normal",
Some(ThreadStatus::Dead) | None => "dead",
};
let s = state.gc.intern_string(status_str.as_bytes());
state.push(Val::Str(s));
Ok(1)
}
pub fn co_running(state: &mut LuaState) -> LuaResult<u32> {
match state.current_thread {
Some(co_ref) => {
state.push(Val::Thread(co_ref));
Ok(1)
}
None => {
Ok(0)
}
}
}