Skip to main content

lua_rs_hlua_shim/
func.rs

1//! Closure hosting and the typed read/push traits.
2//!
3//! lua-rs native functions are bare `fn` pointers (an index into a per-state
4//! registry), exactly like C's `lua_CFunction`. To host arbitrary Rust closures
5//! that capture state we reproduce the trick `hlua` itself uses on the C API: a
6//! single shared trampoline `fn` is registered as the C closure, the boxed
7//! closure adapter lives in a thread-local registry, and the adapter's registry
8//! index travels as the C closure's first upvalue. Execution only ever takes a
9//! shared borrow of the registry, so nested native calls cannot deadlock it.
10
11use std::cell::RefCell;
12use std::collections::HashMap;
13use std::marker::PhantomData;
14
15use lua_types::{LuaError as VmError, LuaType};
16use lua_vm::api;
17use lua_vm::state::LuaState;
18
19use crate::any::{
20    push_any, push_hashable, read_any, read_map, read_sequence, string_bytes_at, AnyHashableLuaValue,
21    AnyLuaValue,
22};
23use crate::{Lua, LuaError};
24
25/// A type-erased native call: read args, invoke the user closure, push results.
26pub(crate) type Adapter = Box<dyn Fn(&mut LuaState) -> Result<usize, VmError>>;
27
28thread_local! {
29    static REGISTRY: RefCell<Vec<Option<Adapter>>> = const { RefCell::new(Vec::new()) };
30}
31
32/// `LUA_REGISTRYINDEX`, mirrored from `lua-vm`; upvalue 1 sits one below it.
33const LUA_REGISTRYINDEX: i32 = -(1_000_000) - 1000;
34
35fn upvalue_index(n: i32) -> i32 {
36    LUA_REGISTRYINDEX - n
37}
38
39pub(crate) fn registry_insert(adapter: Adapter) -> usize {
40    REGISTRY.with(|cell| {
41        let mut slots = cell.borrow_mut();
42        match slots.iter().position(|slot| slot.is_none()) {
43            Some(i) => {
44                slots[i] = Some(adapter);
45                i
46            }
47            None => {
48                slots.push(Some(adapter));
49                slots.len() - 1
50            }
51        }
52    })
53}
54
55pub(crate) fn registry_remove(index: usize) {
56    REGISTRY.with(|cell| {
57        let mut slots = cell.borrow_mut();
58        if index < slots.len() {
59            slots[index] = None;
60        }
61    });
62}
63
64/// The one C function every hosted closure is registered as. It recovers its
65/// adapter index from upvalue 1 and dispatches.
66pub(crate) fn trampoline(state: &mut LuaState) -> Result<usize, VmError> {
67    let index = api::to_integer_x(state, upvalue_index(1))
68        .ok_or_else(|| VmError::runtime(format_args!("hlua-shim: closure upvalue missing")))?
69        as usize;
70    REGISTRY.with(|cell| {
71        let slots = cell.borrow();
72        match slots.get(index).and_then(|slot| slot.as_ref()) {
73            Some(adapter) => adapter(state),
74            None => Err(VmError::runtime(format_args!(
75                "hlua-shim: closure {index} not registered"
76            ))),
77        }
78    })
79}
80
81// ── reading native-function arguments off the stack ───────────────────────────
82
83/// Convert a single positional argument at stack index `idx` into a Rust value.
84pub trait LuaReadArg: Sized {
85    fn read_arg(state: &mut LuaState, idx: i32) -> Result<Self, VmError>;
86}
87
88impl LuaReadArg for AnyLuaValue {
89    fn read_arg(state: &mut LuaState, idx: i32) -> Result<Self, VmError> {
90        Ok(read_any(state, idx))
91    }
92}
93
94impl LuaReadArg for String {
95    fn read_arg(state: &mut LuaState, idx: i32) -> Result<Self, VmError> {
96        let bytes = string_bytes_at(state, idx)
97            .ok_or_else(|| VmError::runtime(format_args!("expected string argument")))?;
98        String::from_utf8(bytes)
99            .map_err(|_| VmError::runtime(format_args!("string argument is not valid utf-8")))
100    }
101}
102
103impl LuaReadArg for Vec<AnyLuaValue> {
104    fn read_arg(state: &mut LuaState, idx: i32) -> Result<Self, VmError> {
105        Ok(read_sequence(state, idx))
106    }
107}
108
109impl LuaReadArg for HashMap<AnyHashableLuaValue, AnyLuaValue> {
110    fn read_arg(state: &mut LuaState, idx: i32) -> Result<Self, VmError> {
111        Ok(read_map(state, idx))
112    }
113}
114
115macro_rules! read_arg_int {
116    ($($ty:ty),*) => {$(
117        impl LuaReadArg for $ty {
118            fn read_arg(state: &mut LuaState, idx: i32) -> Result<Self, VmError> {
119                let n = api::to_integer_x(state, idx)
120                    .ok_or_else(|| VmError::runtime(format_args!("expected integer argument")))?;
121                Ok(n as $ty)
122            }
123        }
124    )*};
125}
126read_arg_int!(i32, u32, u16, i64, u64);
127
128// ── pushing native-function results onto the stack ────────────────────────────
129
130/// Push a Rust return value onto the stack, yielding the number of results.
131pub trait PushToLua {
132    fn push_to(self, state: &mut LuaState) -> Result<usize, VmError>;
133}
134
135impl PushToLua for () {
136    fn push_to(self, _state: &mut LuaState) -> Result<usize, VmError> {
137        Ok(0)
138    }
139}
140
141impl PushToLua for bool {
142    fn push_to(self, state: &mut LuaState) -> Result<usize, VmError> {
143        api::push_boolean(state, self);
144        Ok(1)
145    }
146}
147
148impl PushToLua for String {
149    fn push_to(self, state: &mut LuaState) -> Result<usize, VmError> {
150        api::push_lstring(state, self.as_bytes())?;
151        Ok(1)
152    }
153}
154
155impl PushToLua for AnyLuaValue {
156    fn push_to(self, state: &mut LuaState) -> Result<usize, VmError> {
157        push_any(state, &self)?;
158        Ok(1)
159    }
160}
161
162impl PushToLua for Vec<AnyLuaValue> {
163    fn push_to(self, state: &mut LuaState) -> Result<usize, VmError> {
164        state.create_table(self.len() as i32, 0)?;
165        let table = api::get_top(state);
166        for (i, value) in self.iter().enumerate() {
167            push_any(state, value)?;
168            state.raw_seti(table, (i + 1) as i64)?;
169        }
170        Ok(1)
171    }
172}
173
174impl PushToLua for HashMap<AnyHashableLuaValue, AnyLuaValue> {
175    fn push_to(self, state: &mut LuaState) -> Result<usize, VmError> {
176        state.create_table(0, self.len() as i32)?;
177        let table = api::get_top(state);
178        for (key, value) in &self {
179            push_hashable(state, key)?;
180            push_any(state, value)?;
181            api::raw_set(state, table)?;
182        }
183        Ok(1)
184    }
185}
186
187macro_rules! push_int {
188    ($($ty:ty),*) => {$(
189        impl PushToLua for $ty {
190            fn push_to(self, state: &mut LuaState) -> Result<usize, VmError> {
191                api::push_integer(state, self as i64);
192                Ok(1)
193            }
194        }
195    )*};
196}
197push_int!(i32, u32, u16, i64);
198
199/// A closure returning `Result` never raises a Lua error: consumers record
200/// failures out of band (e.g. authoscope's `state.error`) and the script keeps
201/// running with `nil` substituted for the result. This matches hlua-badtouch's
202/// observed behaviour and is what makes recorded-error handling work.
203impl<T: PushToLua, E> PushToLua for Result<T, E> {
204    fn push_to(self, state: &mut LuaState) -> Result<usize, VmError> {
205        match self {
206            Ok(value) => value.push_to(state),
207            Err(_) => {
208                api::push_nil(state);
209                Ok(1)
210            }
211        }
212    }
213}
214
215// ── the functionN family ──────────────────────────────────────────────────────
216
217/// Installs a hosted closure as a global of the given name on `lua`.
218pub trait SetValue {
219    fn set_into(self, lua: &mut Lua<'_>, name: &str);
220}
221
222macro_rules! define_function {
223    ($name:ident, $wrapper:ident, ($($arg:ident : $argty:ident),*), ($($idx:expr),*)) => {
224        /// Wrapper produced by the matching `functionN`, mirroring hlua's API.
225        pub struct $wrapper<F, $($argty,)* R> {
226            f: F,
227            _marker: PhantomData<fn($($argty,)*) -> R>,
228        }
229
230        /// Wrap a Rust closure so it can be stored as a Lua global via `Lua::set`.
231        pub fn $name<F, $($argty,)* R>(f: F) -> $wrapper<F, $($argty,)* R>
232        where
233            F: Fn($($argty,)*) -> R + 'static,
234            $($argty: LuaReadArg + 'static,)*
235            R: PushToLua + 'static,
236        {
237            $wrapper { f, _marker: PhantomData }
238        }
239
240        impl<F, $($argty,)* R> SetValue for $wrapper<F, $($argty,)* R>
241        where
242            F: Fn($($argty,)*) -> R + 'static,
243            $($argty: LuaReadArg + 'static,)*
244            R: PushToLua + 'static,
245        {
246            fn set_into(self, lua: &mut Lua<'_>, name: &str) {
247                let f = self.f;
248                let adapter: Adapter = Box::new(move |state: &mut LuaState| {
249                    $(let $arg = $argty::read_arg(state, $idx)?;)*
250                    let result = f($($arg,)*);
251                    result.push_to(state)
252                });
253                lua.install_closure(name, adapter);
254            }
255        }
256    };
257}
258
259define_function!(function0, Function0, (), ());
260define_function!(function1, Function1, (a0: A0), (1));
261define_function!(function2, Function2, (a0: A0, a1: A1), (1, 2));
262define_function!(function3, Function3, (a0: A0, a1: A1, a2: A2), (1, 2, 3));
263define_function!(function4, Function4, (a0: A0, a1: A1, a2: A2, a3: A3), (1, 2, 3, 4));
264define_function!(
265    function5,
266    Function5,
267    (a0: A0, a1: A1, a2: A2, a3: A3, a4: A4),
268    (1, 2, 3, 4, 5)
269);
270define_function!(
271    function6,
272    Function6,
273    (a0: A0, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5),
274    (1, 2, 3, 4, 5, 6)
275);
276
277// ── reading Lua functions back out, and calling them ──────────────────────────
278
279/// A handle to a global Lua function, kept by name so it can be re-fetched and
280/// called. Mirrors `hlua::LuaFunction<L>` (single type parameter).
281pub struct LuaFunction<L> {
282    pub(crate) inner: L,
283}
284
285/// The concrete `L` we instantiate `LuaFunction` with: a live borrow of the VM
286/// plus the global name to call.
287pub struct FnHandle<'a> {
288    pub(crate) state: &'a mut LuaState,
289    pub(crate) name: String,
290}
291
292/// Arguments pushed for a Lua call, returning how many were pushed.
293pub trait PushArgs {
294    fn push_args(self, state: &mut LuaState) -> Result<usize, VmError>;
295}
296
297impl PushArgs for (AnyLuaValue, AnyLuaValue) {
298    fn push_args(self, state: &mut LuaState) -> Result<usize, VmError> {
299        push_any(state, &self.0)?;
300        push_any(state, &self.1)?;
301        Ok(2)
302    }
303}
304
305impl PushArgs for (AnyLuaValue,) {
306    fn push_args(self, state: &mut LuaState) -> Result<usize, VmError> {
307        push_any(state, &self.0)?;
308        Ok(1)
309    }
310}
311
312/// Read a single returned value off the top of the stack.
313pub trait FromTop: Sized {
314    fn from_top(state: &mut LuaState) -> Self;
315}
316
317impl FromTop for AnyLuaValue {
318    fn from_top(state: &mut LuaState) -> Self {
319        read_any(state, -1)
320    }
321}
322
323impl<'a> LuaFunction<FnHandle<'a>> {
324    /// Call the function with the given argument tuple and read one result.
325    pub fn call_with_args<V, A>(&mut self, args: A) -> Result<V, LuaError>
326    where
327        A: PushArgs,
328        V: FromTop,
329    {
330        let state = &mut *self.inner.state;
331        api::get_global(state, self.inner.name.as_bytes()).map_err(LuaError::from_vm)?;
332        let nargs = args.push_args(state).map_err(LuaError::from_vm)?;
333        state
334            .protected_call(nargs as i32, 1, 0)
335            .map_err(LuaError::from_vm)?;
336        let value = V::from_top(state);
337        api::set_top(state, -2).ok();
338        Ok(value)
339    }
340}
341
342// ── borrowing a Lua string without copying out of the VM ──────────────────────
343
344/// Mirror of `hlua::StringInLua<L>`: owns the read string and derefs to `str`.
345pub struct StringInLua<L> {
346    pub(crate) value: String,
347    pub(crate) _marker: PhantomData<L>,
348}
349
350impl<L> std::ops::Deref for StringInLua<L> {
351    type Target = str;
352    fn deref(&self) -> &str {
353        &self.value
354    }
355}
356
357// ── reading named globals (backs `Lua::get`) ──────────────────────────────────
358
359/// Read a global named `name` from `lua`, mirroring `hlua`'s `get`.
360pub trait FromLuaGlobal<'l>: Sized {
361    fn from_lua_global<'lua>(lua: &'l mut Lua<'lua>, name: &str) -> Option<Self>;
362}
363
364impl<'l> FromLuaGlobal<'l> for AnyLuaValue {
365    fn from_lua_global<'lua>(lua: &'l mut Lua<'lua>, name: &str) -> Option<Self> {
366        let state = lua.state_mut();
367        api::get_global(state, name.as_bytes()).ok()?;
368        let value = read_any(state, -1);
369        api::set_top(state, -2).ok();
370        Some(value)
371    }
372}
373
374impl<'l> FromLuaGlobal<'l> for String {
375    fn from_lua_global<'lua>(lua: &'l mut Lua<'lua>, name: &str) -> Option<Self> {
376        let state = lua.state_mut();
377        api::get_global(state, name.as_bytes()).ok()?;
378        let bytes = string_bytes_at(state, -1);
379        api::set_top(state, -2).ok();
380        String::from_utf8(bytes?).ok()
381    }
382}
383
384impl<'l> FromLuaGlobal<'l> for StringInLua<()> {
385    fn from_lua_global<'lua>(lua: &'l mut Lua<'lua>, name: &str) -> Option<Self> {
386        let value: String = FromLuaGlobal::from_lua_global(lua, name)?;
387        Some(StringInLua {
388            value,
389            _marker: PhantomData,
390        })
391    }
392}
393
394impl<'l> FromLuaGlobal<'l> for LuaFunction<FnHandle<'l>> {
395    fn from_lua_global<'lua>(lua: &'l mut Lua<'lua>, name: &str) -> Option<Self> {
396        let owned_name = name.to_string();
397        let state = lua.state_mut();
398        let ty = api::get_global(state, name.as_bytes()).ok()?;
399        api::set_top(state, -2).ok();
400        if ty != LuaType::Function {
401            return None;
402        }
403        Some(LuaFunction {
404            inner: FnHandle {
405                state,
406                name: owned_name,
407            },
408        })
409    }
410}