Skip to main content

lua_stdlib/
coro_lib.rs

1//! Coroutine library — port of `lcorolib.c`.
2//!
3//! Provides the `coroutine.*` standard-library table: `create`, `resume`,
4//! `running`, `status`, `wrap`, `yield`, `isyieldable`, and `close`.
5//!
6//! # Phase A–D stub notice
7//!
8//! Every function that requires actual coroutine execution (`resume`, `yield`,
9//! cross-thread `xmove`, `new_thread`, `close_thread`) is **unimplemented** and
10//! will panic at runtime.  The argument-checking and result-packaging logic is
11//! translated faithfully so that Phase E can drop in the real implementations
12//! without restructuring.  Phase E wires real stackful coroutines via
13//! `corosensei`.  See PORTING.md §2 #6.
14//!
15//! Translated from: `reference/lua-5.4.7/src/lcorolib.c` (210 lines, 12 functions)
16//! Target crate: `lua-stdlib`
17
18use lua_types::{
19    error::LuaError,
20    value::LuaValue,
21    LuaType,
22    LuaStatus,
23    gc::GcRef,
24};
25use crate::state_stub::{LuaState, LuaStateStubExt as _, lua_CFunction, upvalue_index};
26
27// ── Coroutine status codes ────────────────────────────────────────────────────
28
29
30/// Coroutine is the currently running thread.
31const COS_RUN: i32 = 0;
32
33/// Coroutine has finished execution or encountered an error.
34const COS_DEAD: i32 = 1;
35
36/// Coroutine is suspended — either yielded or not yet started.
37const COS_YIELD: i32 = 2;
38
39/// Coroutine is normal — it resumed another coroutine and is waiting.
40const COS_NORM: i32 = 3;
41
42/// Human-readable status strings indexed by the `COS_*` constants above.
43/// Pushed onto the Lua stack as byte strings.
44///
45const STAT_NAMES: [&[u8]; 4] = [b"running", b"dead", b"suspended", b"normal"];
46
47// ── Registration table ────────────────────────────────────────────────────────
48
49/// Registration table for the `coroutine` standard library.
50///
51///
52/// Each entry is `(name_bytes, function_pointer)`. Phase B resolves
53/// `lua_CFunction` to the canonical type alias from `lua-types`.
54pub const CO_FUNCS: &[(&[u8], lua_CFunction)] = &[
55    (b"create",      co_create),
56    (b"resume",      co_resume),
57    (b"running",     co_running),
58    (b"status",      co_status),
59    (b"wrap",        co_wrap),
60    (b"yield",       co_yield),
61    (b"isyieldable", co_isyieldable),
62    (b"close",       co_close),
63];
64
65// ── Internal helpers ──────────────────────────────────────────────────────────
66
67/// Retrieves the coroutine thread at stack index 1, raising a type error if
68/// the argument is absent or not a thread.
69///
70fn get_co(state: &mut LuaState) -> Result<GcRef<lua_types::value::LuaThread>, LuaError> {
71    let co = state.to_thread(1);
72    if co.is_none() {
73        let got = state.arg(1);
74        return Err(LuaError::type_arg_error(1, "thread", &got));
75    }
76    Ok(co.expect("checked above"))
77}
78
79/// Returns one of the `COS_*` status codes describing `co` relative to the
80/// calling thread `state`. Mirrors `auxstatus` in `lcorolib.c` exactly,
81/// reading the target coroutine's `status`, call-frame depth, and stack
82/// top through `GlobalState::threads`.
83///
84/// The main thread (id 0) is never stored in the registry, so a value
85/// pointing at it is always "running" when it is the current thread.
86/// Phase E-1 cannot resume coroutines, so any registry-resident thread
87/// is either suspended (initial state, function still on stack) or dead
88/// (empty stack).
89///
90fn aux_status(state: &mut LuaState, co: &GcRef<lua_types::value::LuaThread>) -> i32 {
91    let co_id = co.id;
92    let entry_rc = {
93        let g = state.global();
94        if co_id == g.current_thread_id {
95            return COS_RUN;
96        }
97        if co_id == g.main_thread_id {
98            return COS_NORM;
99        }
100        match g.threads.get(&co_id) {
101            Some(e) => e.state.clone(),
102            None => return COS_DEAD,
103        }
104    };
105    let co_state = match entry_rc.try_borrow() {
106        Ok(state) => state,
107        Err(_) => {
108            // Nested resumes can hold a mutable borrow of a parent coroutine.
109            // In that case, the safest fallback is to report the target as
110            // "normal" (active but not suspended/dead), which matches the
111            // common nested-resume status for the parent thread.
112            return COS_NORM;
113        }
114    };
115    let raw_status = co_state.status;
116    if raw_status == LuaStatus::Yield as u8 {
117        return COS_YIELD;
118    }
119    if raw_status != LuaStatus::Ok as u8 {
120        return COS_DEAD;
121    }
122    let has_frames = co_state.ci.as_usize() > 0;
123    if has_frames {
124        return COS_NORM;
125    }
126    let ci_func = co_state.call_info[0].func.0;
127    let top = co_state.top.0;
128    let lua_gettop = top as i64 - ci_func as i64 - 1;
129    if lua_gettop == 0 {
130        COS_DEAD
131    } else {
132        COS_YIELD
133    }
134}
135
136/// Transfers `narg` arguments from `state` to `co`, resumes the coroutine,
137/// then transfers results (or error message) back to `state`.
138///
139/// Returns the number of result values (≥ 0) on success, or `-1` on error
140/// with the error object left on top of `state`'s stack.
141///
142/// Phase E-3 adds cross-thread open-upvalue mirroring around the resume
143/// boundary: before yielding control, the parent's open-upvalue values
144/// are snapshotted into `GlobalState::cross_thread_upvals` so the
145/// coroutine body can read and write them through
146/// `LuaState::upvalue_get` / `upvalue_set`. On resume return, the
147/// (possibly mutated) cache entries are flushed back into the parent's
148/// stack. This is the alternative to a stack-refactor that would let
149/// the parent's `LuaState` be reached through `Rc<RefCell<_>>` while it
150/// is held by `&mut` further up the call stack.
151///
152fn aux_resume(state: &mut LuaState, co: GcRef<lua_types::value::LuaThread>, narg: i32) -> i32 {
153    let co_id = co.id;
154    let entry_rc = {
155        let g = state.global();
156        match g.threads.get(&co_id) {
157            Some(e) => e.state.clone(),
158            None => {
159                drop(g);
160                push_lit_or_nil(state, b"cannot resume dead coroutine");
161                return -1;
162            }
163        }
164    };
165    let parent_thread_id = state.global().current_thread_id;
166    let top_before = state.get_top();
167    if top_before < narg {
168        push_lit_or_nil(state, b"not enough arguments to resume");
169        return -1;
170    }
171    let first_arg_idx = top_before - narg + 1;
172    let args: Vec<LuaValue> = (first_arg_idx..=top_before)
173        .map(|i| state.value_at(i))
174        .collect();
175    lua_vm::api::set_top(state, (top_before - narg) as i32).ok();
176
177    let parent_open_upval_slots: Vec<(u64, lua_vm::state::StackIdx)> = state
178        .openupval
179        .iter()
180        .filter_map(|uv| match &*uv.slot() {
181            lua_types::UpValState::Open { thread_id, idx } => {
182                Some((*thread_id as u64, *idx))
183            }
184            lua_types::UpValState::Closed(_) => None,
185        })
186        .collect();
187    {
188        let mut g = state.global_mut();
189        for (tid, idx) in &parent_open_upval_slots {
190            let val = state.get_at(*idx);
191            g.cross_thread_upvals.insert((*tid, *idx), val);
192        }
193    }
194
195    push_parent_gc_snapshot(state);
196
197    let (status, results_or_err): (LuaStatus, Vec<LuaValue>) = {
198        let mut co_state = match entry_rc.try_borrow_mut() {
199            Ok(b) => b,
200            Err(_) => {
201                pop_parent_gc_snapshot(state);
202                let mut g = state.global_mut();
203                for (tid, idx) in &parent_open_upval_slots {
204                    g.cross_thread_upvals.remove(&(*tid, *idx));
205                }
206                drop(g);
207                push_lit_or_nil(state, b"cannot resume non-suspended coroutine");
208                return -1;
209            }
210        };
211        if co_state.check_stack(narg + 1).is_err() {
212            drop(co_state);
213            pop_parent_gc_snapshot(state);
214            let mut g = state.global_mut();
215            for (tid, idx) in &parent_open_upval_slots {
216                g.cross_thread_upvals.remove(&(*tid, *idx));
217            }
218            drop(g);
219            push_lit_or_nil(state, b"too many arguments to resume");
220            return -1;
221        }
222        for v in args {
223            co_state.push(v);
224        }
225        co_state.global_mut().current_thread_id = co_id;
226        let mut nres: i32 = 0;
227        let status = lua_vm::do_::lua_resume(&mut *co_state, Some(state), narg, &mut nres);
228        co_state.global_mut().current_thread_id = parent_thread_id;
229        let co_top = co_state.top_idx().0 as i32;
230        let ci_func = co_state.current_call_info().func.0 as i32;
231        let count = if status == LuaStatus::Ok || status == LuaStatus::Yield {
232            nres
233        } else {
234            1
235        };
236        let start = co_top - count;
237        let vals: Vec<LuaValue> = (start..co_top)
238            .map(|i| co_state.get_at(lua_vm::state::StackIdx(i as u32)))
239            .collect();
240        let new_co_top = if status == LuaStatus::Ok || status == LuaStatus::Yield {
241            (co_top - count).max(ci_func + 1)
242        } else {
243            co_top - count
244        };
245        co_state.set_top(lua_vm::state::StackIdx(new_co_top.max(0) as u32));
246        (status, vals)
247    };
248
249    // Pop the parent stack snapshot — the coroutine has yielded or returned.
250    pop_parent_gc_snapshot(state);
251
252    {
253        let mut g = state.global_mut();
254        let mut flush: Vec<(lua_vm::state::StackIdx, LuaValue)> = Vec::new();
255        for (tid, idx) in &parent_open_upval_slots {
256            if let Some(v) = g.cross_thread_upvals.remove(&(*tid, *idx)) {
257                flush.push((*idx, v));
258            }
259        }
260        drop(g);
261        for (idx, v) in flush {
262            state.set_at(idx, v);
263        }
264    }
265
266    match status {
267        LuaStatus::Ok | LuaStatus::Yield => {
268            if state.check_stack(results_or_err.len() as i32 + 1).is_err() {
269                push_lit_or_nil(state, b"too many results to resume");
270                return -1;
271            }
272            let n = results_or_err.len();
273            for v in results_or_err {
274                state.push(v);
275            }
276            n as i32
277        }
278        _ => {
279            for v in results_or_err {
280                state.push(v);
281            }
282            -1
283        }
284    }
285}
286
287fn push_parent_gc_snapshot(state: &mut LuaState) {
288    let top = state.top_idx();
289    let stack_snapshot: Vec<LuaValue> = (0..top.0)
290        .map(|i| state.get_at(lua_vm::state::StackIdx(i)))
291        .collect();
292    let open_upval_snapshot = state.openupval.clone();
293    let mut g = state.global_mut();
294    g.suspended_parent_stacks.push(stack_snapshot);
295    g.suspended_parent_open_upvals.push(open_upval_snapshot);
296}
297
298fn pop_parent_gc_snapshot(state: &mut LuaState) {
299    let mut g = state.global_mut();
300    g.suspended_parent_open_upvals.pop();
301    g.suspended_parent_stacks.pop();
302}
303
304/// Helper: push a string literal or fall back to Nil on intern failure.
305fn push_lit_or_nil(state: &mut LuaState, bytes: &[u8]) {
306    match state.intern_str(bytes) {
307        Ok(s) => state.push(LuaValue::Str(s)),
308        Err(_) => state.push(LuaValue::Nil),
309    }
310}
311
312// ── Public library functions ──────────────────────────────────────────────────
313
314/// `coroutine.resume(co [, val1, ...])` — attempt to resume coroutine `co`.
315///
316/// On success pushes `true` followed by all values yielded or returned by `co`.
317/// On failure pushes `false` followed by the error object.
318///
319pub fn co_resume(state: &mut LuaState) -> Result<usize, LuaError> {
320    let co = get_co(state)?;
321    // PORT NOTE: lua_gettop returns the argument count; -1 excludes the coroutine
322    // itself which sits at index 1.
323    let narg = state.get_top() - 1;
324    let r = aux_resume(state, co, narg);
325    if r < 0 {
326        // A sandbox budget trip is uncatchable: re-raise into the caller frame
327        // instead of returning `false, msg`, so code cannot keep a runaway
328        // coroutine alive by resuming it in a loop.
329        if state.sandbox_aborting() {
330            let top = state.get_top();
331            let err_val = state.value_at(top);
332            return Err(LuaError::from_value(err_val));
333        }
334        state.push(LuaValue::Bool(false));
335        state.insert(-2)?;
336        Ok(2)
337    } else {
338        state.push(LuaValue::Bool(true));
339        state.insert(-(r + 1))?;
340        Ok((r + 1) as usize)
341    }
342}
343
344/// Closure body installed by `coroutine.wrap`. The wrapped coroutine
345/// thread is stored in upvalue slot 1 as a `LuaValue::Thread`.
346///
347/// On call: forwards all args to `aux_resume` on the captured thread. On
348/// success returns the yielded/returned values; on coroutine error raises
349/// the error (matching `select(2, assert(resume(co, ...)))` semantics).
350///
351fn aux_wrap(state: &mut LuaState) -> Result<usize, LuaError> {
352    let up = state.value_at(upvalue_index(1));
353    let co = match up {
354        LuaValue::Thread(t) => t,
355        _ => {
356            return Err(LuaError::runtime(format_args!(
357                "coroutine.wrap: upvalue is not a thread"
358            )))
359        }
360    };
361    let narg = state.get_top();
362    let r = aux_resume(state, co.clone(), narg);
363    if r < 0 {
364        let top = state.get_top();
365        let mut err_val = state.value_at(top);
366        if aux_status(state, &co) == COS_DEAD {
367            let old_err = state.pop();
368            let nclose = close_suspended_or_dead(state, co)?;
369            err_val = if nclose >= 2 {
370                let top = state.get_top();
371                state.value_at(top)
372            } else {
373                old_err
374            };
375            state.pop_n(nclose);
376        }
377        Err(LuaError::from_value(err_val))
378    } else {
379        Ok(r as usize)
380    }
381}
382
383/// `coroutine.create(f)` — create a new coroutine that will run function `f`.
384///
385/// Pushes the new thread value and returns 1.
386///
387/// Phase E-1: allocates a real `LuaState` registered in
388/// `GlobalState::threads`, with `f` staged on the new thread's stack so
389/// `coroutine.status` reports `"suspended"`. The full `xmove` from the
390/// caller's stack arrives in slice 02b; for this slice the body is
391/// cloned via `value_at(1)`, which has the same net stack effect since
392/// `lua_newthread` in C also leaves only the thread value on the
393/// caller's stack.
394///
395pub fn co_create(state: &mut LuaState) -> Result<usize, LuaError> {
396    state.check_arg_type(1, LuaType::Function)?;
397    let body = state.value_at(1);
398    let _nl = state.new_thread(Some(body))?;
399    Ok(1)
400}
401
402/// `coroutine.wrap(f)` — create a coroutine and return a resuming function.
403///
404/// The returned function, when called, resumes the coroutine as if by
405/// `coroutine.resume`, but raises an error rather than returning `false`.
406///
407///
408/// Captures the new coroutine thread as upvalue 1 of `aux_wrap`.
409pub fn co_wrap(state: &mut LuaState) -> Result<usize, LuaError> {
410    co_create(state)?;
411    state.push_cclosure(aux_wrap, 1)?;
412    Ok(1)
413}
414
415/// `coroutine.yield([...])` — suspend the running coroutine.
416///
417/// All arguments are passed back as results of the corresponding `resume`.
418///
419/// → `return lua_yield(L, lua_gettop(L));`
420/// → `lua_yield(L,n)` is `lua_yieldk(L, n, 0, NULL)` (lua.h:316)
421pub fn co_yield(state: &mut LuaState) -> Result<usize, LuaError> {
422    let n = state.get_top();
423    let r = lua_vm::do_::lua_yieldk(state, n, 0, None)?;
424    Ok(r as usize)
425}
426
427/// `coroutine.status(co)` — return a string describing `co`'s current status.
428///
429/// Returns one of `"running"`, `"dead"`, `"suspended"`, or `"normal"`.
430///
431pub fn co_status(state: &mut LuaState) -> Result<usize, LuaError> {
432    let co = get_co(state)?;
433    let idx = aux_status(state, &co) as usize;
434    let name: &[u8] = STAT_NAMES[idx];
435    let interned = state.intern_str(name)?;
436    state.push(LuaValue::Str(interned));
437    Ok(1)
438}
439
440/// `coroutine.isyieldable([co])` — test whether a coroutine (default: current)
441/// is in a yieldable state.
442///
443pub fn co_isyieldable(state: &mut LuaState) -> Result<usize, LuaError> {
444    let is_yieldable = if matches!(state.type_at(1), LuaType::None) {
445        state.is_yieldable()
446    } else {
447        let co = get_co(state)?;
448        let co_id = co.id;
449        let (is_main, is_current) = {
450            let g = state.global();
451            (co_id == g.main_thread_id, co_id == g.current_thread_id)
452        };
453        if is_main {
454            false
455        } else if is_current {
456            state.is_yieldable()
457        } else {
458            let entry_rc = {
459                let g = state.global();
460                g.threads
461                    .get(&co_id)
462                    .expect("thread value carries an id that must resolve in GlobalState::threads")
463                    .state
464                    .clone()
465            };
466            let target_is_yieldable = match entry_rc.try_borrow() {
467                Ok(b) => b.is_yieldable(),
468                Err(_) => false,
469            };
470            target_is_yieldable
471        }
472    };
473    state.push(LuaValue::Bool(is_yieldable));
474    Ok(1)
475}
476
477/// `coroutine.running()` — return the current coroutine plus a boolean.
478///
479/// The boolean is `true` when the current coroutine is the main thread.
480///
481pub fn co_running(state: &mut LuaState) -> Result<usize, LuaError> {
482    // TODO(port): push_thread pushes a Thread value for the current LuaState and
483    // returns true iff it is the main thread; Phase B wire-up needed.
484    let is_main = state.push_thread()?;
485    state.push(LuaValue::Bool(is_main));
486    Ok(2)
487}
488
489/// `coroutine.close(co)` — close a dead or suspended coroutine.
490///
491/// Closes a coroutine, running any pending to-be-closed variables via
492/// `__close` and resetting its status. Valid only when the target is
493/// suspended (`Yield`) or dead (`Ok` with no active frames).
494/// Calling on a running or normal coroutine raises an error.
495///
496pub fn co_close(state: &mut LuaState) -> Result<usize, LuaError> {
497    lua_vm::state::inc_c_stack(state)?;
498    let result = (|| {
499        let co = get_co(state)?;
500        let status = aux_status(state, &co);
501        match status {
502            COS_DEAD | COS_YIELD => close_suspended_or_dead(state, co),
503            _ => {
504                let name = if status == COS_RUN { "running" } else { "normal" };
505                Err(LuaError::runtime(format_args!(
506                    "cannot close a {} coroutine",
507                    name
508                )))
509            }
510        }
511    })();
512    state.n_ccalls -= 1;
513    result
514}
515
516/// Performs the actual close for a suspended or dead coroutine.
517fn close_suspended_or_dead(
518    state: &mut LuaState,
519    co: GcRef<lua_types::value::LuaThread>,
520) -> Result<usize, LuaError> {
521    let co_id = co.id;
522    let entry_rc_opt = {
523        let g = state.global();
524        g.threads.get(&co_id).map(|e| e.state.clone())
525    };
526    let entry_rc = match entry_rc_opt {
527        Some(rc) => rc,
528        None => {
529            state.push(LuaValue::Bool(true));
530            return Ok(1);
531        }
532    };
533    let parent_thread_id = state.global().current_thread_id;
534    let caller_c_calls = state.c_calls();
535
536    let parent_open_upval_slots: Vec<(u64, lua_vm::state::StackIdx)> = state
537        .openupval
538        .iter()
539        .filter_map(|uv| match &*uv.slot() {
540            lua_types::UpValState::Open { thread_id, idx } => {
541                Some((*thread_id as u64, *idx))
542            }
543            lua_types::UpValState::Closed(_) => None,
544        })
545        .collect();
546    {
547        let mut g = state.global_mut();
548        for (tid, idx) in &parent_open_upval_slots {
549            let val = state.get_at(*idx);
550            g.cross_thread_upvals.insert((*tid, *idx), val);
551        }
552    }
553
554    push_parent_gc_snapshot(state);
555
556    let (status, err_value): (i32, Option<LuaValue>) = {
557        let mut co_state = entry_rc.borrow_mut();
558        co_state.global_mut().current_thread_id = co_id;
559        co_state.n_ccalls = caller_c_calls;
560        let in_status = co_state.status as i32;
561        let s = lua_vm::state::reset_thread(&mut *co_state, in_status);
562        co_state.global_mut().current_thread_id = parent_thread_id;
563        if s == LuaStatus::Ok as i32 {
564            (s, None)
565        } else {
566            let top = co_state.top_idx().0;
567            if top > 0 {
568                let err = co_state.get_at(lua_vm::state::StackIdx(top - 1));
569                co_state.set_top(lua_vm::state::StackIdx(top - 1));
570                (s, Some(err))
571            } else {
572                (s, Some(LuaValue::Nil))
573            }
574        }
575    };
576
577    pop_parent_gc_snapshot(state);
578
579    {
580        let mut g = state.global_mut();
581        let mut flush: Vec<(lua_vm::state::StackIdx, LuaValue)> = Vec::new();
582        for (tid, idx) in &parent_open_upval_slots {
583            if let Some(v) = g.cross_thread_upvals.remove(&(*tid, *idx)) {
584                flush.push((*idx, v));
585            }
586        }
587        drop(g);
588        for (idx, v) in flush {
589            state.set_at(idx, v);
590        }
591    }
592
593    if status == LuaStatus::Ok as i32 {
594        state.push(LuaValue::Bool(true));
595        Ok(1)
596    } else {
597        state.push(LuaValue::Bool(false));
598        if let Some(v) = err_value {
599            state.push(v);
600        } else {
601            state.push(LuaValue::Nil);
602        }
603        Ok(2)
604    }
605}
606
607// ── Module entry point ────────────────────────────────────────────────────────
608
609/// Opens the `coroutine` standard library by pushing a new table containing
610/// all `coroutine.*` functions.
611///
612pub fn open_coroutine(state: &mut LuaState) -> Result<usize, LuaError> {
613    // TODO(port): state.new_lib(CO_FUNCS) creates a table from the registration
614    // slice and leaves it on the stack; Phase B wire-up needed.
615    state.new_lib(CO_FUNCS)?;
616    Ok(1)
617}
618
619// ──────────────────────────────────────────────────────────────────────────────
620// PORT STATUS
621//   source:        src/lcorolib.c  (210 lines, 12 functions)
622//   target_crate:  lua-stdlib
623//   confidence:    medium
624//   todos:         21
625//   port_notes:    2
626//   unsafe_blocks: 0
627//   notes:         All coroutine execution primitives (resume, yield, xmove,
628//                  new_thread, close_thread) are Phase E stubs that panic.
629//                  Argument-checking / result-packaging logic is faithfully
630//                  translated so Phase E can drop in real implementations.
631//                  The CO_FUNCS table type references lua_CFunction which is
632//                  resolved in Phase B.  LuaState / GcRef<LuaState> / LuaStatus
633//                  imports are all deferred to Phase B.
634// ──────────────────────────────────────────────────────────────────────────────