mlua_codemp_patch/
thread.rs

1use std::os::raw::{c_int, c_void};
2
3use crate::error::{Error, Result};
4#[allow(unused)]
5use crate::state::Lua;
6use crate::state::RawLua;
7use crate::types::ValueRef;
8use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard};
9use crate::value::{FromLuaMulti, IntoLuaMulti};
10
11#[cfg(not(feature = "luau"))]
12use crate::{
13    hook::{Debug, HookTriggers},
14    types::MaybeSend,
15};
16
17#[cfg(feature = "async")]
18use {
19    futures_util::stream::Stream,
20    std::{
21        future::Future,
22        marker::PhantomData,
23        pin::Pin,
24        ptr::NonNull,
25        task::{Context, Poll, Waker},
26    },
27};
28
29/// Status of a Lua thread (coroutine).
30#[derive(Debug, Copy, Clone, Eq, PartialEq)]
31pub enum ThreadStatus {
32    /// The thread was just created or is suspended (yielded).
33    ///
34    /// If a thread is in this state, it can be resumed by calling [`Thread::resume`].
35    Resumable,
36    /// The thread is currently running.
37    Running,
38    /// The thread has finished executing.
39    Finished,
40    /// The thread has raised a Lua error during execution.
41    Error,
42}
43
44/// Handle to an internal Lua thread (coroutine).
45#[derive(Clone, Debug)]
46pub struct Thread(pub(crate) ValueRef, pub(crate) *mut ffi::lua_State);
47
48#[cfg(feature = "send")]
49unsafe impl Send for Thread {}
50#[cfg(feature = "send")]
51unsafe impl Sync for Thread {}
52
53/// Thread (coroutine) representation as an async [`Future`] or [`Stream`].
54///
55/// Requires `feature = "async"`
56///
57/// [`Future`]: std::future::Future
58/// [`Stream`]: futures_util::stream::Stream
59#[cfg(feature = "async")]
60#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
61#[must_use = "futures do nothing unless you `.await` or poll them"]
62pub struct AsyncThread<A, R> {
63    thread: Thread,
64    init_args: Option<A>,
65    ret: PhantomData<R>,
66    recycle: bool,
67}
68
69impl Thread {
70    #[inline(always)]
71    const fn state(&self) -> *mut ffi::lua_State {
72        self.1
73    }
74
75    /// Resumes execution of this thread.
76    ///
77    /// Equivalent to `coroutine.resume`.
78    ///
79    /// Passes `args` as arguments to the thread. If the coroutine has called `coroutine.yield`, it
80    /// will return these arguments. Otherwise, the coroutine wasn't yet started, so the arguments
81    /// are passed to its main function.
82    ///
83    /// If the thread is no longer in `Active` state (meaning it has finished execution or
84    /// encountered an error), this will return `Err(CoroutineInactive)`, otherwise will return `Ok`
85    /// as follows:
86    ///
87    /// If the thread calls `coroutine.yield`, returns the values passed to `yield`. If the thread
88    /// `return`s values from its main function, returns those.
89    ///
90    /// # Examples
91    ///
92    /// ```
93    /// # use mlua::{Error, Lua, Result, Thread};
94    /// # fn main() -> Result<()> {
95    /// # let lua = Lua::new();
96    /// let thread: Thread = lua.load(r#"
97    ///     coroutine.create(function(arg)
98    ///         assert(arg == 42)
99    ///         local yieldarg = coroutine.yield(123)
100    ///         assert(yieldarg == 43)
101    ///         return 987
102    ///     end)
103    /// "#).eval()?;
104    ///
105    /// assert_eq!(thread.resume::<u32>(42)?, 123);
106    /// assert_eq!(thread.resume::<u32>(43)?, 987);
107    ///
108    /// // The coroutine has now returned, so `resume` will fail
109    /// match thread.resume::<u32>(()) {
110    ///     Err(Error::CoroutineUnresumable) => {},
111    ///     unexpected => panic!("unexpected result {:?}", unexpected),
112    /// }
113    /// # Ok(())
114    /// # }
115    /// ```
116    pub fn resume<R>(&self, args: impl IntoLuaMulti) -> Result<R>
117    where
118        R: FromLuaMulti,
119    {
120        let lua = self.0.lua.lock();
121        if self.status_inner(&lua) != ThreadStatus::Resumable {
122            return Err(Error::CoroutineUnresumable);
123        }
124
125        let state = lua.state();
126        let thread_state = self.state();
127        unsafe {
128            let _sg = StackGuard::new(state);
129            let _thread_sg = StackGuard::with_top(thread_state, 0);
130
131            let nresults = self.resume_inner(&lua, args)?;
132            check_stack(state, nresults + 1)?;
133            ffi::lua_xmove(thread_state, state, nresults);
134
135            R::from_stack_multi(nresults, &lua)
136        }
137    }
138
139    /// Resumes execution of this thread.
140    ///
141    /// It's similar to `resume()` but leaves `nresults` values on the thread stack.
142    unsafe fn resume_inner(&self, lua: &RawLua, args: impl IntoLuaMulti) -> Result<c_int> {
143        let state = lua.state();
144        let thread_state = self.state();
145
146        let nargs = args.push_into_stack_multi(&lua)?;
147        if nargs > 0 {
148            check_stack(thread_state, nargs)?;
149            ffi::lua_xmove(state, thread_state, nargs);
150        }
151
152        let mut nresults = 0;
153        let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int);
154        if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD {
155            if ret == ffi::LUA_ERRMEM {
156                // Don't call error handler for memory errors
157                return Err(pop_error(thread_state, ret));
158            }
159            check_stack(state, 3)?;
160            protect_lua!(state, 0, 1, |state| error_traceback_thread(state, thread_state))?;
161            return Err(pop_error(state, ret));
162        }
163
164        Ok(nresults)
165    }
166
167    /// Gets the status of the thread.
168    pub fn status(&self) -> ThreadStatus {
169        self.status_inner(&self.0.lua.lock())
170    }
171
172    /// Gets the status of the thread (internal implementation).
173    pub(crate) fn status_inner(&self, lua: &RawLua) -> ThreadStatus {
174        let thread_state = self.state();
175        if thread_state == lua.state() {
176            // The thread is currently running
177            return ThreadStatus::Running;
178        }
179        let status = unsafe { ffi::lua_status(thread_state) };
180        if status != ffi::LUA_OK && status != ffi::LUA_YIELD {
181            ThreadStatus::Error
182        } else if status == ffi::LUA_YIELD || unsafe { ffi::lua_gettop(thread_state) > 0 } {
183            ThreadStatus::Resumable
184        } else {
185            ThreadStatus::Finished
186        }
187    }
188
189    /// Sets a 'hook' function that will periodically be called as Lua code executes.
190    ///
191    /// This function is similar or [`Lua::set_hook()`] except that it sets for the thread.
192    /// To remove a hook call [`Lua::remove_hook()`].
193    #[cfg(not(feature = "luau"))]
194    #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
195    pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F)
196    where
197        F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static,
198    {
199        let lua = self.0.lua.lock();
200        unsafe {
201            lua.set_thread_hook(self.state(), triggers, callback);
202        }
203    }
204
205    /// Resets a thread
206    ///
207    /// In [Lua 5.4]: cleans its call stack and closes all pending to-be-closed variables.
208    /// Returns a error in case of either the original error that stopped the thread or errors
209    /// in closing methods.
210    ///
211    /// In Luau: resets to the initial state of a newly created Lua thread.
212    /// Lua threads in arbitrary states (like yielded or errored) can be reset properly.
213    ///
214    /// Sets a Lua function for the thread afterwards.
215    ///
216    /// Requires `feature = "lua54"` OR `feature = "luau"`.
217    ///
218    /// [Lua 5.4]: https://www.lua.org/manual/5.4/manual.html#lua_closethread
219    #[cfg(any(feature = "lua54", feature = "luau"))]
220    #[cfg_attr(docsrs, doc(cfg(any(feature = "lua54", feature = "luau"))))]
221    pub fn reset(&self, func: crate::function::Function) -> Result<()> {
222        let lua = self.0.lua.lock();
223        if self.status_inner(&lua) == ThreadStatus::Running {
224            return Err(Error::runtime("cannot reset a running thread"));
225        }
226
227        let thread_state = self.state();
228        unsafe {
229            #[cfg(all(feature = "lua54", not(feature = "vendored")))]
230            let status = ffi::lua_resetthread(thread_state);
231            #[cfg(all(feature = "lua54", feature = "vendored"))]
232            let status = ffi::lua_closethread(thread_state, lua.state());
233            #[cfg(feature = "lua54")]
234            if status != ffi::LUA_OK {
235                return Err(pop_error(thread_state, status));
236            }
237            #[cfg(feature = "luau")]
238            ffi::lua_resetthread(thread_state);
239
240            // Push function to the top of the thread stack
241            ffi::lua_xpush(lua.ref_thread(), thread_state, func.0.index);
242
243            #[cfg(feature = "luau")]
244            {
245                // Inherit `LUA_GLOBALSINDEX` from the main thread
246                ffi::lua_xpush(lua.main_state(), thread_state, ffi::LUA_GLOBALSINDEX);
247                ffi::lua_replace(thread_state, ffi::LUA_GLOBALSINDEX);
248            }
249
250            Ok(())
251        }
252    }
253
254    /// Converts Thread to an AsyncThread which implements [`Future`] and [`Stream`] traits.
255    ///
256    /// `args` are passed as arguments to the thread function for first call.
257    /// The object calls [`resume()`] while polling and also allows to run rust futures
258    /// to completion using an executor.
259    ///
260    /// Using AsyncThread as a Stream allows to iterate through `coroutine.yield()`
261    /// values whereas Future version discards that values and poll until the final
262    /// one (returned from the thread function).
263    ///
264    /// Requires `feature = "async"`
265    ///
266    /// [`Future`]: std::future::Future
267    /// [`Stream`]: futures_util::stream::Stream
268    /// [`resume()`]: https://www.lua.org/manual/5.4/manual.html#lua_resume
269    ///
270    /// # Examples
271    ///
272    /// ```
273    /// # use mlua::{Lua, Result, Thread};
274    /// use futures_util::stream::TryStreamExt;
275    /// # #[tokio::main]
276    /// # async fn main() -> Result<()> {
277    /// # let lua = Lua::new();
278    /// let thread: Thread = lua.load(r#"
279    ///     coroutine.create(function (sum)
280    ///         for i = 1,10 do
281    ///             sum = sum + i
282    ///             coroutine.yield(sum)
283    ///         end
284    ///         return sum
285    ///     end)
286    /// "#).eval()?;
287    ///
288    /// let mut stream = thread.into_async::<i64>(1);
289    /// let mut sum = 0;
290    /// while let Some(n) = stream.try_next().await? {
291    ///     sum += n;
292    /// }
293    ///
294    /// assert_eq!(sum, 286);
295    ///
296    /// # Ok(())
297    /// # }
298    /// ```
299    #[cfg(feature = "async")]
300    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
301    pub fn into_async<R>(self, args: impl IntoLuaMulti) -> AsyncThread<impl IntoLuaMulti, R>
302    where
303        R: FromLuaMulti,
304    {
305        AsyncThread {
306            thread: self,
307            init_args: Some(args),
308            ret: PhantomData,
309            recycle: false,
310        }
311    }
312
313    /// Enables sandbox mode on this thread.
314    ///
315    /// Under the hood replaces the global environment table with a new table,
316    /// that performs writes locally and proxies reads to caller's global environment.
317    ///
318    /// This mode ideally should be used together with the global sandbox mode [`Lua::sandbox()`].
319    ///
320    /// Please note that Luau links environment table with chunk when loading it into Lua state.
321    /// Therefore you need to load chunks into a thread to link with the thread environment.
322    ///
323    /// # Examples
324    ///
325    /// ```
326    /// # use mlua::{Lua, Result};
327    /// # fn main() -> Result<()> {
328    /// let lua = Lua::new();
329    /// let thread = lua.create_thread(lua.create_function(|lua2, ()| {
330    ///     lua2.load("var = 123").exec()?;
331    ///     assert_eq!(lua2.globals().get::<u32>("var")?, 123);
332    ///     Ok(())
333    /// })?)?;
334    /// thread.sandbox()?;
335    /// thread.resume(())?;
336    ///
337    /// // The global environment should be unchanged
338    /// assert_eq!(lua.globals().get::<Option<u32>>("var")?, None);
339    /// # Ok(())
340    /// # }
341    /// ```
342    ///
343    /// Requires `feature = "luau"`
344    #[cfg(any(feature = "luau", docsrs))]
345    #[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
346    #[doc(hidden)]
347    pub fn sandbox(&self) -> Result<()> {
348        let lua = self.0.lua.lock();
349        let state = lua.state();
350        let thread_state = self.state();
351        unsafe {
352            check_stack(thread_state, 3)?;
353            check_stack(state, 3)?;
354            protect_lua!(state, 0, 0, |_| ffi::luaL_sandboxthread(thread_state))
355        }
356    }
357
358    /// Converts this thread to a generic C pointer.
359    ///
360    /// There is no way to convert the pointer back to its original value.
361    ///
362    /// Typically this function is used only for hashing and debug information.
363    #[inline]
364    pub fn to_pointer(&self) -> *const c_void {
365        self.0.to_pointer()
366    }
367}
368
369impl PartialEq for Thread {
370    fn eq(&self, other: &Self) -> bool {
371        self.0 == other.0
372    }
373}
374
375#[cfg(feature = "async")]
376impl<A, R> AsyncThread<A, R> {
377    #[inline]
378    pub(crate) fn set_recyclable(&mut self, recyclable: bool) {
379        self.recycle = recyclable;
380    }
381}
382
383#[cfg(feature = "async")]
384#[cfg(any(feature = "lua54", feature = "luau"))]
385impl<A, R> Drop for AsyncThread<A, R> {
386    fn drop(&mut self) {
387        if self.recycle {
388            if let Some(lua) = self.thread.0.lua.try_lock() {
389                unsafe {
390                    // For Lua 5.4 this also closes all pending to-be-closed variables
391                    if !lua.recycle_thread(&mut self.thread) {
392                        #[cfg(feature = "lua54")]
393                        if self.thread.status_inner(&lua) == ThreadStatus::Error {
394                            #[cfg(not(feature = "vendored"))]
395                            ffi::lua_resetthread(self.thread.state());
396                            #[cfg(feature = "vendored")]
397                            ffi::lua_closethread(self.thread.state(), lua.state());
398                        }
399                    }
400                }
401            }
402        }
403    }
404}
405
406#[cfg(feature = "async")]
407impl<A: IntoLuaMulti, R: FromLuaMulti> Stream for AsyncThread<A, R> {
408    type Item = Result<R>;
409
410    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
411        let lua = self.thread.0.lua.lock();
412        if self.thread.status_inner(&lua) != ThreadStatus::Resumable {
413            return Poll::Ready(None);
414        }
415
416        let state = lua.state();
417        let thread_state = self.thread.state();
418        unsafe {
419            let _sg = StackGuard::new(state);
420            let _thread_sg = StackGuard::with_top(thread_state, 0);
421            let _wg = WakerGuard::new(&lua, cx.waker());
422
423            // This is safe as we are not moving the whole struct
424            let this = self.get_unchecked_mut();
425            let nresults = if let Some(args) = this.init_args.take() {
426                this.thread.resume_inner(&lua, args)?
427            } else {
428                this.thread.resume_inner(&lua, ())?
429            };
430
431            if nresults == 1 && is_poll_pending(thread_state) {
432                return Poll::Pending;
433            }
434
435            check_stack(state, nresults + 1)?;
436            ffi::lua_xmove(thread_state, state, nresults);
437
438            cx.waker().wake_by_ref();
439            Poll::Ready(Some(R::from_stack_multi(nresults, &lua)))
440        }
441    }
442}
443
444#[cfg(feature = "async")]
445impl<A: IntoLuaMulti, R: FromLuaMulti> Future for AsyncThread<A, R> {
446    type Output = Result<R>;
447
448    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
449        let lua = self.thread.0.lua.lock();
450        if self.thread.status_inner(&lua) != ThreadStatus::Resumable {
451            return Poll::Ready(Err(Error::CoroutineUnresumable));
452        }
453
454        let state = lua.state();
455        let thread_state = self.thread.state();
456        unsafe {
457            let _sg = StackGuard::new(state);
458            let _thread_sg = StackGuard::with_top(thread_state, 0);
459            let _wg = WakerGuard::new(&lua, cx.waker());
460
461            // This is safe as we are not moving the whole struct
462            let this = self.get_unchecked_mut();
463            let nresults = if let Some(args) = this.init_args.take() {
464                this.thread.resume_inner(&lua, args)?
465            } else {
466                this.thread.resume_inner(&lua, ())?
467            };
468
469            if nresults == 1 && is_poll_pending(thread_state) {
470                return Poll::Pending;
471            }
472
473            if ffi::lua_status(thread_state) == ffi::LUA_YIELD {
474                // Ignore value returned via yield()
475                cx.waker().wake_by_ref();
476                return Poll::Pending;
477            }
478
479            check_stack(state, nresults + 1)?;
480            ffi::lua_xmove(thread_state, state, nresults);
481
482            Poll::Ready(R::from_stack_multi(nresults, &lua))
483        }
484    }
485}
486
487#[cfg(feature = "async")]
488#[inline(always)]
489unsafe fn is_poll_pending(state: *mut ffi::lua_State) -> bool {
490    ffi::lua_tolightuserdata(state, -1) == Lua::poll_pending().0
491}
492
493#[cfg(feature = "async")]
494struct WakerGuard<'lua, 'a> {
495    lua: &'lua RawLua,
496    prev: NonNull<Waker>,
497    _phantom: PhantomData<&'a ()>,
498}
499
500#[cfg(feature = "async")]
501impl<'lua, 'a> WakerGuard<'lua, 'a> {
502    #[inline]
503    pub fn new(lua: &'lua RawLua, waker: &'a Waker) -> Result<WakerGuard<'lua, 'a>> {
504        let prev = unsafe { lua.set_waker(NonNull::from(waker)) };
505        Ok(WakerGuard {
506            lua,
507            prev,
508            _phantom: PhantomData,
509        })
510    }
511}
512
513#[cfg(feature = "async")]
514impl<'lua, 'a> Drop for WakerGuard<'lua, 'a> {
515    fn drop(&mut self) {
516        unsafe { self.lua.set_waker(self.prev) };
517    }
518}
519
520#[cfg(test)]
521mod assertions {
522    use super::*;
523
524    #[cfg(not(feature = "send"))]
525    static_assertions::assert_not_impl_any!(Thread: Send);
526    #[cfg(feature = "send")]
527    static_assertions::assert_impl_all!(Thread: Send, Sync);
528    #[cfg(all(feature = "async", not(feature = "send")))]
529    static_assertions::assert_not_impl_any!(AsyncThread<(), ()>: Send);
530    #[cfg(all(feature = "async", feature = "send"))]
531    static_assertions::assert_impl_all!(AsyncThread<(), ()>: Send, Sync);
532}