mlua_codemp_patch/
function.rs

1use std::cell::RefCell;
2use std::os::raw::{c_int, c_void};
3use std::{mem, ptr, slice};
4
5use crate::error::{Error, Result};
6use crate::state::Lua;
7use crate::table::Table;
8use crate::types::{Callback, MaybeSend, ValueRef};
9use crate::util::{
10    assert_stack, check_stack, linenumber_to_usize, pop_error, ptr_to_lossy_str, ptr_to_str, StackGuard,
11};
12use crate::value::{FromLuaMulti, IntoLua, IntoLuaMulti, Value};
13
14#[cfg(feature = "async")]
15use {
16    crate::types::AsyncCallback,
17    std::future::{self, Future},
18};
19
20/// Handle to an internal Lua function.
21#[derive(Clone, Debug)]
22pub struct Function(pub(crate) ValueRef);
23
24/// Contains information about a function.
25///
26/// Please refer to the [`Lua Debug Interface`] for more information.
27///
28/// [`Lua Debug Interface`]: https://www.lua.org/manual/5.4/manual.html#4.7
29#[derive(Clone, Debug)]
30pub struct FunctionInfo {
31    /// A (reasonable) name of the function (`None` if the name cannot be found).
32    pub name: Option<String>,
33    /// Explains the `name` field (can be `global`/`local`/`method`/`field`/`upvalue`/etc).
34    ///
35    /// Always `None` for Luau.
36    pub name_what: Option<&'static str>,
37    /// A string `Lua` if the function is a Lua function, `C` if it is a C function, `main` if it is
38    /// the main part of a chunk.
39    pub what: &'static str,
40    /// Source of the chunk that created the function.
41    pub source: Option<String>,
42    /// A "printable" version of `source`, to be used in error messages.
43    pub short_src: Option<String>,
44    /// The line number where the definition of the function starts.
45    pub line_defined: Option<usize>,
46    /// The line number where the definition of the function ends (not set by Luau).
47    pub last_line_defined: Option<usize>,
48}
49
50/// Luau function coverage snapshot.
51#[cfg(any(feature = "luau", doc))]
52#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
53#[derive(Clone, Debug, PartialEq, Eq)]
54pub struct CoverageInfo {
55    pub function: Option<String>,
56    pub line_defined: i32,
57    pub depth: i32,
58    pub hits: Vec<i32>,
59}
60
61impl Function {
62    /// Calls the function, passing `args` as function arguments.
63    ///
64    /// The function's return values are converted to the generic type `R`.
65    ///
66    /// # Examples
67    ///
68    /// Call Lua's built-in `tostring` function:
69    ///
70    /// ```
71    /// # use mlua::{Function, Lua, Result};
72    /// # fn main() -> Result<()> {
73    /// # let lua = Lua::new();
74    /// let globals = lua.globals();
75    ///
76    /// let tostring: Function = globals.get("tostring")?;
77    ///
78    /// assert_eq!(tostring.call::<String>(123)?, "123");
79    ///
80    /// # Ok(())
81    /// # }
82    /// ```
83    ///
84    /// Call a function with multiple arguments:
85    ///
86    /// ```
87    /// # use mlua::{Function, Lua, Result};
88    /// # fn main() -> Result<()> {
89    /// # let lua = Lua::new();
90    /// let sum: Function = lua.load(
91    ///     r#"
92    ///         function(a, b)
93    ///             return a + b
94    ///         end
95    /// "#).eval()?;
96    ///
97    /// assert_eq!(sum.call::<u32>((3, 4))?, 3 + 4);
98    ///
99    /// # Ok(())
100    /// # }
101    /// ```
102    pub fn call<R: FromLuaMulti>(&self, args: impl IntoLuaMulti) -> Result<R> {
103        let lua = self.0.lua.lock();
104        let state = lua.state();
105        unsafe {
106            let _sg = StackGuard::new(state);
107            check_stack(state, 2)?;
108
109            // Push error handler
110            lua.push_error_traceback();
111            let stack_start = ffi::lua_gettop(state);
112            // Push function and the arguments
113            lua.push_ref(&self.0);
114            let nargs = args.push_into_stack_multi(&lua)?;
115            // Call the function
116            let ret = ffi::lua_pcall(state, nargs, ffi::LUA_MULTRET, stack_start);
117            if ret != ffi::LUA_OK {
118                return Err(pop_error(state, ret));
119            }
120            // Get the results
121            let nresults = ffi::lua_gettop(state) - stack_start;
122            R::from_stack_multi(nresults, &lua)
123        }
124    }
125
126    /// Returns a future that, when polled, calls `self`, passing `args` as function arguments,
127    /// and drives the execution.
128    ///
129    /// Internally it wraps the function to an [`AsyncThread`].
130    ///
131    /// Requires `feature = "async"`
132    ///
133    /// # Examples
134    ///
135    /// ```
136    /// use std::time::Duration;
137    /// # use mlua::{Lua, Result};
138    /// # #[tokio::main]
139    /// # async fn main() -> Result<()> {
140    /// # let lua = Lua::new();
141    ///
142    /// let sleep = lua.create_async_function(move |_lua, n: u64| async move {
143    ///     tokio::time::sleep(Duration::from_millis(n)).await;
144    ///     Ok(())
145    /// })?;
146    ///
147    /// sleep.call_async(10).await?;
148    ///
149    /// # Ok(())
150    /// # }
151    /// ```
152    ///
153    /// [`AsyncThread`]: crate::AsyncThread
154    #[cfg(feature = "async")]
155    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
156    pub fn call_async<R>(&self, args: impl IntoLuaMulti) -> impl Future<Output = Result<R>>
157    where
158        R: FromLuaMulti,
159    {
160        let lua = self.0.lua.lock();
161        let thread_res = unsafe {
162            lua.create_recycled_thread(self).map(|th| {
163                let mut th = th.into_async(args);
164                th.set_recyclable(true);
165                th
166            })
167        };
168        async move { thread_res?.await }
169    }
170
171    /// Returns a function that, when called, calls `self`, passing `args` as the first set of
172    /// arguments.
173    ///
174    /// If any arguments are passed to the returned function, they will be passed after `args`.
175    ///
176    /// # Examples
177    ///
178    /// ```
179    /// # use mlua::{Function, Lua, Result};
180    /// # fn main() -> Result<()> {
181    /// # let lua = Lua::new();
182    /// let sum: Function = lua.load(
183    ///     r#"
184    ///         function(a, b)
185    ///             return a + b
186    ///         end
187    /// "#).eval()?;
188    ///
189    /// let bound_a = sum.bind(1)?;
190    /// assert_eq!(bound_a.call::<u32>(2)?, 1 + 2);
191    ///
192    /// let bound_a_and_b = sum.bind(13)?.bind(57)?;
193    /// assert_eq!(bound_a_and_b.call::<u32>(())?, 13 + 57);
194    ///
195    /// # Ok(())
196    /// # }
197    /// ```
198    pub fn bind(&self, args: impl IntoLuaMulti) -> Result<Function> {
199        unsafe extern "C-unwind" fn args_wrapper_impl(state: *mut ffi::lua_State) -> c_int {
200            let nargs = ffi::lua_gettop(state);
201            let nbinds = ffi::lua_tointeger(state, ffi::lua_upvalueindex(1)) as c_int;
202            ffi::luaL_checkstack(state, nbinds, ptr::null());
203
204            for i in 0..nbinds {
205                ffi::lua_pushvalue(state, ffi::lua_upvalueindex(i + 2));
206            }
207            if nargs > 0 {
208                ffi::lua_rotate(state, 1, nbinds);
209            }
210
211            nargs + nbinds
212        }
213
214        let lua = self.0.lua.lock();
215        let state = lua.state();
216
217        let args = args.into_lua_multi(lua.lua())?;
218        let nargs = args.len() as c_int;
219
220        if nargs == 0 {
221            return Ok(self.clone());
222        }
223
224        if nargs + 1 > ffi::LUA_MAX_UPVALUES {
225            return Err(Error::BindError);
226        }
227
228        let args_wrapper = unsafe {
229            let _sg = StackGuard::new(state);
230            check_stack(state, nargs + 3)?;
231
232            ffi::lua_pushinteger(state, nargs as ffi::lua_Integer);
233            for arg in &args {
234                lua.push_value(arg)?;
235            }
236            protect_lua!(state, nargs + 1, 1, fn(state) {
237                ffi::lua_pushcclosure(state, args_wrapper_impl, ffi::lua_gettop(state));
238            })?;
239
240            Function(lua.pop_ref())
241        };
242
243        let lua = lua.lua();
244        lua.load(
245            r#"
246            local func, args_wrapper = ...
247            return function(...)
248                return func(args_wrapper(...))
249            end
250            "#,
251        )
252        .try_cache()
253        .set_name("__mlua_bind")
254        .call((self, args_wrapper))
255    }
256
257    /// Returns the environment of the Lua function.
258    ///
259    /// By default Lua functions shares a global environment.
260    ///
261    /// This function always returns `None` for Rust/C functions.
262    pub fn environment(&self) -> Option<Table> {
263        let lua = self.0.lua.lock();
264        let state = lua.state();
265        unsafe {
266            let _sg = StackGuard::new(state);
267            assert_stack(state, 1);
268
269            lua.push_ref(&self.0);
270            if ffi::lua_iscfunction(state, -1) != 0 {
271                return None;
272            }
273
274            #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))]
275            ffi::lua_getfenv(state, -1);
276            #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
277            for i in 1..=255 {
278                // Traverse upvalues until we find the _ENV one
279                match ffi::lua_getupvalue(state, -1, i) {
280                    s if s.is_null() => break,
281                    s if std::ffi::CStr::from_ptr(s as _).to_bytes() == b"_ENV" => break,
282                    _ => ffi::lua_pop(state, 1),
283                }
284            }
285
286            if ffi::lua_type(state, -1) != ffi::LUA_TTABLE {
287                return None;
288            }
289            Some(Table(lua.pop_ref()))
290        }
291    }
292
293    /// Sets the environment of the Lua function.
294    ///
295    /// The environment is a table that is used as the global environment for the function.
296    /// Returns `true` if environment successfully changed, `false` otherwise.
297    ///
298    /// This function does nothing for Rust/C functions.
299    pub fn set_environment(&self, env: Table) -> Result<bool> {
300        let lua = self.0.lua.lock();
301        let state = lua.state();
302        unsafe {
303            let _sg = StackGuard::new(state);
304            check_stack(state, 2)?;
305
306            lua.push_ref(&self.0);
307            if ffi::lua_iscfunction(state, -1) != 0 {
308                return Ok(false);
309            }
310
311            #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))]
312            {
313                lua.push_ref(&env.0);
314                ffi::lua_setfenv(state, -2);
315            }
316            #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
317            for i in 1..=255 {
318                match ffi::lua_getupvalue(state, -1, i) {
319                    s if s.is_null() => return Ok(false),
320                    s if std::ffi::CStr::from_ptr(s as _).to_bytes() == b"_ENV" => {
321                        ffi::lua_pop(state, 1);
322                        // Create an anonymous function with the new environment
323                        let f_with_env = lua
324                            .lua()
325                            .load("return _ENV")
326                            .set_environment(env)
327                            .try_cache()
328                            .into_function()?;
329                        lua.push_ref(&f_with_env.0);
330                        ffi::lua_upvaluejoin(state, -2, i, -1, 1);
331                        break;
332                    }
333                    _ => ffi::lua_pop(state, 1),
334                }
335            }
336
337            Ok(true)
338        }
339    }
340
341    /// Returns information about the function.
342    ///
343    /// Corresponds to the `>Sn` what mask for [`lua_getinfo`] when applied to the function.
344    ///
345    /// [`lua_getinfo`]: https://www.lua.org/manual/5.4/manual.html#lua_getinfo
346    pub fn info(&self) -> FunctionInfo {
347        let lua = self.0.lua.lock();
348        let state = lua.state();
349        unsafe {
350            let _sg = StackGuard::new(state);
351            assert_stack(state, 1);
352
353            let mut ar: ffi::lua_Debug = mem::zeroed();
354            lua.push_ref(&self.0);
355            #[cfg(not(feature = "luau"))]
356            let res = ffi::lua_getinfo(state, cstr!(">Sn"), &mut ar);
357            #[cfg(feature = "luau")]
358            let res = ffi::lua_getinfo(state, -1, cstr!("sn"), &mut ar);
359            mlua_assert!(res != 0, "lua_getinfo failed with `>Sn`");
360
361            FunctionInfo {
362                name: ptr_to_lossy_str(ar.name).map(|s| s.into_owned()),
363                #[cfg(not(feature = "luau"))]
364                name_what: match ptr_to_str(ar.namewhat) {
365                    Some("") => None,
366                    val => val,
367                },
368                #[cfg(feature = "luau")]
369                name_what: None,
370                what: ptr_to_str(ar.what).unwrap_or("main"),
371                source: ptr_to_lossy_str(ar.source).map(|s| s.into_owned()),
372                #[cfg(not(feature = "luau"))]
373                short_src: ptr_to_lossy_str(ar.short_src.as_ptr()).map(|s| s.into_owned()),
374                #[cfg(feature = "luau")]
375                short_src: ptr_to_lossy_str(ar.short_src).map(|s| s.into_owned()),
376                line_defined: linenumber_to_usize(ar.linedefined),
377                #[cfg(not(feature = "luau"))]
378                last_line_defined: linenumber_to_usize(ar.lastlinedefined),
379                #[cfg(feature = "luau")]
380                last_line_defined: None,
381            }
382        }
383    }
384
385    /// Dumps the function as a binary chunk.
386    ///
387    /// If `strip` is true, the binary representation may not include all debug information
388    /// about the function, to save space.
389    ///
390    /// For Luau a [Compiler] can be used to compile Lua chunks to bytecode.
391    ///
392    /// [Compiler]: crate::chunk::Compiler
393    #[cfg(not(feature = "luau"))]
394    #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
395    pub fn dump(&self, strip: bool) -> Vec<u8> {
396        unsafe extern "C-unwind" fn writer(
397            _state: *mut ffi::lua_State,
398            buf: *const c_void,
399            buf_len: usize,
400            data: *mut c_void,
401        ) -> c_int {
402            let data = &mut *(data as *mut Vec<u8>);
403            let buf = slice::from_raw_parts(buf as *const u8, buf_len);
404            data.extend_from_slice(buf);
405            0
406        }
407
408        let lua = self.0.lua.lock();
409        let state = lua.state();
410        let mut data: Vec<u8> = Vec::new();
411        unsafe {
412            let _sg = StackGuard::new(state);
413            assert_stack(state, 1);
414
415            lua.push_ref(&self.0);
416            let data_ptr = &mut data as *mut Vec<u8> as *mut c_void;
417            ffi::lua_dump(state, writer, data_ptr, strip as i32);
418            ffi::lua_pop(state, 1);
419        }
420
421        data
422    }
423
424    /// Retrieves recorded coverage information about this Lua function including inner calls.
425    ///
426    /// This function takes a callback as an argument and calls it providing [`CoverageInfo`]
427    /// snapshot per each executed inner function.
428    ///
429    /// Recording of coverage information is controlled by [`Compiler::set_coverage_level`] option.
430    ///
431    /// Requires `feature = "luau"`
432    ///
433    /// [`Compiler::set_coverage_level`]: crate::chunk::Compiler::set_coverage_level
434    #[cfg(any(feature = "luau", doc))]
435    #[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
436    pub fn coverage<F>(&self, mut func: F)
437    where
438        F: FnMut(CoverageInfo),
439    {
440        use std::ffi::CStr;
441        use std::os::raw::c_char;
442
443        unsafe extern "C-unwind" fn callback<F: FnMut(CoverageInfo)>(
444            data: *mut c_void,
445            function: *const c_char,
446            line_defined: c_int,
447            depth: c_int,
448            hits: *const c_int,
449            size: usize,
450        ) {
451            let function = if !function.is_null() {
452                Some(CStr::from_ptr(function).to_string_lossy().to_string())
453            } else {
454                None
455            };
456            let rust_callback = &mut *(data as *mut F);
457            rust_callback(CoverageInfo {
458                function,
459                line_defined,
460                depth,
461                hits: slice::from_raw_parts(hits, size).to_vec(),
462            });
463        }
464
465        let lua = self.0.lua.lock();
466        let state = lua.state();
467        unsafe {
468            let _sg = StackGuard::new(state);
469            assert_stack(state, 1);
470
471            lua.push_ref(&self.0);
472            let func_ptr = &mut func as *mut F as *mut c_void;
473            ffi::lua_getcoverage(state, -1, func_ptr, callback::<F>);
474        }
475    }
476
477    /// Converts this function to a generic C pointer.
478    ///
479    /// There is no way to convert the pointer back to its original value.
480    ///
481    /// Typically this function is used only for hashing and debug information.
482    #[inline]
483    pub fn to_pointer(&self) -> *const c_void {
484        self.0.to_pointer()
485    }
486
487    /// Creates a deep clone of the Lua function.
488    ///
489    /// Copies the function prototype and all its upvalues to the
490    /// newly created function.
491    ///
492    /// This function returns shallow clone (same handle) for Rust/C functions.
493    /// Requires `feature = "luau"`
494    #[cfg(feature = "luau")]
495    #[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
496    pub fn deep_clone(&self) -> Self {
497        let lua = self.0.lua.lock();
498        let ref_thread = lua.ref_thread();
499        unsafe {
500            if ffi::lua_iscfunction(ref_thread, self.0.index) != 0 {
501                return self.clone();
502            }
503
504            ffi::lua_clonefunction(ref_thread, self.0.index);
505            Function(lua.pop_ref_thread())
506        }
507    }
508}
509
510impl PartialEq for Function {
511    fn eq(&self, other: &Self) -> bool {
512        self.0 == other.0
513    }
514}
515
516pub(crate) struct WrappedFunction(pub(crate) Callback);
517
518#[cfg(feature = "async")]
519pub(crate) struct WrappedAsyncFunction(pub(crate) AsyncCallback);
520
521impl Function {
522    /// Wraps a Rust function or closure, returning an opaque type that implements [`IntoLua`]
523    /// trait.
524    #[inline]
525    pub fn wrap<A, R, F>(func: F) -> impl IntoLua
526    where
527        A: FromLuaMulti,
528        R: IntoLuaMulti,
529        F: Fn(&Lua, A) -> Result<R> + MaybeSend + 'static,
530    {
531        WrappedFunction(Box::new(move |lua, nargs| unsafe {
532            let args = A::from_stack_args(nargs, 1, None, lua)?;
533            func(lua.lua(), args)?.push_into_stack_multi(lua)
534        }))
535    }
536
537    /// Wraps a Rust mutable closure, returning an opaque type that implements [`IntoLua`] trait.
538    #[inline]
539    pub fn wrap_mut<A, R, F>(func: F) -> impl IntoLua
540    where
541        A: FromLuaMulti,
542        R: IntoLuaMulti,
543        F: FnMut(&Lua, A) -> Result<R> + MaybeSend + 'static,
544    {
545        let func = RefCell::new(func);
546        WrappedFunction(Box::new(move |lua, nargs| unsafe {
547            let mut func = func.try_borrow_mut().map_err(|_| Error::RecursiveMutCallback)?;
548            let args = A::from_stack_args(nargs, 1, None, lua)?;
549            func(lua.lua(), args)?.push_into_stack_multi(lua)
550        }))
551    }
552
553    /// Wraps a Rust async function or closure, returning an opaque type that implements [`IntoLua`]
554    /// trait.
555    #[cfg(feature = "async")]
556    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
557    pub fn wrap_async<A, R, F, FR>(func: F) -> impl IntoLua
558    where
559        A: FromLuaMulti,
560        R: IntoLuaMulti,
561        F: Fn(Lua, A) -> FR + MaybeSend + 'static,
562        FR: Future<Output = Result<R>> + MaybeSend + 'static,
563    {
564        WrappedAsyncFunction(Box::new(move |rawlua, nargs| unsafe {
565            let args = match A::from_stack_args(nargs, 1, None, rawlua) {
566                Ok(args) => args,
567                Err(e) => return Box::pin(future::ready(Err(e))),
568            };
569            let lua = rawlua.lua().clone();
570            let fut = func(lua.clone(), args);
571            Box::pin(async move { fut.await?.push_into_stack_multi(lua.raw_lua()) })
572        }))
573    }
574}
575
576impl IntoLua for WrappedFunction {
577    #[inline]
578    fn into_lua(self, lua: &Lua) -> Result<Value> {
579        lua.lock().create_callback(self.0).map(Value::Function)
580    }
581}
582
583#[cfg(feature = "async")]
584impl IntoLua for WrappedAsyncFunction {
585    #[inline]
586    fn into_lua(self, lua: &Lua) -> Result<Value> {
587        lua.lock().create_async_callback(self.0).map(Value::Function)
588    }
589}
590
591#[cfg(test)]
592mod assertions {
593    use super::*;
594
595    #[cfg(not(feature = "send"))]
596    static_assertions::assert_not_impl_any!(Function: Send);
597    #[cfg(feature = "send")]
598    static_assertions::assert_impl_all!(Function: Send, Sync);
599}