1use lua_types::{
19 error::LuaError,
20 value::LuaValue,
21 LuaType,
22 LuaStatus,
23 gc::GcRef,
24};
25use crate::state_stub::{LuaState, LuaStateStubExt as _, lua_CFunction, upvalue_index};
26
27const COS_RUN: i32 = 0;
32
33const COS_DEAD: i32 = 1;
35
36const COS_YIELD: i32 = 2;
38
39const COS_NORM: i32 = 3;
41
42const STAT_NAMES: [&[u8]; 4] = [b"running", b"dead", b"suspended", b"normal"];
46
47pub const CO_FUNCS: &[(&[u8], lua_CFunction)] = &[
55 (b"create", co_create),
56 (b"resume", co_resume),
57 (b"running", co_running),
58 (b"status", co_status),
59 (b"wrap", co_wrap),
60 (b"yield", co_yield),
61 (b"isyieldable", co_isyieldable),
62 (b"close", co_close),
63];
64
65fn get_co(state: &mut LuaState) -> Result<GcRef<lua_types::value::LuaThread>, LuaError> {
71 let co = state.to_thread(1);
72 if co.is_none() {
73 let got = state.arg(1);
74 return Err(LuaError::type_arg_error(1, "thread", &got));
75 }
76 Ok(co.expect("checked above"))
77}
78
79fn aux_status(state: &mut LuaState, co: &GcRef<lua_types::value::LuaThread>) -> i32 {
91 let co_id = co.id;
92 let entry_rc = {
93 let g = state.global();
94 if co_id == g.current_thread_id {
95 return COS_RUN;
96 }
97 if co_id == g.main_thread_id {
98 return COS_NORM;
99 }
100 match g.threads.get(&co_id) {
101 Some(e) => e.state.clone(),
102 None => return COS_DEAD,
103 }
104 };
105 let co_state = match entry_rc.try_borrow() {
106 Ok(state) => state,
107 Err(_) => {
108 return COS_NORM;
113 }
114 };
115 let raw_status = co_state.status;
116 if raw_status == LuaStatus::Yield as u8 {
117 return COS_YIELD;
118 }
119 if raw_status != LuaStatus::Ok as u8 {
120 return COS_DEAD;
121 }
122 let has_frames = co_state.ci.as_usize() > 0;
123 if has_frames {
124 return COS_NORM;
125 }
126 let ci_func = co_state.call_info[0].func.0;
127 let top = co_state.top.0;
128 let lua_gettop = top as i64 - ci_func as i64 - 1;
129 if lua_gettop == 0 {
130 COS_DEAD
131 } else {
132 COS_YIELD
133 }
134}
135
136fn aux_resume(state: &mut LuaState, co: GcRef<lua_types::value::LuaThread>, narg: i32) -> i32 {
153 let co_id = co.id;
154 let entry_rc = {
155 let g = state.global();
156 match g.threads.get(&co_id) {
157 Some(e) => e.state.clone(),
158 None => {
159 drop(g);
160 push_lit_or_nil(state, b"cannot resume dead coroutine");
161 return -1;
162 }
163 }
164 };
165 let parent_thread_id = state.global().current_thread_id;
166 let top_before = state.get_top();
167 if top_before < narg {
168 push_lit_or_nil(state, b"not enough arguments to resume");
169 return -1;
170 }
171 let first_arg_idx = top_before - narg + 1;
172 let args: Vec<LuaValue> = (first_arg_idx..=top_before)
173 .map(|i| state.value_at(i))
174 .collect();
175 lua_vm::api::set_top(state, (top_before - narg) as i32).ok();
176
177 let parent_open_upval_slots: Vec<(u64, lua_vm::state::StackIdx)> = state
178 .openupval
179 .iter()
180 .filter_map(|uv| match &*uv.slot() {
181 lua_types::UpValState::Open { thread_id, idx } => {
182 Some((*thread_id as u64, *idx))
183 }
184 lua_types::UpValState::Closed(_) => None,
185 })
186 .collect();
187 {
188 let mut g = state.global_mut();
189 for (tid, idx) in &parent_open_upval_slots {
190 let val = state.get_at(*idx);
191 g.cross_thread_upvals.insert((*tid, *idx), val);
192 }
193 }
194
195 push_parent_gc_snapshot(state);
196
197 let (status, results_or_err): (LuaStatus, Vec<LuaValue>) = {
198 let mut co_state = match entry_rc.try_borrow_mut() {
199 Ok(b) => b,
200 Err(_) => {
201 pop_parent_gc_snapshot(state);
202 let mut g = state.global_mut();
203 for (tid, idx) in &parent_open_upval_slots {
204 g.cross_thread_upvals.remove(&(*tid, *idx));
205 }
206 drop(g);
207 push_lit_or_nil(state, b"cannot resume non-suspended coroutine");
208 return -1;
209 }
210 };
211 if co_state.check_stack(narg + 1).is_err() {
212 drop(co_state);
213 pop_parent_gc_snapshot(state);
214 let mut g = state.global_mut();
215 for (tid, idx) in &parent_open_upval_slots {
216 g.cross_thread_upvals.remove(&(*tid, *idx));
217 }
218 drop(g);
219 push_lit_or_nil(state, b"too many arguments to resume");
220 return -1;
221 }
222 for v in args {
223 co_state.push(v);
224 }
225 co_state.global_mut().current_thread_id = co_id;
226 let mut nres: i32 = 0;
227 let status = lua_vm::do_::lua_resume(&mut *co_state, Some(state), narg, &mut nres);
228 co_state.global_mut().current_thread_id = parent_thread_id;
229 let co_top = co_state.top_idx().0 as i32;
230 let ci_func = co_state.current_call_info().func.0 as i32;
231 let count = if status == LuaStatus::Ok || status == LuaStatus::Yield {
232 nres
233 } else {
234 1
235 };
236 let start = co_top - count;
237 let vals: Vec<LuaValue> = (start..co_top)
238 .map(|i| co_state.get_at(lua_vm::state::StackIdx(i as u32)))
239 .collect();
240 let new_co_top = if status == LuaStatus::Ok || status == LuaStatus::Yield {
241 (co_top - count).max(ci_func + 1)
242 } else {
243 co_top - count
244 };
245 co_state.set_top(lua_vm::state::StackIdx(new_co_top.max(0) as u32));
246 (status, vals)
247 };
248
249 pop_parent_gc_snapshot(state);
251
252 {
253 let mut g = state.global_mut();
254 let mut flush: Vec<(lua_vm::state::StackIdx, LuaValue)> = Vec::new();
255 for (tid, idx) in &parent_open_upval_slots {
256 if let Some(v) = g.cross_thread_upvals.remove(&(*tid, *idx)) {
257 flush.push((*idx, v));
258 }
259 }
260 drop(g);
261 for (idx, v) in flush {
262 state.set_at(idx, v);
263 }
264 }
265
266 match status {
267 LuaStatus::Ok | LuaStatus::Yield => {
268 if state.check_stack(results_or_err.len() as i32 + 1).is_err() {
269 push_lit_or_nil(state, b"too many results to resume");
270 return -1;
271 }
272 let n = results_or_err.len();
273 for v in results_or_err {
274 state.push(v);
275 }
276 n as i32
277 }
278 _ => {
279 for v in results_or_err {
280 state.push(v);
281 }
282 -1
283 }
284 }
285}
286
287fn push_parent_gc_snapshot(state: &mut LuaState) {
288 let top = state.top_idx();
289 let stack_snapshot: Vec<LuaValue> = (0..top.0)
290 .map(|i| state.get_at(lua_vm::state::StackIdx(i)))
291 .collect();
292 let open_upval_snapshot = state.openupval.clone();
293 let mut g = state.global_mut();
294 g.suspended_parent_stacks.push(stack_snapshot);
295 g.suspended_parent_open_upvals.push(open_upval_snapshot);
296}
297
298fn pop_parent_gc_snapshot(state: &mut LuaState) {
299 let mut g = state.global_mut();
300 g.suspended_parent_open_upvals.pop();
301 g.suspended_parent_stacks.pop();
302}
303
304fn push_lit_or_nil(state: &mut LuaState, bytes: &[u8]) {
306 match state.intern_str(bytes) {
307 Ok(s) => state.push(LuaValue::Str(s)),
308 Err(_) => state.push(LuaValue::Nil),
309 }
310}
311
312pub fn co_resume(state: &mut LuaState) -> Result<usize, LuaError> {
320 let co = get_co(state)?;
321 let narg = state.get_top() - 1;
324 let r = aux_resume(state, co, narg);
325 if r < 0 {
326 if state.sandbox_aborting() {
330 let top = state.get_top();
331 let err_val = state.value_at(top);
332 return Err(LuaError::from_value(err_val));
333 }
334 state.push(LuaValue::Bool(false));
335 state.insert(-2)?;
336 Ok(2)
337 } else {
338 state.push(LuaValue::Bool(true));
339 state.insert(-(r + 1))?;
340 Ok((r + 1) as usize)
341 }
342}
343
344fn aux_wrap(state: &mut LuaState) -> Result<usize, LuaError> {
352 let up = state.value_at(upvalue_index(1));
353 let co = match up {
354 LuaValue::Thread(t) => t,
355 _ => {
356 return Err(LuaError::runtime(format_args!(
357 "coroutine.wrap: upvalue is not a thread"
358 )))
359 }
360 };
361 let narg = state.get_top();
362 let r = aux_resume(state, co.clone(), narg);
363 if r < 0 {
364 let top = state.get_top();
365 let mut err_val = state.value_at(top);
366 if aux_status(state, &co) == COS_DEAD {
367 let old_err = state.pop();
368 let nclose = close_suspended_or_dead(state, co)?;
369 err_val = if nclose >= 2 {
370 let top = state.get_top();
371 state.value_at(top)
372 } else {
373 old_err
374 };
375 state.pop_n(nclose);
376 }
377 Err(LuaError::from_value(err_val))
378 } else {
379 Ok(r as usize)
380 }
381}
382
383pub fn co_create(state: &mut LuaState) -> Result<usize, LuaError> {
396 state.check_arg_type(1, LuaType::Function)?;
397 let body = state.value_at(1);
398 let _nl = state.new_thread(Some(body))?;
399 Ok(1)
400}
401
402pub fn co_wrap(state: &mut LuaState) -> Result<usize, LuaError> {
410 co_create(state)?;
411 state.push_cclosure(aux_wrap, 1)?;
412 Ok(1)
413}
414
415pub fn co_yield(state: &mut LuaState) -> Result<usize, LuaError> {
422 let n = state.get_top();
423 let r = lua_vm::do_::lua_yieldk(state, n, 0, None)?;
424 Ok(r as usize)
425}
426
427pub fn co_status(state: &mut LuaState) -> Result<usize, LuaError> {
432 let co = get_co(state)?;
433 let idx = aux_status(state, &co) as usize;
434 let name: &[u8] = STAT_NAMES[idx];
435 let interned = state.intern_str(name)?;
436 state.push(LuaValue::Str(interned));
437 Ok(1)
438}
439
440pub fn co_isyieldable(state: &mut LuaState) -> Result<usize, LuaError> {
444 let is_yieldable = if matches!(state.type_at(1), LuaType::None) {
445 state.is_yieldable()
446 } else {
447 let co = get_co(state)?;
448 let co_id = co.id;
449 let (is_main, is_current) = {
450 let g = state.global();
451 (co_id == g.main_thread_id, co_id == g.current_thread_id)
452 };
453 if is_main {
454 false
455 } else if is_current {
456 state.is_yieldable()
457 } else {
458 let entry_rc = {
459 let g = state.global();
460 g.threads
461 .get(&co_id)
462 .expect("thread value carries an id that must resolve in GlobalState::threads")
463 .state
464 .clone()
465 };
466 let target_is_yieldable = match entry_rc.try_borrow() {
467 Ok(b) => b.is_yieldable(),
468 Err(_) => false,
469 };
470 target_is_yieldable
471 }
472 };
473 state.push(LuaValue::Bool(is_yieldable));
474 Ok(1)
475}
476
477pub fn co_running(state: &mut LuaState) -> Result<usize, LuaError> {
482 let is_main = state.push_thread()?;
485 state.push(LuaValue::Bool(is_main));
486 Ok(2)
487}
488
489pub fn co_close(state: &mut LuaState) -> Result<usize, LuaError> {
497 lua_vm::state::inc_c_stack(state)?;
498 let result = (|| {
499 let co = get_co(state)?;
500 let status = aux_status(state, &co);
501 match status {
502 COS_DEAD | COS_YIELD => close_suspended_or_dead(state, co),
503 _ => {
504 let name = if status == COS_RUN { "running" } else { "normal" };
505 Err(LuaError::runtime(format_args!(
506 "cannot close a {} coroutine",
507 name
508 )))
509 }
510 }
511 })();
512 state.n_ccalls -= 1;
513 result
514}
515
516fn close_suspended_or_dead(
518 state: &mut LuaState,
519 co: GcRef<lua_types::value::LuaThread>,
520) -> Result<usize, LuaError> {
521 let co_id = co.id;
522 let entry_rc_opt = {
523 let g = state.global();
524 g.threads.get(&co_id).map(|e| e.state.clone())
525 };
526 let entry_rc = match entry_rc_opt {
527 Some(rc) => rc,
528 None => {
529 state.push(LuaValue::Bool(true));
530 return Ok(1);
531 }
532 };
533 let parent_thread_id = state.global().current_thread_id;
534 let caller_c_calls = state.c_calls();
535
536 let parent_open_upval_slots: Vec<(u64, lua_vm::state::StackIdx)> = state
537 .openupval
538 .iter()
539 .filter_map(|uv| match &*uv.slot() {
540 lua_types::UpValState::Open { thread_id, idx } => {
541 Some((*thread_id as u64, *idx))
542 }
543 lua_types::UpValState::Closed(_) => None,
544 })
545 .collect();
546 {
547 let mut g = state.global_mut();
548 for (tid, idx) in &parent_open_upval_slots {
549 let val = state.get_at(*idx);
550 g.cross_thread_upvals.insert((*tid, *idx), val);
551 }
552 }
553
554 push_parent_gc_snapshot(state);
555
556 let (status, err_value): (i32, Option<LuaValue>) = {
557 let mut co_state = entry_rc.borrow_mut();
558 co_state.global_mut().current_thread_id = co_id;
559 co_state.n_ccalls = caller_c_calls;
560 let in_status = co_state.status as i32;
561 let s = lua_vm::state::reset_thread(&mut *co_state, in_status);
562 co_state.global_mut().current_thread_id = parent_thread_id;
563 if s == LuaStatus::Ok as i32 {
564 (s, None)
565 } else {
566 let top = co_state.top_idx().0;
567 if top > 0 {
568 let err = co_state.get_at(lua_vm::state::StackIdx(top - 1));
569 co_state.set_top(lua_vm::state::StackIdx(top - 1));
570 (s, Some(err))
571 } else {
572 (s, Some(LuaValue::Nil))
573 }
574 }
575 };
576
577 pop_parent_gc_snapshot(state);
578
579 {
580 let mut g = state.global_mut();
581 let mut flush: Vec<(lua_vm::state::StackIdx, LuaValue)> = Vec::new();
582 for (tid, idx) in &parent_open_upval_slots {
583 if let Some(v) = g.cross_thread_upvals.remove(&(*tid, *idx)) {
584 flush.push((*idx, v));
585 }
586 }
587 drop(g);
588 for (idx, v) in flush {
589 state.set_at(idx, v);
590 }
591 }
592
593 if status == LuaStatus::Ok as i32 {
594 state.push(LuaValue::Bool(true));
595 Ok(1)
596 } else {
597 state.push(LuaValue::Bool(false));
598 if let Some(v) = err_value {
599 state.push(v);
600 } else {
601 state.push(LuaValue::Nil);
602 }
603 Ok(2)
604 }
605}
606
607pub fn open_coroutine(state: &mut LuaState) -> Result<usize, LuaError> {
613 state.new_lib(CO_FUNCS)?;
616 Ok(1)
617}
618
619