1use std::cmp;
2use std::os::raw::c_int;
3
4use crate::error::{Error, Result};
5use crate::ffi;
6use crate::types::LuaRef;
7use crate::util::{check_stack, error_traceback, pop_error, StackGuard};
8use crate::value::{FromLuaMulti, ToLuaMulti};
9
10#[cfg(any(
11 feature = "lua54",
12 all(feature = "luajit", feature = "vendored"),
13 feature = "luau",
14))]
15use crate::function::Function;
16
17#[cfg(feature = "async")]
18use {
19 crate::{
20 lua::{Lua, ASYNC_POLL_PENDING},
21 value::{MultiValue, Value},
22 },
23 futures_core::{future::Future, stream::Stream},
24 std::{
25 cell::RefCell,
26 marker::PhantomData,
27 pin::Pin,
28 task::{Context, Poll, Waker},
29 },
30};
31
32#[derive(Debug, Copy, Clone, Eq, PartialEq)]
34pub enum ThreadStatus {
35 Resumable,
41 Unresumable,
43 Error,
45}
46
47#[derive(Clone, Debug)]
49pub struct Thread<'lua>(pub(crate) LuaRef<'lua>);
50
51#[cfg(feature = "async")]
58#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
59#[derive(Debug)]
60pub struct AsyncThread<'lua, R> {
61 thread: Thread<'lua>,
62 args0: RefCell<Option<Result<MultiValue<'lua>>>>,
63 ret: PhantomData<R>,
64 recycle: bool,
65}
66
67impl<'lua> Thread<'lua> {
68 #[cfg(any(not(feature = "lua-factorio"), doc))]
110 pub fn resume<A, R>(&self, args: A) -> Result<R>
111 where
112 A: ToLuaMulti<'lua>,
113 R: FromLuaMulti<'lua>,
114 {
115 let lua = self.0.lua;
116 let mut args = args.to_lua_multi(lua)?;
117 let nargs = args.len() as c_int;
118 let results = unsafe {
119 let _sg = StackGuard::new(lua.state);
120 check_stack(lua.state, cmp::max(nargs + 1, 3))?;
121
122 let thread_state =
123 lua.ref_thread_exec(|ref_thread| ffi::lua_tothread(ref_thread, self.0.index));
124
125 let status = ffi::lua_status(thread_state);
126 if status != ffi::LUA_YIELD && ffi::lua_gettop(thread_state) == 0 {
127 return Err(Error::CoroutineInactive);
128 }
129
130 check_stack(thread_state, nargs)?;
131 for arg in args.drain_all() {
132 lua.push_value(arg)?;
133 }
134 ffi::lua_xmove(lua.state, thread_state, nargs);
135
136 let mut nresults = 0;
137
138 let ret = ffi::lua_resume(thread_state, lua.state, nargs, &mut nresults as *mut c_int);
139 if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD {
140 protect_lua!(lua.state, 0, 0, |_| error_traceback(thread_state))?;
141 return Err(pop_error(thread_state, ret));
142 }
143
144 let mut results = args; check_stack(lua.state, nresults + 2)?; ffi::lua_xmove(thread_state, lua.state, nresults);
147
148 for _ in 0..nresults {
149 results.push_front(lua.pop_value());
150 }
151 results
152 };
153 R::from_lua_multi(results, lua)
154 }
155
156 #[cfg(not(feature = "lua-factorio"))]
158 pub fn status(&self) -> ThreadStatus {
159 let lua = self.0.lua;
160 unsafe {
161 let thread_state =
162 lua.ref_thread_exec(|ref_thread| ffi::lua_tothread(ref_thread, self.0.index));
163
164 let status = ffi::lua_status(thread_state);
165 if status != ffi::LUA_OK && status != ffi::LUA_YIELD {
166 ThreadStatus::Error
167 } else if status == ffi::LUA_YIELD || ffi::lua_gettop(thread_state) > 0 {
168 ThreadStatus::Resumable
169 } else {
170 ThreadStatus::Unresumable
171 }
172 }
173 }
174
175 #[cfg(any(
191 feature = "lua54",
192 all(feature = "luajit", feature = "vendored"),
193 feature = "luau",
194 ))]
195 pub fn reset(&self, func: Function<'lua>) -> Result<()> {
196 let lua = self.0.lua;
197 unsafe {
198 let _sg = StackGuard::new(lua.state);
199 check_stack(lua.state, 2)?;
200
201 lua.push_ref(&self.0);
202 let thread_state = ffi::lua_tothread(lua.state, -1);
203
204 #[cfg(feature = "lua54")]
205 let status = ffi::lua_resetthread(thread_state);
206 #[cfg(feature = "lua54")]
207 if status != ffi::LUA_OK {
208 return Err(pop_error(thread_state, status));
209 }
210 #[cfg(all(feature = "luajit", feature = "vendored"))]
211 ffi::lua_resetthread(lua.state, thread_state);
212 #[cfg(feature = "luau")]
213 ffi::lua_resetthread(thread_state);
214
215 lua.push_ref(&func.0);
216 ffi::lua_xmove(lua.state, thread_state, 1);
217
218 #[cfg(feature = "luau")]
219 {
220 ffi::lua_xpush(lua.state, thread_state, ffi::LUA_GLOBALSINDEX);
222 ffi::lua_replace(thread_state, ffi::LUA_GLOBALSINDEX);
223 }
224
225 Ok(())
226 }
227 }
228
229 #[cfg(feature = "async")]
275 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
276 pub fn into_async<A, R>(self, args: A) -> AsyncThread<'lua, R>
277 where
278 A: ToLuaMulti<'lua>,
279 R: FromLuaMulti<'lua>,
280 {
281 let args = args.to_lua_multi(self.0.lua);
282 AsyncThread {
283 thread: self,
284 args0: RefCell::new(Some(args)),
285 ret: PhantomData,
286 recycle: false,
287 }
288 }
289
290 #[cfg(any(feature = "luau", docsrs))]
322 #[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
323 #[doc(hidden)]
324 pub fn sandbox(&self) -> Result<()> {
325 let lua = self.0.lua;
326 unsafe {
327 let thread = lua.ref_thread_exec(|t| ffi::lua_tothread(t, self.0.index));
328 check_stack(thread, 1)?;
329 check_stack(lua.state, 3)?;
330 ffi::lua_xpush(lua.state, thread, ffi::LUA_GLOBALSINDEX);
332 ffi::lua_replace(thread, ffi::LUA_GLOBALSINDEX);
333 protect_lua!(lua.state, 0, 0, |_| ffi::luaL_sandboxthread(thread))
334 }
335 }
336}
337
338impl<'lua> PartialEq for Thread<'lua> {
339 fn eq(&self, other: &Self) -> bool {
340 self.0 == other.0
341 }
342}
343
344#[cfg(feature = "async")]
345impl<'lua, R> AsyncThread<'lua, R> {
346 #[inline]
347 pub(crate) fn set_recyclable(&mut self, recyclable: bool) {
348 self.recycle = recyclable;
349 }
350}
351
352#[cfg(feature = "async")]
353#[cfg(any(
354 feature = "lua54",
355 all(feature = "luajit", feature = "vendored"),
356 feature = "luau",
357))]
358impl<'lua, R> Drop for AsyncThread<'lua, R> {
359 fn drop(&mut self) {
360 if self.recycle {
361 unsafe {
362 self.thread.0.lua.recycle_thread(&mut self.thread);
363 }
364 }
365 }
366}
367
368#[cfg(feature = "async")]
369impl<'lua, R> Stream for AsyncThread<'lua, R>
370where
371 R: FromLuaMulti<'lua>,
372{
373 type Item = Result<R>;
374
375 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
376 let lua = self.thread.0.lua;
377
378 match self.thread.status() {
379 ThreadStatus::Resumable => {}
380 _ => return Poll::Ready(None),
381 };
382
383 let _wg = WakerGuard::new(lua, cx.waker().clone());
384 let ret: MultiValue = if let Some(args) = self.args0.borrow_mut().take() {
385 self.thread.resume(args?)?
386 } else {
387 self.thread.resume(())?
388 };
389
390 if is_poll_pending(&ret) {
391 return Poll::Pending;
392 }
393
394 cx.waker().wake_by_ref();
395 Poll::Ready(Some(R::from_lua_multi(ret, lua)))
396 }
397}
398
399#[cfg(feature = "async")]
400impl<'lua, R> Future for AsyncThread<'lua, R>
401where
402 R: FromLuaMulti<'lua>,
403{
404 type Output = Result<R>;
405
406 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
407 let lua = self.thread.0.lua;
408
409 match self.thread.status() {
410 ThreadStatus::Resumable => {}
411 _ => return Poll::Ready(Err(Error::CoroutineInactive)),
412 };
413
414 let _wg = WakerGuard::new(lua, cx.waker().clone());
415 let ret: MultiValue = if let Some(args) = self.args0.borrow_mut().take() {
416 self.thread.resume(args?)?
417 } else {
418 self.thread.resume(())?
419 };
420
421 if is_poll_pending(&ret) {
422 return Poll::Pending;
423 }
424
425 if let ThreadStatus::Resumable = self.thread.status() {
426 cx.waker().wake_by_ref();
428 return Poll::Pending;
429 }
430
431 Poll::Ready(R::from_lua_multi(ret, lua))
432 }
433}
434
435#[cfg(feature = "async")]
436#[inline(always)]
437fn is_poll_pending(val: &MultiValue) -> bool {
438 match val.iter().enumerate().last() {
439 Some((0, Value::LightUserData(ud))) => {
440 std::ptr::eq(ud.0 as *const u8, &ASYNC_POLL_PENDING as *const u8)
441 }
442 _ => false,
443 }
444}
445
446#[cfg(feature = "async")]
447struct WakerGuard<'lua> {
448 lua: &'lua Lua,
449 prev: Option<Waker>,
450}
451
452#[cfg(feature = "async")]
453impl<'lua> WakerGuard<'lua> {
454 #[inline]
455 pub fn new(lua: &Lua, waker: Waker) -> Result<WakerGuard> {
456 unsafe {
457 let prev = lua.set_waker(Some(waker));
458 Ok(WakerGuard { lua, prev })
459 }
460 }
461}
462
463#[cfg(feature = "async")]
464impl<'lua> Drop for WakerGuard<'lua> {
465 fn drop(&mut self) {
466 unsafe {
467 self.lua.set_waker(self.prev.take());
468 }
469 }
470}