1use std::os::raw::{c_int, c_void};
2
3use crate::error::{Error, Result};
4#[allow(unused)]
5use crate::state::Lua;
6use crate::state::RawLua;
7use crate::types::ValueRef;
8use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard};
9use crate::value::{FromLuaMulti, IntoLuaMulti};
10
11#[cfg(not(feature = "luau"))]
12use crate::{
13 hook::{Debug, HookTriggers},
14 types::MaybeSend,
15};
16
17#[cfg(feature = "async")]
18use {
19 futures_util::stream::Stream,
20 std::{
21 future::Future,
22 marker::PhantomData,
23 pin::Pin,
24 ptr::NonNull,
25 task::{Context, Poll, Waker},
26 },
27};
28
29#[derive(Debug, Copy, Clone, Eq, PartialEq)]
31pub enum ThreadStatus {
32 Resumable,
36 Running,
38 Finished,
40 Error,
42}
43
44#[derive(Clone, Debug)]
46pub struct Thread(pub(crate) ValueRef, pub(crate) *mut ffi::lua_State);
47
48#[cfg(feature = "send")]
49unsafe impl Send for Thread {}
50#[cfg(feature = "send")]
51unsafe impl Sync for Thread {}
52
53#[cfg(feature = "async")]
60#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
61#[must_use = "futures do nothing unless you `.await` or poll them"]
62pub struct AsyncThread<A, R> {
63 thread: Thread,
64 init_args: Option<A>,
65 ret: PhantomData<R>,
66 recycle: bool,
67}
68
69impl Thread {
70 #[inline(always)]
71 const fn state(&self) -> *mut ffi::lua_State {
72 self.1
73 }
74
75 pub fn resume<R>(&self, args: impl IntoLuaMulti) -> Result<R>
117 where
118 R: FromLuaMulti,
119 {
120 let lua = self.0.lua.lock();
121 if self.status_inner(&lua) != ThreadStatus::Resumable {
122 return Err(Error::CoroutineUnresumable);
123 }
124
125 let state = lua.state();
126 let thread_state = self.state();
127 unsafe {
128 let _sg = StackGuard::new(state);
129 let _thread_sg = StackGuard::with_top(thread_state, 0);
130
131 let nresults = self.resume_inner(&lua, args)?;
132 check_stack(state, nresults + 1)?;
133 ffi::lua_xmove(thread_state, state, nresults);
134
135 R::from_stack_multi(nresults, &lua)
136 }
137 }
138
139 unsafe fn resume_inner(&self, lua: &RawLua, args: impl IntoLuaMulti) -> Result<c_int> {
143 let state = lua.state();
144 let thread_state = self.state();
145
146 let nargs = args.push_into_stack_multi(&lua)?;
147 if nargs > 0 {
148 check_stack(thread_state, nargs)?;
149 ffi::lua_xmove(state, thread_state, nargs);
150 }
151
152 let mut nresults = 0;
153 let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int);
154 if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD {
155 if ret == ffi::LUA_ERRMEM {
156 return Err(pop_error(thread_state, ret));
158 }
159 check_stack(state, 3)?;
160 protect_lua!(state, 0, 1, |state| error_traceback_thread(state, thread_state))?;
161 return Err(pop_error(state, ret));
162 }
163
164 Ok(nresults)
165 }
166
167 pub fn status(&self) -> ThreadStatus {
169 self.status_inner(&self.0.lua.lock())
170 }
171
172 pub(crate) fn status_inner(&self, lua: &RawLua) -> ThreadStatus {
174 let thread_state = self.state();
175 if thread_state == lua.state() {
176 return ThreadStatus::Running;
178 }
179 let status = unsafe { ffi::lua_status(thread_state) };
180 if status != ffi::LUA_OK && status != ffi::LUA_YIELD {
181 ThreadStatus::Error
182 } else if status == ffi::LUA_YIELD || unsafe { ffi::lua_gettop(thread_state) > 0 } {
183 ThreadStatus::Resumable
184 } else {
185 ThreadStatus::Finished
186 }
187 }
188
189 #[cfg(not(feature = "luau"))]
194 #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
195 pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F)
196 where
197 F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static,
198 {
199 let lua = self.0.lua.lock();
200 unsafe {
201 lua.set_thread_hook(self.state(), triggers, callback);
202 }
203 }
204
205 #[cfg(any(feature = "lua54", feature = "luau"))]
220 #[cfg_attr(docsrs, doc(cfg(any(feature = "lua54", feature = "luau"))))]
221 pub fn reset(&self, func: crate::function::Function) -> Result<()> {
222 let lua = self.0.lua.lock();
223 if self.status_inner(&lua) == ThreadStatus::Running {
224 return Err(Error::runtime("cannot reset a running thread"));
225 }
226
227 let thread_state = self.state();
228 unsafe {
229 #[cfg(all(feature = "lua54", not(feature = "vendored")))]
230 let status = ffi::lua_resetthread(thread_state);
231 #[cfg(all(feature = "lua54", feature = "vendored"))]
232 let status = ffi::lua_closethread(thread_state, lua.state());
233 #[cfg(feature = "lua54")]
234 if status != ffi::LUA_OK {
235 return Err(pop_error(thread_state, status));
236 }
237 #[cfg(feature = "luau")]
238 ffi::lua_resetthread(thread_state);
239
240 ffi::lua_xpush(lua.ref_thread(), thread_state, func.0.index);
242
243 #[cfg(feature = "luau")]
244 {
245 ffi::lua_xpush(lua.main_state(), thread_state, ffi::LUA_GLOBALSINDEX);
247 ffi::lua_replace(thread_state, ffi::LUA_GLOBALSINDEX);
248 }
249
250 Ok(())
251 }
252 }
253
254 #[cfg(feature = "async")]
300 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
301 pub fn into_async<R>(self, args: impl IntoLuaMulti) -> AsyncThread<impl IntoLuaMulti, R>
302 where
303 R: FromLuaMulti,
304 {
305 AsyncThread {
306 thread: self,
307 init_args: Some(args),
308 ret: PhantomData,
309 recycle: false,
310 }
311 }
312
313 #[cfg(any(feature = "luau", docsrs))]
345 #[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
346 #[doc(hidden)]
347 pub fn sandbox(&self) -> Result<()> {
348 let lua = self.0.lua.lock();
349 let state = lua.state();
350 let thread_state = self.state();
351 unsafe {
352 check_stack(thread_state, 3)?;
353 check_stack(state, 3)?;
354 protect_lua!(state, 0, 0, |_| ffi::luaL_sandboxthread(thread_state))
355 }
356 }
357
358 #[inline]
364 pub fn to_pointer(&self) -> *const c_void {
365 self.0.to_pointer()
366 }
367}
368
369impl PartialEq for Thread {
370 fn eq(&self, other: &Self) -> bool {
371 self.0 == other.0
372 }
373}
374
375#[cfg(feature = "async")]
376impl<A, R> AsyncThread<A, R> {
377 #[inline]
378 pub(crate) fn set_recyclable(&mut self, recyclable: bool) {
379 self.recycle = recyclable;
380 }
381}
382
383#[cfg(feature = "async")]
384#[cfg(any(feature = "lua54", feature = "luau"))]
385impl<A, R> Drop for AsyncThread<A, R> {
386 fn drop(&mut self) {
387 if self.recycle {
388 if let Some(lua) = self.thread.0.lua.try_lock() {
389 unsafe {
390 if !lua.recycle_thread(&mut self.thread) {
392 #[cfg(feature = "lua54")]
393 if self.thread.status_inner(&lua) == ThreadStatus::Error {
394 #[cfg(not(feature = "vendored"))]
395 ffi::lua_resetthread(self.thread.state());
396 #[cfg(feature = "vendored")]
397 ffi::lua_closethread(self.thread.state(), lua.state());
398 }
399 }
400 }
401 }
402 }
403 }
404}
405
406#[cfg(feature = "async")]
407impl<A: IntoLuaMulti, R: FromLuaMulti> Stream for AsyncThread<A, R> {
408 type Item = Result<R>;
409
410 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
411 let lua = self.thread.0.lua.lock();
412 if self.thread.status_inner(&lua) != ThreadStatus::Resumable {
413 return Poll::Ready(None);
414 }
415
416 let state = lua.state();
417 let thread_state = self.thread.state();
418 unsafe {
419 let _sg = StackGuard::new(state);
420 let _thread_sg = StackGuard::with_top(thread_state, 0);
421 let _wg = WakerGuard::new(&lua, cx.waker());
422
423 let this = self.get_unchecked_mut();
425 let nresults = if let Some(args) = this.init_args.take() {
426 this.thread.resume_inner(&lua, args)?
427 } else {
428 this.thread.resume_inner(&lua, ())?
429 };
430
431 if nresults == 1 && is_poll_pending(thread_state) {
432 return Poll::Pending;
433 }
434
435 check_stack(state, nresults + 1)?;
436 ffi::lua_xmove(thread_state, state, nresults);
437
438 cx.waker().wake_by_ref();
439 Poll::Ready(Some(R::from_stack_multi(nresults, &lua)))
440 }
441 }
442}
443
444#[cfg(feature = "async")]
445impl<A: IntoLuaMulti, R: FromLuaMulti> Future for AsyncThread<A, R> {
446 type Output = Result<R>;
447
448 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
449 let lua = self.thread.0.lua.lock();
450 if self.thread.status_inner(&lua) != ThreadStatus::Resumable {
451 return Poll::Ready(Err(Error::CoroutineUnresumable));
452 }
453
454 let state = lua.state();
455 let thread_state = self.thread.state();
456 unsafe {
457 let _sg = StackGuard::new(state);
458 let _thread_sg = StackGuard::with_top(thread_state, 0);
459 let _wg = WakerGuard::new(&lua, cx.waker());
460
461 let this = self.get_unchecked_mut();
463 let nresults = if let Some(args) = this.init_args.take() {
464 this.thread.resume_inner(&lua, args)?
465 } else {
466 this.thread.resume_inner(&lua, ())?
467 };
468
469 if nresults == 1 && is_poll_pending(thread_state) {
470 return Poll::Pending;
471 }
472
473 if ffi::lua_status(thread_state) == ffi::LUA_YIELD {
474 cx.waker().wake_by_ref();
476 return Poll::Pending;
477 }
478
479 check_stack(state, nresults + 1)?;
480 ffi::lua_xmove(thread_state, state, nresults);
481
482 Poll::Ready(R::from_stack_multi(nresults, &lua))
483 }
484 }
485}
486
487#[cfg(feature = "async")]
488#[inline(always)]
489unsafe fn is_poll_pending(state: *mut ffi::lua_State) -> bool {
490 ffi::lua_tolightuserdata(state, -1) == Lua::poll_pending().0
491}
492
493#[cfg(feature = "async")]
494struct WakerGuard<'lua, 'a> {
495 lua: &'lua RawLua,
496 prev: NonNull<Waker>,
497 _phantom: PhantomData<&'a ()>,
498}
499
500#[cfg(feature = "async")]
501impl<'lua, 'a> WakerGuard<'lua, 'a> {
502 #[inline]
503 pub fn new(lua: &'lua RawLua, waker: &'a Waker) -> Result<WakerGuard<'lua, 'a>> {
504 let prev = unsafe { lua.set_waker(NonNull::from(waker)) };
505 Ok(WakerGuard {
506 lua,
507 prev,
508 _phantom: PhantomData,
509 })
510 }
511}
512
513#[cfg(feature = "async")]
514impl<'lua, 'a> Drop for WakerGuard<'lua, 'a> {
515 fn drop(&mut self) {
516 unsafe { self.lua.set_waker(self.prev) };
517 }
518}
519
520#[cfg(test)]
521mod assertions {
522 use super::*;
523
524 #[cfg(not(feature = "send"))]
525 static_assertions::assert_not_impl_any!(Thread: Send);
526 #[cfg(feature = "send")]
527 static_assertions::assert_impl_all!(Thread: Send, Sync);
528 #[cfg(all(feature = "async", not(feature = "send")))]
529 static_assertions::assert_not_impl_any!(AsyncThread<(), ()>: Send);
530 #[cfg(all(feature = "async", feature = "send"))]
531 static_assertions::assert_impl_all!(AsyncThread<(), ()>: Send, Sync);
532}