nvim_oxi_types/
function.rs

1use std::cell::RefCell;
2use std::error::Error as StdError;
3use std::ffi::c_int;
4use std::fmt;
5use std::marker::PhantomData;
6
7use luajit::{self as lua, IntoResult, Poppable, Pushable, ffi};
8
9use crate::{Error, LuaRef};
10
11/// A wrapper around a Lua reference to a function stored in the Lua registry.
12#[derive(Clone, Eq, PartialEq, Hash)]
13pub struct Function<A, R> {
14    pub(crate) lua_ref: LuaRef,
15    _pd: (PhantomData<A>, PhantomData<R>),
16}
17
18impl<A, R> fmt::Debug for Function<A, R> {
19    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20        write!(
21            f,
22            "<function {}: {} -> {}>",
23            self.lua_ref,
24            std::any::type_name::<A>(),
25            std::any::type_name::<R>()
26        )
27    }
28}
29
30impl<A, R, F, O> From<F> for Function<A, R>
31where
32    F: Fn(A) -> O + 'static,
33    A: Poppable,
34    O: IntoResult<R>,
35    R: Pushable,
36    O::Error: StdError + 'static,
37{
38    fn from(fun: F) -> Function<A, R> {
39        Function::from_fn_mut(fun)
40    }
41}
42
43impl<A, R> Poppable for Function<A, R> {
44    unsafe fn pop(state: *mut lua::ffi::State) -> Result<Self, lua::Error> {
45        if ffi::lua_gettop(state) == 0 {
46            return Err(lua::Error::PopEmptyStack);
47        }
48
49        match ffi::lua_type(state, -1) {
50            ffi::LUA_TFUNCTION => {
51                let lua_ref = ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX);
52                // TODO: check `lua_ref`.
53                Ok(Self::from_ref(lua_ref))
54            },
55
56            other => Err(lua::Error::pop_wrong_type::<Self>(
57                ffi::LUA_TFUNCTION,
58                other,
59            )),
60        }
61    }
62}
63
64impl<A, R> Pushable for Function<A, R> {
65    unsafe fn push(
66        self,
67        state: *mut lua::ffi::State,
68    ) -> Result<c_int, lua::Error> {
69        ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, self.lua_ref);
70        Ok(1)
71    }
72}
73
74impl<A, R> Function<A, R> {
75    pub(crate) fn from_ref(lua_ref: LuaRef) -> Self {
76        Self { lua_ref, _pd: (PhantomData, PhantomData) }
77    }
78
79    #[doc(hidden)]
80    pub fn lua_ref(&self) -> LuaRef {
81        self.lua_ref
82    }
83
84    pub fn from_fn<F, O>(fun: F) -> Self
85    where
86        F: Fn(A) -> O + 'static,
87        A: Poppable,
88        O: IntoResult<R>,
89        R: Pushable,
90        O::Error: StdError + 'static,
91    {
92        Self::from_ref(lua::function::store(fun))
93    }
94
95    pub fn from_fn_mut<F, O>(fun: F) -> Self
96    where
97        F: FnMut(A) -> O + 'static,
98        A: Poppable,
99        O: IntoResult<R>,
100        R: Pushable,
101        O::Error: StdError + 'static,
102    {
103        let fun = RefCell::new(fun);
104
105        Self::from_fn(move |args| {
106            let fun = &mut *fun.try_borrow_mut().map_err(Error::from_err)?;
107
108            fun(args).into_result().map_err(Error::from_err)
109        })
110    }
111
112    pub fn from_fn_once<F, O>(fun: F) -> Self
113    where
114        F: FnOnce(A) -> O + 'static,
115        A: Poppable,
116        O: IntoResult<R>,
117        R: Pushable,
118        O::Error: StdError + 'static,
119    {
120        let fun = RefCell::new(Some(fun));
121
122        Self::from_fn(move |args| {
123            let fun = fun
124                .try_borrow_mut()
125                .map_err(Error::from_err)?
126                .take()
127                .ok_or_else(|| {
128                    Error::from_str("Cannot call function twice")
129                })?;
130
131            fun(args).into_result().map_err(Error::from_err)
132        })
133    }
134
135    pub fn call(&self, args: A) -> Result<R, lua::Error>
136    where
137        A: Pushable,
138        R: Poppable,
139    {
140        lua::function::call(self.lua_ref, args)
141    }
142
143    /// Consumes the `Function`, removing the reference stored in the Lua
144    /// registry.
145    #[doc(hidden)]
146    pub fn remove_from_lua_registry(self) {
147        lua::function::remove(self.lua_ref)
148    }
149}
150
151#[cfg(feature = "serde")]
152mod serde {
153    use std::fmt;
154
155    use serde::de::{self, Deserialize, Deserializer, Visitor};
156    use serde::ser::{Serialize, Serializer};
157
158    use super::Function;
159    use crate::LuaRef;
160
161    impl<A, R> Serialize for Function<A, R> {
162        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
163        where
164            S: Serializer,
165        {
166            serializer.serialize_f32(self.lua_ref as f32)
167        }
168    }
169
170    impl<'de, A, R> Deserialize<'de> for Function<A, R> {
171        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
172        where
173            D: Deserializer<'de>,
174        {
175            use std::marker::PhantomData;
176
177            struct FunctionVisitor<A, R>(PhantomData<A>, PhantomData<R>);
178
179            impl<A, R> Visitor<'_> for FunctionVisitor<A, R> {
180                type Value = Function<A, R>;
181
182                fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
183                    f.write_str("an f32 representing a Lua reference")
184                }
185
186                fn visit_f32<E>(self, value: f32) -> Result<Self::Value, E>
187                where
188                    E: de::Error,
189                {
190                    Ok(Function::from_ref(value as LuaRef))
191                }
192            }
193
194            deserializer
195                .deserialize_f32(FunctionVisitor(PhantomData, PhantomData))
196        }
197    }
198}