1use lua_types::{
21 error::LuaError,
22 value::LuaValue,
23 LuaType,
24 LuaStatus,
25 gc::GcRef,
26};
27use crate::state_stub::{LuaState, LuaStateStubExt as _, lua_CFunction, upvalue_index};
28
29const COS_RUN: i32 = 0;
38
39const COS_DEAD: i32 = 1;
41
42const COS_YIELD: i32 = 2;
44
45const COS_NORM: i32 = 3;
47
48const STAT_NAMES: [&[u8]; 4] = [b"running", b"dead", b"suspended", b"normal"];
53
54pub const CO_FUNCS: &[(&[u8], lua_CFunction)] = &[
63 (b"create", co_create),
64 (b"resume", co_resume),
65 (b"running", co_running),
66 (b"status", co_status),
67 (b"wrap", co_wrap),
68 (b"yield", co_yield),
69 (b"isyieldable", co_isyieldable),
70 (b"close", co_close),
71];
72
73fn get_co(state: &mut LuaState) -> Result<GcRef<lua_types::value::LuaThread>, LuaError> {
80 let co = state.to_thread(1);
81 if co.is_none() {
82 let got = state.arg(1);
83 return Err(LuaError::type_arg_error(1, "thread", &got));
84 }
85 Ok(co.expect("checked above"))
86}
87
88fn aux_status(state: &mut LuaState, co: &GcRef<lua_types::value::LuaThread>) -> i32 {
101 let co_id = co.id;
102 let entry_rc = {
103 let g = state.global();
104 if co_id == g.current_thread_id {
105 return COS_RUN;
106 }
107 if co_id == g.main_thread_id {
108 return COS_NORM;
109 }
110 match g.threads.get(&co_id) {
111 Some(e) => e.state.clone(),
112 None => return COS_DEAD,
113 }
114 };
115 let co_state = match entry_rc.try_borrow() {
116 Ok(state) => state,
117 Err(_) => {
118 return COS_NORM;
123 }
124 };
125 let raw_status = co_state.status;
126 if raw_status == LuaStatus::Yield as u8 {
127 return COS_YIELD;
128 }
129 if raw_status != LuaStatus::Ok as u8 {
130 return COS_DEAD;
131 }
132 let has_frames = co_state.ci.as_usize() > 0;
133 if has_frames {
134 return COS_NORM;
135 }
136 let ci_func = co_state.call_info[0].func.0;
137 let top = co_state.top.0;
138 let lua_gettop = top as i64 - ci_func as i64 - 1;
139 if lua_gettop == 0 {
140 COS_DEAD
141 } else {
142 COS_YIELD
143 }
144}
145
146fn aux_resume(state: &mut LuaState, co: GcRef<lua_types::value::LuaThread>, narg: i32) -> i32 {
164 let co_id = co.id;
165 let entry_rc = {
166 let g = state.global();
167 match g.threads.get(&co_id) {
168 Some(e) => e.state.clone(),
169 None => {
170 drop(g);
171 push_lit_or_nil(state, b"cannot resume dead coroutine");
172 return -1;
173 }
174 }
175 };
176 let parent_thread_id = state.global().current_thread_id;
177 let top_before = state.get_top();
178 if top_before < narg {
179 push_lit_or_nil(state, b"not enough arguments to resume");
180 return -1;
181 }
182 let first_arg_idx = top_before - narg + 1;
183 let args: Vec<LuaValue> = (first_arg_idx..=top_before)
184 .map(|i| state.value_at(i))
185 .collect();
186 lua_vm::api::set_top(state, (top_before - narg) as i32).ok();
187
188 let parent_open_upval_slots: Vec<(u64, lua_vm::state::StackIdx)> = state
189 .openupval
190 .iter()
191 .filter_map(|uv| match &*uv.slot() {
192 lua_types::UpValState::Open { thread_id, idx } => {
193 Some((*thread_id as u64, *idx))
194 }
195 lua_types::UpValState::Closed(_) => None,
196 })
197 .collect();
198 {
199 let mut g = state.global_mut();
200 for (tid, idx) in &parent_open_upval_slots {
201 let val = state.get_at(*idx);
202 g.cross_thread_upvals.insert((*tid, *idx), val);
203 }
204 }
205
206 push_parent_gc_snapshot(state);
207
208 let (status, results_or_err): (LuaStatus, Vec<LuaValue>) = {
209 let mut co_state = match entry_rc.try_borrow_mut() {
210 Ok(b) => b,
211 Err(_) => {
212 pop_parent_gc_snapshot(state);
213 let mut g = state.global_mut();
214 for (tid, idx) in &parent_open_upval_slots {
215 g.cross_thread_upvals.remove(&(*tid, *idx));
216 }
217 drop(g);
218 push_lit_or_nil(state, b"cannot resume non-suspended coroutine");
219 return -1;
220 }
221 };
222 if co_state.check_stack(narg + 1).is_err() {
223 drop(co_state);
224 pop_parent_gc_snapshot(state);
225 let mut g = state.global_mut();
226 for (tid, idx) in &parent_open_upval_slots {
227 g.cross_thread_upvals.remove(&(*tid, *idx));
228 }
229 drop(g);
230 push_lit_or_nil(state, b"too many arguments to resume");
231 return -1;
232 }
233 for v in args {
234 co_state.push(v);
235 }
236 co_state.global_mut().current_thread_id = co_id;
237 let mut nres: i32 = 0;
238 let status = lua_vm::do_::lua_resume(&mut *co_state, Some(state), narg, &mut nres);
239 co_state.global_mut().current_thread_id = parent_thread_id;
240 let co_top = co_state.top_idx().0 as i32;
241 let ci_func = co_state.current_call_info().func.0 as i32;
242 let count = if status == LuaStatus::Ok || status == LuaStatus::Yield {
243 nres
244 } else {
245 1
246 };
247 let start = co_top - count;
248 let vals: Vec<LuaValue> = (start..co_top)
249 .map(|i| co_state.get_at(lua_vm::state::StackIdx(i as u32)))
250 .collect();
251 let new_co_top = if status == LuaStatus::Ok || status == LuaStatus::Yield {
252 (co_top - count).max(ci_func + 1)
253 } else {
254 co_top - count
255 };
256 co_state.set_top(lua_vm::state::StackIdx(new_co_top.max(0) as u32));
257 (status, vals)
258 };
259
260 pop_parent_gc_snapshot(state);
262
263 {
264 let mut g = state.global_mut();
265 let mut flush: Vec<(lua_vm::state::StackIdx, LuaValue)> = Vec::new();
266 for (tid, idx) in &parent_open_upval_slots {
267 if let Some(v) = g.cross_thread_upvals.remove(&(*tid, *idx)) {
268 flush.push((*idx, v));
269 }
270 }
271 drop(g);
272 for (idx, v) in flush {
273 state.set_at(idx, v);
274 }
275 }
276
277 match status {
278 LuaStatus::Ok | LuaStatus::Yield => {
279 if state.check_stack(results_or_err.len() as i32 + 1).is_err() {
280 push_lit_or_nil(state, b"too many results to resume");
281 return -1;
282 }
283 let n = results_or_err.len();
284 for v in results_or_err {
285 state.push(v);
286 }
287 n as i32
288 }
289 _ => {
290 for v in results_or_err {
291 state.push(v);
292 }
293 -1
294 }
295 }
296}
297
298fn push_parent_gc_snapshot(state: &mut LuaState) {
299 let top = state.top_idx();
300 let stack_snapshot: Vec<LuaValue> = (0..top.0)
301 .map(|i| state.get_at(lua_vm::state::StackIdx(i)))
302 .collect();
303 let open_upval_snapshot = state.openupval.clone();
304 let mut g = state.global_mut();
305 g.suspended_parent_stacks.push(stack_snapshot);
306 g.suspended_parent_open_upvals.push(open_upval_snapshot);
307}
308
309fn pop_parent_gc_snapshot(state: &mut LuaState) {
310 let mut g = state.global_mut();
311 g.suspended_parent_open_upvals.pop();
312 g.suspended_parent_stacks.pop();
313}
314
315fn push_lit_or_nil(state: &mut LuaState, bytes: &[u8]) {
317 match state.intern_str(bytes) {
318 Ok(s) => state.push(LuaValue::Str(s)),
319 Err(_) => state.push(LuaValue::Nil),
320 }
321}
322
323pub fn co_resume(state: &mut LuaState) -> Result<usize, LuaError> {
332 let co = get_co(state)?;
334 let narg = state.get_top() - 1;
338 let r = aux_resume(state, co, narg);
339 if r < 0 {
340 state.push(LuaValue::Bool(false));
342 state.insert(-2);
343 Ok(2)
344 } else {
345 state.push(LuaValue::Bool(true));
347 state.insert(-(r + 1));
348 Ok((r + 1) as usize)
349 }
350}
351
352fn aux_wrap(state: &mut LuaState) -> Result<usize, LuaError> {
361 let up = state.value_at(upvalue_index(1));
362 let co = match up {
363 LuaValue::Thread(t) => t,
364 _ => {
365 return Err(LuaError::runtime(format_args!(
366 "coroutine.wrap: upvalue is not a thread"
367 )))
368 }
369 };
370 let narg = state.get_top();
371 let r = aux_resume(state, co.clone(), narg);
372 if r < 0 {
373 let top = state.get_top();
374 let mut err_val = state.value_at(top);
375 if aux_status(state, &co) == COS_DEAD {
376 let old_err = state.pop();
377 let nclose = close_suspended_or_dead(state, co)?;
378 err_val = if nclose >= 2 {
379 let top = state.get_top();
380 state.value_at(top)
381 } else {
382 old_err
383 };
384 state.pop_n(nclose);
385 }
386 Err(LuaError::from_value(err_val))
387 } else {
388 Ok(r as usize)
389 }
390}
391
392pub fn co_create(state: &mut LuaState) -> Result<usize, LuaError> {
406 state.check_arg_type(1, LuaType::Function)?;
407 let body = state.value_at(1);
408 let _nl = state.new_thread(Some(body))?;
409 Ok(1)
410}
411
412pub fn co_wrap(state: &mut LuaState) -> Result<usize, LuaError> {
421 co_create(state)?;
422 state.push_cclosure(aux_wrap, 1)?;
423 Ok(1)
424}
425
426pub fn co_yield(state: &mut LuaState) -> Result<usize, LuaError> {
434 let n = state.get_top();
435 let r = lua_vm::do_::lua_yieldk(state, n, 0, None)?;
436 Ok(r as usize)
437}
438
439pub fn co_status(state: &mut LuaState) -> Result<usize, LuaError> {
445 let co = get_co(state)?;
447 let idx = aux_status(state, &co) as usize;
449 let name: &[u8] = STAT_NAMES[idx];
450 let interned = state.intern_str(name)?;
451 state.push(LuaValue::Str(interned));
452 Ok(1)
453}
454
455pub fn co_isyieldable(state: &mut LuaState) -> Result<usize, LuaError> {
460 let is_yieldable = if matches!(state.type_at(1), LuaType::None) {
461 state.is_yieldable()
462 } else {
463 let co = get_co(state)?;
464 let co_id = co.id;
465 let (is_main, is_current) = {
466 let g = state.global();
467 (co_id == g.main_thread_id, co_id == g.current_thread_id)
468 };
469 if is_main {
470 false
471 } else if is_current {
472 state.is_yieldable()
473 } else {
474 let entry_rc = {
475 let g = state.global();
476 g.threads
477 .get(&co_id)
478 .expect("thread value carries an id that must resolve in GlobalState::threads")
479 .state
480 .clone()
481 };
482 let target_is_yieldable = match entry_rc.try_borrow() {
483 Ok(b) => b.is_yieldable(),
484 Err(_) => false,
485 };
486 target_is_yieldable
487 }
488 };
489 state.push(LuaValue::Bool(is_yieldable));
490 Ok(1)
491}
492
493pub fn co_running(state: &mut LuaState) -> Result<usize, LuaError> {
499 let is_main = state.push_thread()?;
503 state.push(LuaValue::Bool(is_main));
505 Ok(2)
506}
507
508pub fn co_close(state: &mut LuaState) -> Result<usize, LuaError> {
517 lua_vm::state::inc_c_stack(state)?;
518 let result = (|| {
519 let co = get_co(state)?;
520 let status = aux_status(state, &co);
521 match status {
522 COS_DEAD | COS_YIELD => close_suspended_or_dead(state, co),
523 _ => {
524 let name = if status == COS_RUN { "running" } else { "normal" };
525 Err(LuaError::runtime(format_args!(
526 "cannot close a {} coroutine",
527 name
528 )))
529 }
530 }
531 })();
532 state.nCcalls -= 1;
533 result
534}
535
536fn close_suspended_or_dead(
538 state: &mut LuaState,
539 co: GcRef<lua_types::value::LuaThread>,
540) -> Result<usize, LuaError> {
541 let co_id = co.id;
542 let entry_rc_opt = {
543 let g = state.global();
544 g.threads.get(&co_id).map(|e| e.state.clone())
545 };
546 let entry_rc = match entry_rc_opt {
547 Some(rc) => rc,
548 None => {
549 state.push(LuaValue::Bool(true));
550 return Ok(1);
551 }
552 };
553 let parent_thread_id = state.global().current_thread_id;
554 let caller_c_calls = state.c_calls();
555
556 let parent_open_upval_slots: Vec<(u64, lua_vm::state::StackIdx)> = state
557 .openupval
558 .iter()
559 .filter_map(|uv| match &*uv.slot() {
560 lua_types::UpValState::Open { thread_id, idx } => {
561 Some((*thread_id as u64, *idx))
562 }
563 lua_types::UpValState::Closed(_) => None,
564 })
565 .collect();
566 {
567 let mut g = state.global_mut();
568 for (tid, idx) in &parent_open_upval_slots {
569 let val = state.get_at(*idx);
570 g.cross_thread_upvals.insert((*tid, *idx), val);
571 }
572 }
573
574 push_parent_gc_snapshot(state);
575
576 let (status, err_value): (i32, Option<LuaValue>) = {
577 let mut co_state = entry_rc.borrow_mut();
578 co_state.global_mut().current_thread_id = co_id;
579 co_state.nCcalls = caller_c_calls;
580 let in_status = co_state.status as i32;
581 let s = lua_vm::state::reset_thread(&mut *co_state, in_status);
582 co_state.global_mut().current_thread_id = parent_thread_id;
583 if s == LuaStatus::Ok as i32 {
584 (s, None)
585 } else {
586 let top = co_state.top_idx().0;
587 if top > 0 {
588 let err = co_state.get_at(lua_vm::state::StackIdx(top - 1));
589 co_state.set_top(lua_vm::state::StackIdx(top - 1));
590 (s, Some(err))
591 } else {
592 (s, Some(LuaValue::Nil))
593 }
594 }
595 };
596
597 pop_parent_gc_snapshot(state);
598
599 {
600 let mut g = state.global_mut();
601 let mut flush: Vec<(lua_vm::state::StackIdx, LuaValue)> = Vec::new();
602 for (tid, idx) in &parent_open_upval_slots {
603 if let Some(v) = g.cross_thread_upvals.remove(&(*tid, *idx)) {
604 flush.push((*idx, v));
605 }
606 }
607 drop(g);
608 for (idx, v) in flush {
609 state.set_at(idx, v);
610 }
611 }
612
613 if status == LuaStatus::Ok as i32 {
614 state.push(LuaValue::Bool(true));
615 Ok(1)
616 } else {
617 state.push(LuaValue::Bool(false));
618 if let Some(v) = err_value {
619 state.push(v);
620 } else {
621 state.push(LuaValue::Nil);
622 }
623 Ok(2)
624 }
625}
626
627pub fn open_coroutine(state: &mut LuaState) -> Result<usize, LuaError> {
634 state.new_lib(CO_FUNCS)?;
638 Ok(1)
639}
640
641