Skip to main content

luaur_rt/
interrupt.rs

1//! Luau interrupt support. Mirrors `mlua::Lua::set_interrupt` / `VmState`.
2//!
3//! Luau's VM calls a single global `interrupt` callback at safepoints (loop
4//! back-edges, calls/returns, GC). mlua exposes this as `Lua::set_interrupt`,
5//! taking a Rust closure that returns a [`VmState`] telling the VM whether to
6//! continue or to **yield** the current coroutine.
7//!
8//! luaur's `lua_callbacks().interrupt` is a plain C function pointer, so we
9//! install a fixed trampoline ([`interrupt_trampoline`]) and keep the Rust
10//! closure in a thread-local keyed by the VM's *global* pointer (shared by all
11//! threads of one `Lua`). The trampoline looks up the closure, runs it with a
12//! borrowed [`Lua`], and:
13//!
14//! * `Ok(VmState::Continue)`  — returns normally; the VM keeps executing.
15//! * `Ok(VmState::Yield)`     — calls `lua_break`, which sets the running
16//!   thread's status so the VM unwinds back to `lua_resume` (a *yield* at a
17//!   yieldable point; ignored otherwise, exactly like upstream Luau).
18//! * `Err(e)`                 — raises `e` as a Lua error via `lua_error`.
19
20use std::cell::RefCell;
21use std::collections::HashMap;
22
23use crate::error::{Error, Result};
24use crate::state::Lua;
25use crate::sys::*;
26
27/// The action an interrupt callback asks the VM to take. Mirrors
28/// `mlua::VmState`.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum VmState {
31    /// Keep executing.
32    Continue,
33    /// Yield the currently running coroutine (no-op at a non-yieldable point).
34    Yield,
35}
36
37type InterruptFn = Box<dyn Fn(&Lua) -> Result<VmState> + 'static>;
38
39thread_local! {
40    /// Per-VM interrupt closure, keyed by the `global_State` pointer (stable for
41    /// the lifetime of the VM and shared by all of its threads).
42    static INTERRUPTS: RefCell<HashMap<*mut core::ffi::c_void, InterruptFn>> =
43        RefCell::new(HashMap::new());
44}
45
46/// The `global_State` pointer for `state` — the per-VM key shared by all
47/// threads of one `Lua`.
48unsafe fn vm_key(state: *mut lua_State) -> *mut core::ffi::c_void {
49    unsafe { (*state).global as *mut core::ffi::c_void }
50}
51
52impl Lua {
53    /// Install an interrupt callback. Mirrors `mlua::Lua::set_interrupt`.
54    ///
55    /// The callback runs at VM safepoints; returning [`VmState::Yield`] yields
56    /// the running coroutine, and returning `Err` raises a Lua error.
57    pub fn set_interrupt<F>(&self, callback: F)
58    where
59        F: Fn(&Lua) -> Result<VmState> + crate::sync::MaybeSend + 'static,
60    {
61        let state = self.state();
62        unsafe {
63            let key = vm_key(state);
64            INTERRUPTS.with(|m| {
65                m.borrow_mut().insert(key, Box::new(callback));
66            });
67            let cb = lua_callbacks(state);
68            (*cb).interrupt = Some(interrupt_trampoline);
69        }
70    }
71
72    /// Remove a previously installed interrupt callback. Mirrors
73    /// `mlua::Lua::remove_interrupt`.
74    pub fn remove_interrupt(&self) {
75        let state = self.state();
76        unsafe {
77            let key = vm_key(state);
78            INTERRUPTS.with(|m| {
79                m.borrow_mut().remove(&key);
80            });
81            let cb = lua_callbacks(state);
82            (*cb).interrupt = None;
83        }
84    }
85}
86
87/// The fixed C trampoline installed as `lua_callbacks().interrupt`.
88///
89/// `gc` is non-negative only for GC interrupts; mlua ignores GC interrupts in
90/// the user callback path, and so do we (return immediately) so the user
91/// closure only sees real instruction safepoints.
92unsafe extern "C-unwind" fn interrupt_trampoline(state: *mut lua_State, gc: c_int) {
93    if gc >= 0 {
94        // GC step interrupt — not surfaced to the user callback.
95        return;
96    }
97    let key = unsafe { vm_key(state) };
98    // Take the closure out of the map for the duration of the call so a
99    // re-entrant `set_interrupt` from inside the callback can't alias the
100    // borrow. Put it back afterwards (unless the callback replaced it).
101    let cb = INTERRUPTS.with(|m| m.borrow_mut().remove(&key));
102    let Some(cb) = cb else { return };
103
104    let lua = unsafe { Lua::from_borrowed(state) };
105    let result = cb(&lua);
106
107    // Restore the closure if the callback didn't install a new one.
108    INTERRUPTS.with(|m| {
109        let mut map = m.borrow_mut();
110        map.entry(key).or_insert(cb);
111    });
112
113    match result {
114        Ok(VmState::Continue) => {}
115        Ok(VmState::Yield) => unsafe {
116            // Request a yield — but only at a yieldable point. Inside a
117            // metamethod / C-call boundary Luau's `lua_break` would raise
118            // "attempt to break across metamethod/C-call boundary"; upstream
119            // (and mlua) silently ignore the yield request there, so we gate it
120            // on `lua_isyieldable` and otherwise just continue.
121            if lua_isyieldable(state) != 0 {
122                let _ = luaur_vm::functions::lua_break::lua_break(state);
123            }
124        },
125        Err(e) => unsafe {
126            // Raise the error as a Lua error. Push the message and longjmp.
127            raise_error(state, &e);
128        },
129    }
130}
131
132/// Push `e`'s message as a string error object and `lua_error` it (does not
133/// return).
134unsafe fn raise_error(state: *mut lua_State, e: &Error) -> ! {
135    // Use the bare message for a runtime error (so it round-trips back through
136    // `pop_error` as `RuntimeError(msg)` without a doubled "runtime error: "
137    // prefix); fall back to the full Display for other error kinds.
138    let msg = match e {
139        Error::RuntimeError(m) => m.clone(),
140        other => other.to_string(),
141    };
142    unsafe {
143        // The interrupt fires at an arbitrary VM safepoint where `L->top` may be
144        // flush against the call-info top; make room before pushing so the
145        // `api_incr_top` stack invariant in `lua_pushlstring` holds.
146        lua_rawcheckstack(state, 1);
147        lua_pushlstring(state, msg.as_ptr() as *const c_char, msg.len());
148        lua_error(state)
149    }
150}