factorio_mlua/
function.rs

1use std::mem;
2use std::os::raw::{c_int, c_void};
3use std::ptr;
4use std::slice;
5
6use crate::error::{Error, Result};
7use crate::ffi;
8use crate::types::LuaRef;
9use crate::util::{
10    assert_stack, check_stack, error_traceback, pop_error, ptr_to_cstr_bytes, StackGuard,
11};
12use crate::value::{FromLuaMulti, ToLuaMulti};
13
14#[cfg(feature = "async")]
15use {futures_core::future::LocalBoxFuture, futures_util::future};
16
17/// Handle to an internal Lua function.
18#[derive(Clone, Debug)]
19pub struct Function<'lua>(pub(crate) LuaRef<'lua>);
20
21#[derive(Clone, Debug)]
22pub struct FunctionInfo {
23    pub name: Option<Vec<u8>>,
24    pub name_what: Option<Vec<u8>>,
25    pub what: Option<Vec<u8>>,
26    pub source: Option<Vec<u8>>,
27    pub short_src: Option<Vec<u8>>,
28    pub line_defined: i32,
29    #[cfg(not(feature = "luau"))]
30    pub last_line_defined: i32,
31}
32
33/// Luau function coverage snapshot.
34#[cfg(any(feature = "luau", doc))]
35#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
36#[derive(Clone, Debug, PartialEq, Eq)]
37pub struct CoverageInfo {
38    pub function: Option<std::string::String>,
39    pub line_defined: i32,
40    pub depth: i32,
41    pub hits: Vec<i32>,
42}
43
44impl<'lua> Function<'lua> {
45    /// Calls the function, passing `args` as function arguments.
46    ///
47    /// The function's return values are converted to the generic type `R`.
48    ///
49    /// # Examples
50    ///
51    /// Call Lua's built-in `tostring` function:
52    ///
53    /// ```
54    /// # use mlua::{Function, Lua, Result};
55    /// # fn main() -> Result<()> {
56    /// # let lua = Lua::new();
57    /// let globals = lua.globals();
58    ///
59    /// let tostring: Function = globals.get("tostring")?;
60    ///
61    /// assert_eq!(tostring.call::<_, String>(123)?, "123");
62    ///
63    /// # Ok(())
64    /// # }
65    /// ```
66    ///
67    /// Call a function with multiple arguments:
68    ///
69    /// ```
70    /// # use mlua::{Function, Lua, Result};
71    /// # fn main() -> Result<()> {
72    /// # let lua = Lua::new();
73    /// let sum: Function = lua.load(
74    ///     r#"
75    ///         function(a, b)
76    ///             return a + b
77    ///         end
78    /// "#).eval()?;
79    ///
80    /// assert_eq!(sum.call::<_, u32>((3, 4))?, 3 + 4);
81    ///
82    /// # Ok(())
83    /// # }
84    /// ```
85    pub fn call<A: ToLuaMulti<'lua>, R: FromLuaMulti<'lua>>(&self, args: A) -> Result<R> {
86        let lua = self.0.lua;
87
88        let mut args = args.to_lua_multi(lua)?;
89        let nargs = args.len() as c_int;
90
91        let results = unsafe {
92            let _sg = StackGuard::new(lua.state);
93            check_stack(lua.state, nargs + 3)?;
94
95            ffi::lua_pushcfunction(lua.state, error_traceback);
96            let stack_start = ffi::lua_gettop(lua.state);
97            lua.push_ref(&self.0);
98            for arg in args.drain_all() {
99                lua.push_value(arg)?;
100            }
101            let ret = ffi::lua_pcall(lua.state, nargs, ffi::LUA_MULTRET, stack_start);
102            if ret != ffi::LUA_OK {
103                return Err(pop_error(lua.state, ret));
104            }
105            let nresults = ffi::lua_gettop(lua.state) - stack_start;
106            let mut results = args; // Reuse MultiValue container
107            assert_stack(lua.state, 2);
108            for _ in 0..nresults {
109                results.push_front(lua.pop_value());
110            }
111            ffi::lua_pop(lua.state, 1);
112            results
113        };
114        R::from_lua_multi(results, lua)
115    }
116
117    /// Returns a Feature that, when polled, calls `self`, passing `args` as function arguments,
118    /// and drives the execution.
119    ///
120    /// Internally it wraps the function to an [`AsyncThread`].
121    ///
122    /// Requires `feature = "async"`
123    ///
124    /// # Examples
125    ///
126    /// ```
127    /// use std::time::Duration;
128    /// use futures_timer::Delay;
129    /// # use mlua::{Lua, Result};
130    /// # #[tokio::main]
131    /// # async fn main() -> Result<()> {
132    /// # let lua = Lua::new();
133    ///
134    /// let sleep = lua.create_async_function(move |_lua, n: u64| async move {
135    ///     Delay::new(Duration::from_millis(n)).await;
136    ///     Ok(())
137    /// })?;
138    ///
139    /// sleep.call_async(10).await?;
140    ///
141    /// # Ok(())
142    /// # }
143    /// ```
144    ///
145    /// [`AsyncThread`]: crate::AsyncThread
146    #[cfg(feature = "async")]
147    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
148    pub fn call_async<'fut, A, R>(&self, args: A) -> LocalBoxFuture<'fut, Result<R>>
149    where
150        'lua: 'fut,
151        A: ToLuaMulti<'lua>,
152        R: FromLuaMulti<'lua> + 'fut,
153    {
154        let lua = self.0.lua;
155        match lua.create_recycled_thread(self.clone()) {
156            Ok(t) => {
157                let mut t = t.into_async(args);
158                t.set_recyclable(true);
159                Box::pin(t)
160            }
161            Err(e) => Box::pin(future::err(e)),
162        }
163    }
164
165    /// Returns a function that, when called, calls `self`, passing `args` as the first set of
166    /// arguments.
167    ///
168    /// If any arguments are passed to the returned function, they will be passed after `args`.
169    ///
170    /// # Examples
171    ///
172    /// ```
173    /// # use mlua::{Function, Lua, Result};
174    /// # fn main() -> Result<()> {
175    /// # let lua = Lua::new();
176    /// let sum: Function = lua.load(
177    ///     r#"
178    ///         function(a, b)
179    ///             return a + b
180    ///         end
181    /// "#).eval()?;
182    ///
183    /// let bound_a = sum.bind(1)?;
184    /// assert_eq!(bound_a.call::<_, u32>(2)?, 1 + 2);
185    ///
186    /// let bound_a_and_b = sum.bind(13)?.bind(57)?;
187    /// assert_eq!(bound_a_and_b.call::<_, u32>(())?, 13 + 57);
188    ///
189    /// # Ok(())
190    /// # }
191    /// ```
192    pub fn bind<A: ToLuaMulti<'lua>>(&self, args: A) -> Result<Function<'lua>> {
193        unsafe extern "C" fn args_wrapper_impl(state: *mut ffi::lua_State) -> c_int {
194            let nargs = ffi::lua_gettop(state);
195            let nbinds = ffi::lua_tointeger(state, ffi::lua_upvalueindex(1)) as c_int;
196            ffi::luaL_checkstack(state, nbinds, ptr::null());
197
198            for i in 0..nbinds {
199                ffi::lua_pushvalue(state, ffi::lua_upvalueindex(i + 2));
200            }
201            ffi::lua_rotate(state, 1, nbinds);
202
203            nargs + nbinds
204        }
205
206        let lua = self.0.lua;
207
208        let args = args.to_lua_multi(lua)?;
209        let nargs = args.len() as c_int;
210
211        if nargs + 1 > ffi::LUA_MAX_UPVALUES {
212            return Err(Error::BindError);
213        }
214
215        let args_wrapper = unsafe {
216            let _sg = StackGuard::new(lua.state);
217            check_stack(lua.state, nargs + 3)?;
218
219            ffi::lua_pushinteger(lua.state, nargs as ffi::lua_Integer);
220            for arg in args {
221                lua.push_value(arg)?;
222            }
223            protect_lua!(lua.state, nargs + 1, 1, fn(state) {
224                ffi::lua_pushcclosure(state, args_wrapper_impl, ffi::lua_gettop(state));
225            })?;
226
227            Function(lua.pop_ref())
228        };
229
230        lua.load(
231            r#"
232            local func, args_wrapper = ...
233            return function(...)
234                return func(args_wrapper(...))
235            end
236            "#,
237        )
238        .try_cache()
239        .set_name("_mlua_bind")?
240        .call((self.clone(), args_wrapper))
241    }
242
243    /// Returns information about the function.
244    ///
245    /// Corresponds to the `>Sn` what mask for [`lua_getinfo`] when applied to the function.
246    ///
247    /// [`lua_getinfo`]: https://www.lua.org/manual/5.4/manual.html#lua_getinfo
248    pub fn info(&self) -> FunctionInfo {
249        let lua = self.0.lua;
250        unsafe {
251            let _sg = StackGuard::new(lua.state);
252            assert_stack(lua.state, 1);
253
254            let mut ar: ffi::lua_Debug = mem::zeroed();
255            lua.push_ref(&self.0);
256            #[cfg(not(feature = "luau"))]
257            let res = ffi::lua_getinfo(lua.state, cstr!(">Sn"), &mut ar);
258            #[cfg(feature = "luau")]
259            let res = ffi::lua_getinfo(lua.state, -1, cstr!("sn"), &mut ar);
260            mlua_assert!(res != 0, "lua_getinfo failed with `>Sn`");
261
262            FunctionInfo {
263                name: ptr_to_cstr_bytes(ar.name).map(|s| s.to_vec()),
264                #[cfg(not(feature = "luau"))]
265                name_what: ptr_to_cstr_bytes(ar.namewhat).map(|s| s.to_vec()),
266                #[cfg(feature = "luau")]
267                name_what: None,
268                what: ptr_to_cstr_bytes(ar.what).map(|s| s.to_vec()),
269                source: ptr_to_cstr_bytes(ar.source).map(|s| s.to_vec()),
270                short_src: ptr_to_cstr_bytes(&ar.short_src as *const _).map(|s| s.to_vec()),
271                line_defined: ar.linedefined as i32,
272                #[cfg(not(feature = "luau"))]
273                last_line_defined: ar.lastlinedefined as i32,
274            }
275        }
276    }
277
278    /// Dumps the function as a binary chunk.
279    ///
280    /// If `strip` is true, the binary representation may not include all debug information
281    /// about the function, to save space.
282    ///
283    /// For Luau a [Compiler] can be used to compile Lua chunks to bytecode.
284    ///
285    /// [Compiler]: crate::chunk::Compiler
286    #[cfg(not(feature = "luau"))]
287    #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
288    pub fn dump(&self, strip: bool) -> Vec<u8> {
289        unsafe extern "C" fn writer(
290            _state: *mut ffi::lua_State,
291            buf: *const c_void,
292            buf_len: usize,
293            data: *mut c_void,
294        ) -> c_int {
295            let data = &mut *(data as *mut Vec<u8>);
296            let buf = slice::from_raw_parts(buf as *const u8, buf_len);
297            data.extend_from_slice(buf);
298            0
299        }
300
301        let lua = self.0.lua;
302        let mut data: Vec<u8> = Vec::new();
303        unsafe {
304            let _sg = StackGuard::new(lua.state);
305            assert_stack(lua.state, 1);
306
307            lua.push_ref(&self.0);
308            let data_ptr = &mut data as *mut Vec<u8> as *mut c_void;
309            let strip = if strip { 1 } else { 0 };
310            ffi::lua_dump(lua.state, writer, data_ptr, strip);
311            ffi::lua_pop(lua.state, 1);
312        }
313
314        data
315    }
316
317    /// Retrieves recorded coverage information about this Lua function including inner calls.
318    ///
319    /// This function takes a callback as an argument and calls it providing [`CoverageInfo`] snapshot
320    /// per each executed inner function.
321    ///
322    /// Recording of coverage information is controlled by [`Compiler::set_coverage_level`] option.
323    ///
324    /// Requires `feature = "luau"`
325    ///
326    /// [`Compiler::set_coverage_level`]: crate::chunk::Compiler::set_coverage_level
327    #[cfg(any(feature = "luau", docsrs))]
328    #[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
329    pub fn coverage<F>(&self, mut func: F)
330    where
331        F: FnMut(CoverageInfo),
332    {
333        use std::ffi::CStr;
334        use std::os::raw::c_char;
335
336        unsafe extern "C" fn callback<F: FnMut(CoverageInfo)>(
337            data: *mut c_void,
338            function: *const c_char,
339            line_defined: c_int,
340            depth: c_int,
341            hits: *const c_int,
342            size: usize,
343        ) {
344            let function = if !function.is_null() {
345                Some(CStr::from_ptr(function).to_string_lossy().to_string())
346            } else {
347                None
348            };
349            let rust_callback = &mut *(data as *mut F);
350            rust_callback(CoverageInfo {
351                function,
352                line_defined,
353                depth,
354                hits: slice::from_raw_parts(hits, size).to_vec(),
355            });
356        }
357
358        let lua = self.0.lua;
359        unsafe {
360            let _sg = StackGuard::new(lua.state);
361            assert_stack(lua.state, 1);
362
363            lua.push_ref(&self.0);
364            let func_ptr = &mut func as *mut F as *mut c_void;
365            ffi::lua_getcoverage(lua.state, -1, func_ptr, callback::<F>);
366        }
367    }
368}
369
370impl<'lua> PartialEq for Function<'lua> {
371    fn eq(&self, other: &Self) -> bool {
372        self.0 == other.0
373    }
374}