factorio_mlua/
thread.rs

1use std::cmp;
2use std::os::raw::c_int;
3
4use crate::error::{Error, Result};
5use crate::ffi;
6use crate::types::LuaRef;
7use crate::util::{check_stack, error_traceback, pop_error, StackGuard};
8use crate::value::{FromLuaMulti, ToLuaMulti};
9
10#[cfg(any(
11    feature = "lua54",
12    all(feature = "luajit", feature = "vendored"),
13    feature = "luau",
14))]
15use crate::function::Function;
16
17#[cfg(feature = "async")]
18use {
19    crate::{
20        lua::{Lua, ASYNC_POLL_PENDING},
21        value::{MultiValue, Value},
22    },
23    futures_core::{future::Future, stream::Stream},
24    std::{
25        cell::RefCell,
26        marker::PhantomData,
27        pin::Pin,
28        task::{Context, Poll, Waker},
29    },
30};
31
32/// Status of a Lua thread (or coroutine).
33#[derive(Debug, Copy, Clone, Eq, PartialEq)]
34pub enum ThreadStatus {
35    /// The thread was just created, or is suspended because it has called `coroutine.yield`.
36    ///
37    /// If a thread is in this state, it can be resumed by calling [`Thread::resume`].
38    ///
39    /// [`Thread::resume`]: crate::Thread::resume
40    Resumable,
41    /// Either the thread has finished executing, or the thread is currently running.
42    Unresumable,
43    /// The thread has raised a Lua error during execution.
44    Error,
45}
46
47/// Handle to an internal Lua thread (or coroutine).
48#[derive(Clone, Debug)]
49pub struct Thread<'lua>(pub(crate) LuaRef<'lua>);
50
51/// Thread (coroutine) representation as an async [`Future`] or [`Stream`].
52///
53/// Requires `feature = "async"`
54///
55/// [`Future`]: futures_core::future::Future
56/// [`Stream`]: futures_core::stream::Stream
57#[cfg(feature = "async")]
58#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
59#[derive(Debug)]
60pub struct AsyncThread<'lua, R> {
61    thread: Thread<'lua>,
62    args0: RefCell<Option<Result<MultiValue<'lua>>>>,
63    ret: PhantomData<R>,
64    recycle: bool,
65}
66
67impl<'lua> Thread<'lua> {
68    /// Resumes execution of this thread.
69    ///
70    /// Equivalent to `coroutine.resume`.
71    ///
72    /// Passes `args` as arguments to the thread. If the coroutine has called `coroutine.yield`, it
73    /// will return these arguments. Otherwise, the coroutine wasn't yet started, so the arguments
74    /// are passed to its main function.
75    ///
76    /// If the thread is no longer in `Active` state (meaning it has finished execution or
77    /// encountered an error), this will return `Err(CoroutineInactive)`, otherwise will return `Ok`
78    /// as follows:
79    ///
80    /// If the thread calls `coroutine.yield`, returns the values passed to `yield`. If the thread
81    /// `return`s values from its main function, returns those.
82    ///
83    /// # Examples
84    ///
85    /// ```
86    /// # use mlua::{Error, Lua, Result, Thread};
87    /// # fn main() -> Result<()> {
88    /// # let lua = Lua::new();
89    /// let thread: Thread = lua.load(r#"
90    ///     coroutine.create(function(arg)
91    ///         assert(arg == 42)
92    ///         local yieldarg = coroutine.yield(123)
93    ///         assert(yieldarg == 43)
94    ///         return 987
95    ///     end)
96    /// "#).eval()?;
97    ///
98    /// assert_eq!(thread.resume::<_, u32>(42)?, 123);
99    /// assert_eq!(thread.resume::<_, u32>(43)?, 987);
100    ///
101    /// // The coroutine has now returned, so `resume` will fail
102    /// match thread.resume::<_, u32>(()) {
103    ///     Err(Error::CoroutineInactive) => {},
104    ///     unexpected => panic!("unexpected result {:?}", unexpected),
105    /// }
106    /// # Ok(())
107    /// # }
108    /// ```
109    #[cfg(any(not(feature = "lua-factorio"), doc))]
110    pub fn resume<A, R>(&self, args: A) -> Result<R>
111    where
112        A: ToLuaMulti<'lua>,
113        R: FromLuaMulti<'lua>,
114    {
115        let lua = self.0.lua;
116        let mut args = args.to_lua_multi(lua)?;
117        let nargs = args.len() as c_int;
118        let results = unsafe {
119            let _sg = StackGuard::new(lua.state);
120            check_stack(lua.state, cmp::max(nargs + 1, 3))?;
121
122            let thread_state =
123                lua.ref_thread_exec(|ref_thread| ffi::lua_tothread(ref_thread, self.0.index));
124
125            let status = ffi::lua_status(thread_state);
126            if status != ffi::LUA_YIELD && ffi::lua_gettop(thread_state) == 0 {
127                return Err(Error::CoroutineInactive);
128            }
129
130            check_stack(thread_state, nargs)?;
131            for arg in args.drain_all() {
132                lua.push_value(arg)?;
133            }
134            ffi::lua_xmove(lua.state, thread_state, nargs);
135
136            let mut nresults = 0;
137
138            let ret = ffi::lua_resume(thread_state, lua.state, nargs, &mut nresults as *mut c_int);
139            if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD {
140                protect_lua!(lua.state, 0, 0, |_| error_traceback(thread_state))?;
141                return Err(pop_error(thread_state, ret));
142            }
143
144            let mut results = args; // Reuse MultiValue container
145            check_stack(lua.state, nresults + 2)?; // 2 is extra for `lua.pop_value()` below
146            ffi::lua_xmove(thread_state, lua.state, nresults);
147
148            for _ in 0..nresults {
149                results.push_front(lua.pop_value());
150            }
151            results
152        };
153        R::from_lua_multi(results, lua)
154    }
155
156    /// Gets the status of the thread.
157    #[cfg(not(feature = "lua-factorio"))]
158    pub fn status(&self) -> ThreadStatus {
159        let lua = self.0.lua;
160        unsafe {
161            let thread_state =
162                lua.ref_thread_exec(|ref_thread| ffi::lua_tothread(ref_thread, self.0.index));
163
164            let status = ffi::lua_status(thread_state);
165            if status != ffi::LUA_OK && status != ffi::LUA_YIELD {
166                ThreadStatus::Error
167            } else if status == ffi::LUA_YIELD || ffi::lua_gettop(thread_state) > 0 {
168                ThreadStatus::Resumable
169            } else {
170                ThreadStatus::Unresumable
171            }
172        }
173    }
174
175    /// Resets a thread
176    ///
177    /// In [Lua 5.4]: cleans its call stack and closes all pending to-be-closed variables.
178    /// Returns a error in case of either the original error that stopped the thread or errors
179    /// in closing methods.
180    ///
181    /// In [LuaJIT] and Luau: resets to the initial state of a newly created Lua thread.
182    /// Lua threads in arbitrary states (like yielded or errored) can be reset properly.
183    ///
184    /// Sets a Lua function for the thread afterwards.
185    ///
186    /// Requires `feature = "lua54"` OR `feature = "luajit,vendored"` OR `feature = "luau"`
187    ///
188    /// [Lua 5.4]: https://www.lua.org/manual/5.4/manual.html#lua_resetthread
189    /// [LuaJIT]: https://github.com/openresty/luajit2#lua_resetthread
190    #[cfg(any(
191        feature = "lua54",
192        all(feature = "luajit", feature = "vendored"),
193        feature = "luau",
194    ))]
195    pub fn reset(&self, func: Function<'lua>) -> Result<()> {
196        let lua = self.0.lua;
197        unsafe {
198            let _sg = StackGuard::new(lua.state);
199            check_stack(lua.state, 2)?;
200
201            lua.push_ref(&self.0);
202            let thread_state = ffi::lua_tothread(lua.state, -1);
203
204            #[cfg(feature = "lua54")]
205            let status = ffi::lua_resetthread(thread_state);
206            #[cfg(feature = "lua54")]
207            if status != ffi::LUA_OK {
208                return Err(pop_error(thread_state, status));
209            }
210            #[cfg(all(feature = "luajit", feature = "vendored"))]
211            ffi::lua_resetthread(lua.state, thread_state);
212            #[cfg(feature = "luau")]
213            ffi::lua_resetthread(thread_state);
214
215            lua.push_ref(&func.0);
216            ffi::lua_xmove(lua.state, thread_state, 1);
217
218            #[cfg(feature = "luau")]
219            {
220                // Inherit `LUA_GLOBALSINDEX` from the caller
221                ffi::lua_xpush(lua.state, thread_state, ffi::LUA_GLOBALSINDEX);
222                ffi::lua_replace(thread_state, ffi::LUA_GLOBALSINDEX);
223            }
224
225            Ok(())
226        }
227    }
228
229    /// Converts Thread to an AsyncThread which implements [`Future`] and [`Stream`] traits.
230    ///
231    /// `args` are passed as arguments to the thread function for first call.
232    /// The object calls [`resume()`] while polling and also allows to run rust futures
233    /// to completion using an executor.
234    ///
235    /// Using AsyncThread as a Stream allows to iterate through `coroutine.yield()`
236    /// values whereas Future version discards that values and poll until the final
237    /// one (returned from the thread function).
238    ///
239    /// Requires `feature = "async"`
240    ///
241    /// [`Future`]: futures_core::future::Future
242    /// [`Stream`]: futures_core::stream::Stream
243    /// [`resume()`]: https://www.lua.org/manual/5.4/manual.html#lua_resume
244    ///
245    /// # Examples
246    ///
247    /// ```
248    /// # use mlua::{Lua, Result, Thread};
249    /// use futures::stream::TryStreamExt;
250    /// # #[tokio::main]
251    /// # async fn main() -> Result<()> {
252    /// # let lua = Lua::new();
253    /// let thread: Thread = lua.load(r#"
254    ///     coroutine.create(function (sum)
255    ///         for i = 1,10 do
256    ///             sum = sum + i
257    ///             coroutine.yield(sum)
258    ///         end
259    ///         return sum
260    ///     end)
261    /// "#).eval()?;
262    ///
263    /// let mut stream = thread.into_async::<_, i64>(1);
264    /// let mut sum = 0;
265    /// while let Some(n) = stream.try_next().await? {
266    ///     sum += n;
267    /// }
268    ///
269    /// assert_eq!(sum, 286);
270    ///
271    /// # Ok(())
272    /// # }
273    /// ```
274    #[cfg(feature = "async")]
275    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
276    pub fn into_async<A, R>(self, args: A) -> AsyncThread<'lua, R>
277    where
278        A: ToLuaMulti<'lua>,
279        R: FromLuaMulti<'lua>,
280    {
281        let args = args.to_lua_multi(self.0.lua);
282        AsyncThread {
283            thread: self,
284            args0: RefCell::new(Some(args)),
285            ret: PhantomData,
286            recycle: false,
287        }
288    }
289
290    /// Enables sandbox mode on this thread.
291    ///
292    /// Under the hood replaces the global environment table with a new table,
293    /// that performs writes locally and proxies reads to caller's global environment.
294    ///
295    /// This mode ideally should be used together with the global sandbox mode [`Lua::sandbox()`].
296    ///
297    /// Please note that Luau links environment table with chunk when loading it into Lua state.
298    /// Therefore you need to load chunks into a thread to link with the thread environment.
299    ///
300    /// # Examples
301    ///
302    /// ```
303    /// # use mlua::{Lua, Result};
304    /// # fn main() -> Result<()> {
305    /// let lua = Lua::new();
306    /// let thread = lua.create_thread(lua.create_function(|lua2, ()| {
307    ///     lua2.load("var = 123").exec()?;
308    ///     assert_eq!(lua2.globals().get::<_, u32>("var")?, 123);
309    ///     Ok(())
310    /// })?)?;
311    /// thread.sandbox()?;
312    /// thread.resume(())?;
313    ///
314    /// // The global environment should be unchanged
315    /// assert_eq!(lua.globals().get::<_, Option<u32>>("var")?, None);
316    /// # Ok(())
317    /// # }
318    /// ```
319    ///
320    /// Requires `feature = "luau"`
321    #[cfg(any(feature = "luau", docsrs))]
322    #[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
323    #[doc(hidden)]
324    pub fn sandbox(&self) -> Result<()> {
325        let lua = self.0.lua;
326        unsafe {
327            let thread = lua.ref_thread_exec(|t| ffi::lua_tothread(t, self.0.index));
328            check_stack(thread, 1)?;
329            check_stack(lua.state, 3)?;
330            // Inherit `LUA_GLOBALSINDEX` from the caller
331            ffi::lua_xpush(lua.state, thread, ffi::LUA_GLOBALSINDEX);
332            ffi::lua_replace(thread, ffi::LUA_GLOBALSINDEX);
333            protect_lua!(lua.state, 0, 0, |_| ffi::luaL_sandboxthread(thread))
334        }
335    }
336}
337
338impl<'lua> PartialEq for Thread<'lua> {
339    fn eq(&self, other: &Self) -> bool {
340        self.0 == other.0
341    }
342}
343
344#[cfg(feature = "async")]
345impl<'lua, R> AsyncThread<'lua, R> {
346    #[inline]
347    pub(crate) fn set_recyclable(&mut self, recyclable: bool) {
348        self.recycle = recyclable;
349    }
350}
351
352#[cfg(feature = "async")]
353#[cfg(any(
354    feature = "lua54",
355    all(feature = "luajit", feature = "vendored"),
356    feature = "luau",
357))]
358impl<'lua, R> Drop for AsyncThread<'lua, R> {
359    fn drop(&mut self) {
360        if self.recycle {
361            unsafe {
362                self.thread.0.lua.recycle_thread(&mut self.thread);
363            }
364        }
365    }
366}
367
368#[cfg(feature = "async")]
369impl<'lua, R> Stream for AsyncThread<'lua, R>
370where
371    R: FromLuaMulti<'lua>,
372{
373    type Item = Result<R>;
374
375    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
376        let lua = self.thread.0.lua;
377
378        match self.thread.status() {
379            ThreadStatus::Resumable => {}
380            _ => return Poll::Ready(None),
381        };
382
383        let _wg = WakerGuard::new(lua, cx.waker().clone());
384        let ret: MultiValue = if let Some(args) = self.args0.borrow_mut().take() {
385            self.thread.resume(args?)?
386        } else {
387            self.thread.resume(())?
388        };
389
390        if is_poll_pending(&ret) {
391            return Poll::Pending;
392        }
393
394        cx.waker().wake_by_ref();
395        Poll::Ready(Some(R::from_lua_multi(ret, lua)))
396    }
397}
398
399#[cfg(feature = "async")]
400impl<'lua, R> Future for AsyncThread<'lua, R>
401where
402    R: FromLuaMulti<'lua>,
403{
404    type Output = Result<R>;
405
406    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
407        let lua = self.thread.0.lua;
408
409        match self.thread.status() {
410            ThreadStatus::Resumable => {}
411            _ => return Poll::Ready(Err(Error::CoroutineInactive)),
412        };
413
414        let _wg = WakerGuard::new(lua, cx.waker().clone());
415        let ret: MultiValue = if let Some(args) = self.args0.borrow_mut().take() {
416            self.thread.resume(args?)?
417        } else {
418            self.thread.resume(())?
419        };
420
421        if is_poll_pending(&ret) {
422            return Poll::Pending;
423        }
424
425        if let ThreadStatus::Resumable = self.thread.status() {
426            // Ignore value returned via yield()
427            cx.waker().wake_by_ref();
428            return Poll::Pending;
429        }
430
431        Poll::Ready(R::from_lua_multi(ret, lua))
432    }
433}
434
435#[cfg(feature = "async")]
436#[inline(always)]
437fn is_poll_pending(val: &MultiValue) -> bool {
438    match val.iter().enumerate().last() {
439        Some((0, Value::LightUserData(ud))) => {
440            std::ptr::eq(ud.0 as *const u8, &ASYNC_POLL_PENDING as *const u8)
441        }
442        _ => false,
443    }
444}
445
446#[cfg(feature = "async")]
447struct WakerGuard<'lua> {
448    lua: &'lua Lua,
449    prev: Option<Waker>,
450}
451
452#[cfg(feature = "async")]
453impl<'lua> WakerGuard<'lua> {
454    #[inline]
455    pub fn new(lua: &Lua, waker: Waker) -> Result<WakerGuard> {
456        unsafe {
457            let prev = lua.set_waker(Some(waker));
458            Ok(WakerGuard { lua, prev })
459        }
460    }
461}
462
463#[cfg(feature = "async")]
464impl<'lua> Drop for WakerGuard<'lua> {
465    fn drop(&mut self) {
466        unsafe {
467            self.lua.set_waker(self.prev.take());
468        }
469    }
470}