1use lua_types::{
21 error::LuaError,
22 value::LuaValue,
23 LuaType,
24 LuaStatus,
25 gc::GcRef,
26};
27use crate::state_stub::{LuaState, LuaStateStubExt as _, lua_CFunction, upvalue_index};
28
29const COS_RUN: i32 = 0;
34
35const COS_DEAD: i32 = 1;
37
38const COS_YIELD: i32 = 2;
40
41const COS_NORM: i32 = 3;
43
44const STAT_NAMES: [&[u8]; 4] = [b"running", b"dead", b"suspended", b"normal"];
48
49pub const CO_FUNCS: &[(&[u8], lua_CFunction)] = &[
57 (b"create", co_create),
58 (b"resume", co_resume),
59 (b"running", co_running),
60 (b"status", co_status),
61 (b"wrap", co_wrap),
62 (b"yield", co_yield),
63 (b"isyieldable", co_isyieldable),
64 (b"close", co_close),
65];
66
67fn get_co(state: &mut LuaState) -> Result<GcRef<lua_types::value::LuaThread>, LuaError> {
73 let co = state.to_thread(1);
74 if co.is_none() {
75 let got = state.arg(1);
76 return Err(LuaError::type_arg_error(1, "thread", &got));
77 }
78 Ok(co.expect("checked above"))
79}
80
81fn aux_status(state: &mut LuaState, co: &GcRef<lua_types::value::LuaThread>) -> i32 {
93 let co_id = co.id;
94 let entry_rc = {
95 let g = state.global();
96 if co_id == g.current_thread_id {
97 return COS_RUN;
98 }
99 if co_id == g.main_thread_id {
100 return COS_NORM;
101 }
102 match g.threads.get(&co_id) {
103 Some(e) => e.state.clone(),
104 None => return COS_DEAD,
105 }
106 };
107 let co_state = match entry_rc.try_borrow() {
108 Ok(state) => state,
109 Err(_) => {
110 return COS_NORM;
115 }
116 };
117 let raw_status = co_state.status;
118 if raw_status == LuaStatus::Yield as u8 {
119 return COS_YIELD;
120 }
121 if raw_status != LuaStatus::Ok as u8 {
122 return COS_DEAD;
123 }
124 let has_frames = co_state.ci.as_usize() > 0;
125 if has_frames {
126 return COS_NORM;
127 }
128 let ci_func = co_state.call_info[0].func.0;
129 let top = co_state.top.0;
130 let lua_gettop = top as i64 - ci_func as i64 - 1;
131 if lua_gettop == 0 {
132 COS_DEAD
133 } else {
134 COS_YIELD
135 }
136}
137
138fn aux_resume(state: &mut LuaState, co: GcRef<lua_types::value::LuaThread>, narg: i32) -> i32 {
155 let co_id = co.id;
156 let entry_rc = {
157 let g = state.global();
158 match g.threads.get(&co_id) {
159 Some(e) => e.state.clone(),
160 None => {
161 drop(g);
162 push_lit_or_nil(state, b"cannot resume dead coroutine");
163 return -1;
164 }
165 }
166 };
167 let parent_thread_id = state.global().current_thread_id;
168 let top_before = state.get_top();
169 if top_before < narg {
170 push_lit_or_nil(state, b"not enough arguments to resume");
171 return -1;
172 }
173 let first_arg_idx = top_before - narg + 1;
174 let args: Vec<LuaValue> = (first_arg_idx..=top_before)
175 .map(|i| state.value_at(i))
176 .collect();
177 lua_vm::api::set_top(state, (top_before - narg) as i32).ok();
178
179 let parent_open_upval_slots: Vec<(u64, lua_vm::state::StackIdx)> = state
180 .openupval
181 .iter()
182 .filter_map(|uv| match &*uv.slot() {
183 lua_types::UpValState::Open { thread_id, idx } => {
184 Some((*thread_id as u64, *idx))
185 }
186 lua_types::UpValState::Closed(_) => None,
187 })
188 .collect();
189 {
190 let mut g = state.global_mut();
191 for (tid, idx) in &parent_open_upval_slots {
192 let val = state.get_at(*idx);
193 g.cross_thread_upvals.insert((*tid, *idx), val);
194 }
195 }
196
197 push_parent_gc_snapshot(state);
198
199 let (status, results_or_err): (LuaStatus, Vec<LuaValue>) = {
200 let mut co_state = match entry_rc.try_borrow_mut() {
201 Ok(b) => b,
202 Err(_) => {
203 pop_parent_gc_snapshot(state);
204 let mut g = state.global_mut();
205 for (tid, idx) in &parent_open_upval_slots {
206 g.cross_thread_upvals.remove(&(*tid, *idx));
207 }
208 drop(g);
209 push_lit_or_nil(state, b"cannot resume non-suspended coroutine");
210 return -1;
211 }
212 };
213 if co_state.check_stack(narg + 1).is_err() {
214 drop(co_state);
215 pop_parent_gc_snapshot(state);
216 let mut g = state.global_mut();
217 for (tid, idx) in &parent_open_upval_slots {
218 g.cross_thread_upvals.remove(&(*tid, *idx));
219 }
220 drop(g);
221 push_lit_or_nil(state, b"too many arguments to resume");
222 return -1;
223 }
224 for v in args {
225 co_state.push(v);
226 }
227 co_state.global_mut().current_thread_id = co_id;
228 let mut nres: i32 = 0;
229 let status = lua_vm::do_::lua_resume(&mut *co_state, Some(state), narg, &mut nres);
230 co_state.global_mut().current_thread_id = parent_thread_id;
231 let co_top = co_state.top_idx().0 as i32;
232 let ci_func = co_state.current_call_info().func.0 as i32;
233 let count = if status == LuaStatus::Ok || status == LuaStatus::Yield {
234 nres
235 } else {
236 1
237 };
238 let start = co_top - count;
239 let vals: Vec<LuaValue> = (start..co_top)
240 .map(|i| co_state.get_at(lua_vm::state::StackIdx(i as u32)))
241 .collect();
242 let new_co_top = if status == LuaStatus::Ok || status == LuaStatus::Yield {
243 (co_top - count).max(ci_func + 1)
244 } else {
245 co_top - count
246 };
247 co_state.set_top(lua_vm::state::StackIdx(new_co_top.max(0) as u32));
248 (status, vals)
249 };
250
251 pop_parent_gc_snapshot(state);
253
254 {
255 let mut g = state.global_mut();
256 let mut flush: Vec<(lua_vm::state::StackIdx, LuaValue)> = Vec::new();
257 for (tid, idx) in &parent_open_upval_slots {
258 if let Some(v) = g.cross_thread_upvals.remove(&(*tid, *idx)) {
259 flush.push((*idx, v));
260 }
261 }
262 drop(g);
263 for (idx, v) in flush {
264 state.set_at(idx, v);
265 }
266 }
267
268 match status {
269 LuaStatus::Ok | LuaStatus::Yield => {
270 if state.check_stack(results_or_err.len() as i32 + 1).is_err() {
271 push_lit_or_nil(state, b"too many results to resume");
272 return -1;
273 }
274 let n = results_or_err.len();
275 for v in results_or_err {
276 state.push(v);
277 }
278 n as i32
279 }
280 _ => {
281 for v in results_or_err {
282 state.push(v);
283 }
284 -1
285 }
286 }
287}
288
289fn push_parent_gc_snapshot(state: &mut LuaState) {
290 let top = state.top_idx();
291 let stack_snapshot: Vec<LuaValue> = (0..top.0)
292 .map(|i| state.get_at(lua_vm::state::StackIdx(i)))
293 .collect();
294 let open_upval_snapshot = state.openupval.clone();
295 let mut g = state.global_mut();
296 g.suspended_parent_stacks.push(stack_snapshot);
297 g.suspended_parent_open_upvals.push(open_upval_snapshot);
298}
299
300fn pop_parent_gc_snapshot(state: &mut LuaState) {
301 let mut g = state.global_mut();
302 g.suspended_parent_open_upvals.pop();
303 g.suspended_parent_stacks.pop();
304}
305
306fn push_lit_or_nil(state: &mut LuaState, bytes: &[u8]) {
308 match state.intern_str(bytes) {
309 Ok(s) => state.push(LuaValue::Str(s)),
310 Err(_) => state.push(LuaValue::Nil),
311 }
312}
313
314pub fn co_resume(state: &mut LuaState) -> Result<usize, LuaError> {
322 let co = get_co(state)?;
323 let narg = state.get_top() - 1;
326 let r = aux_resume(state, co, narg);
327 if r < 0 {
328 state.push(LuaValue::Bool(false));
329 state.insert(-2);
330 Ok(2)
331 } else {
332 state.push(LuaValue::Bool(true));
333 state.insert(-(r + 1));
334 Ok((r + 1) as usize)
335 }
336}
337
338fn aux_wrap(state: &mut LuaState) -> Result<usize, LuaError> {
346 let up = state.value_at(upvalue_index(1));
347 let co = match up {
348 LuaValue::Thread(t) => t,
349 _ => {
350 return Err(LuaError::runtime(format_args!(
351 "coroutine.wrap: upvalue is not a thread"
352 )))
353 }
354 };
355 let narg = state.get_top();
356 let r = aux_resume(state, co.clone(), narg);
357 if r < 0 {
358 let top = state.get_top();
359 let mut err_val = state.value_at(top);
360 if aux_status(state, &co) == COS_DEAD {
361 let old_err = state.pop();
362 let nclose = close_suspended_or_dead(state, co)?;
363 err_val = if nclose >= 2 {
364 let top = state.get_top();
365 state.value_at(top)
366 } else {
367 old_err
368 };
369 state.pop_n(nclose);
370 }
371 Err(LuaError::from_value(err_val))
372 } else {
373 Ok(r as usize)
374 }
375}
376
377pub fn co_create(state: &mut LuaState) -> Result<usize, LuaError> {
390 state.check_arg_type(1, LuaType::Function)?;
391 let body = state.value_at(1);
392 let _nl = state.new_thread(Some(body))?;
393 Ok(1)
394}
395
396pub fn co_wrap(state: &mut LuaState) -> Result<usize, LuaError> {
404 co_create(state)?;
405 state.push_cclosure(aux_wrap, 1)?;
406 Ok(1)
407}
408
409pub fn co_yield(state: &mut LuaState) -> Result<usize, LuaError> {
416 let n = state.get_top();
417 let r = lua_vm::do_::lua_yieldk(state, n, 0, None)?;
418 Ok(r as usize)
419}
420
421pub fn co_status(state: &mut LuaState) -> Result<usize, LuaError> {
426 let co = get_co(state)?;
427 let idx = aux_status(state, &co) as usize;
428 let name: &[u8] = STAT_NAMES[idx];
429 let interned = state.intern_str(name)?;
430 state.push(LuaValue::Str(interned));
431 Ok(1)
432}
433
434pub fn co_isyieldable(state: &mut LuaState) -> Result<usize, LuaError> {
438 let is_yieldable = if matches!(state.type_at(1), LuaType::None) {
439 state.is_yieldable()
440 } else {
441 let co = get_co(state)?;
442 let co_id = co.id;
443 let (is_main, is_current) = {
444 let g = state.global();
445 (co_id == g.main_thread_id, co_id == g.current_thread_id)
446 };
447 if is_main {
448 false
449 } else if is_current {
450 state.is_yieldable()
451 } else {
452 let entry_rc = {
453 let g = state.global();
454 g.threads
455 .get(&co_id)
456 .expect("thread value carries an id that must resolve in GlobalState::threads")
457 .state
458 .clone()
459 };
460 let target_is_yieldable = match entry_rc.try_borrow() {
461 Ok(b) => b.is_yieldable(),
462 Err(_) => false,
463 };
464 target_is_yieldable
465 }
466 };
467 state.push(LuaValue::Bool(is_yieldable));
468 Ok(1)
469}
470
471pub fn co_running(state: &mut LuaState) -> Result<usize, LuaError> {
476 let is_main = state.push_thread()?;
479 state.push(LuaValue::Bool(is_main));
480 Ok(2)
481}
482
483pub fn co_close(state: &mut LuaState) -> Result<usize, LuaError> {
491 lua_vm::state::inc_c_stack(state)?;
492 let result = (|| {
493 let co = get_co(state)?;
494 let status = aux_status(state, &co);
495 match status {
496 COS_DEAD | COS_YIELD => close_suspended_or_dead(state, co),
497 _ => {
498 let name = if status == COS_RUN { "running" } else { "normal" };
499 Err(LuaError::runtime(format_args!(
500 "cannot close a {} coroutine",
501 name
502 )))
503 }
504 }
505 })();
506 state.nCcalls -= 1;
507 result
508}
509
510fn close_suspended_or_dead(
512 state: &mut LuaState,
513 co: GcRef<lua_types::value::LuaThread>,
514) -> Result<usize, LuaError> {
515 let co_id = co.id;
516 let entry_rc_opt = {
517 let g = state.global();
518 g.threads.get(&co_id).map(|e| e.state.clone())
519 };
520 let entry_rc = match entry_rc_opt {
521 Some(rc) => rc,
522 None => {
523 state.push(LuaValue::Bool(true));
524 return Ok(1);
525 }
526 };
527 let parent_thread_id = state.global().current_thread_id;
528 let caller_c_calls = state.c_calls();
529
530 let parent_open_upval_slots: Vec<(u64, lua_vm::state::StackIdx)> = state
531 .openupval
532 .iter()
533 .filter_map(|uv| match &*uv.slot() {
534 lua_types::UpValState::Open { thread_id, idx } => {
535 Some((*thread_id as u64, *idx))
536 }
537 lua_types::UpValState::Closed(_) => None,
538 })
539 .collect();
540 {
541 let mut g = state.global_mut();
542 for (tid, idx) in &parent_open_upval_slots {
543 let val = state.get_at(*idx);
544 g.cross_thread_upvals.insert((*tid, *idx), val);
545 }
546 }
547
548 push_parent_gc_snapshot(state);
549
550 let (status, err_value): (i32, Option<LuaValue>) = {
551 let mut co_state = entry_rc.borrow_mut();
552 co_state.global_mut().current_thread_id = co_id;
553 co_state.nCcalls = caller_c_calls;
554 let in_status = co_state.status as i32;
555 let s = lua_vm::state::reset_thread(&mut *co_state, in_status);
556 co_state.global_mut().current_thread_id = parent_thread_id;
557 if s == LuaStatus::Ok as i32 {
558 (s, None)
559 } else {
560 let top = co_state.top_idx().0;
561 if top > 0 {
562 let err = co_state.get_at(lua_vm::state::StackIdx(top - 1));
563 co_state.set_top(lua_vm::state::StackIdx(top - 1));
564 (s, Some(err))
565 } else {
566 (s, Some(LuaValue::Nil))
567 }
568 }
569 };
570
571 pop_parent_gc_snapshot(state);
572
573 {
574 let mut g = state.global_mut();
575 let mut flush: Vec<(lua_vm::state::StackIdx, LuaValue)> = Vec::new();
576 for (tid, idx) in &parent_open_upval_slots {
577 if let Some(v) = g.cross_thread_upvals.remove(&(*tid, *idx)) {
578 flush.push((*idx, v));
579 }
580 }
581 drop(g);
582 for (idx, v) in flush {
583 state.set_at(idx, v);
584 }
585 }
586
587 if status == LuaStatus::Ok as i32 {
588 state.push(LuaValue::Bool(true));
589 Ok(1)
590 } else {
591 state.push(LuaValue::Bool(false));
592 if let Some(v) = err_value {
593 state.push(v);
594 } else {
595 state.push(LuaValue::Nil);
596 }
597 Ok(2)
598 }
599}
600
601pub fn open_coroutine(state: &mut LuaState) -> Result<usize, LuaError> {
607 state.new_lib(CO_FUNCS)?;
610 Ok(1)
611}
612
613