nvim_oxi_types/
function.rs

1use std::cell::RefCell;
2use std::error::Error as StdError;
3use std::ffi::c_int;
4use std::fmt;
5use std::marker::PhantomData;
6
7use luajit::{self as lua, ffi, IntoResult, Poppable, Pushable};
8
9use crate::{Error, LuaRef};
10
11/// A wrapper around a Lua reference to a function stored in the Lua registry.
12#[derive(Clone, Eq, PartialEq, Hash)]
13pub struct Function<A, R> {
14    pub(crate) lua_ref: LuaRef,
15    _pd: (PhantomData<A>, PhantomData<R>),
16}
17
18impl<A, R> fmt::Debug for Function<A, R> {
19    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20        write!(
21            f,
22            "<function {}: {} -> {}>",
23            self.lua_ref,
24            std::any::type_name::<A>(),
25            std::any::type_name::<R>()
26        )
27    }
28}
29
30impl<A, R, F, O> From<F> for Function<A, R>
31where
32    F: Fn(A) -> O + 'static,
33    A: Poppable,
34    O: IntoResult<R>,
35    R: Pushable,
36    O::Error: StdError + 'static,
37{
38    fn from(fun: F) -> Function<A, R> {
39        Function::from_fn_mut(fun)
40    }
41}
42
43impl<A, R> Poppable for Function<A, R> {
44    unsafe fn pop(
45        state: *mut lua::ffi::lua_State,
46    ) -> Result<Self, lua::Error> {
47        if ffi::lua_gettop(state) == 0 {
48            return Err(lua::Error::PopEmptyStack);
49        }
50
51        match ffi::lua_type(state, -1) {
52            ffi::LUA_TFUNCTION => {
53                let lua_ref = ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX);
54                // TODO: check `lua_ref`.
55                Ok(Self::from_ref(lua_ref))
56            },
57
58            other => Err(lua::Error::pop_wrong_type::<Self>(
59                ffi::LUA_TFUNCTION,
60                other,
61            )),
62        }
63    }
64}
65
66impl<A, R> Pushable for Function<A, R> {
67    unsafe fn push(
68        self,
69        state: *mut lua::ffi::lua_State,
70    ) -> Result<c_int, lua::Error> {
71        ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, self.lua_ref);
72        Ok(1)
73    }
74}
75
76impl<A, R> Function<A, R> {
77    pub(crate) fn from_ref(lua_ref: LuaRef) -> Self {
78        Self { lua_ref, _pd: (PhantomData, PhantomData) }
79    }
80
81    #[doc(hidden)]
82    pub fn lua_ref(&self) -> LuaRef {
83        self.lua_ref
84    }
85
86    pub fn from_fn<F, O>(fun: F) -> Self
87    where
88        F: Fn(A) -> O + 'static,
89        A: Poppable,
90        O: IntoResult<R>,
91        R: Pushable,
92        O::Error: StdError + 'static,
93    {
94        Self::from_ref(lua::function::store(fun))
95    }
96
97    pub fn from_fn_mut<F, O>(fun: F) -> Self
98    where
99        F: FnMut(A) -> O + 'static,
100        A: Poppable,
101        O: IntoResult<R>,
102        R: Pushable,
103        O::Error: StdError + 'static,
104    {
105        let fun = RefCell::new(fun);
106
107        Self::from_fn(move |args| {
108            let fun = &mut *fun.try_borrow_mut().map_err(Error::from_err)?;
109
110            fun(args).into_result().map_err(Error::from_err)
111        })
112    }
113
114    pub fn from_fn_once<F, O>(fun: F) -> Self
115    where
116        F: FnOnce(A) -> O + 'static,
117        A: Poppable,
118        O: IntoResult<R>,
119        R: Pushable,
120        O::Error: StdError + 'static,
121    {
122        let fun = RefCell::new(Some(fun));
123
124        Self::from_fn(move |args| {
125            let fun = fun
126                .try_borrow_mut()
127                .map_err(Error::from_err)?
128                .take()
129                .ok_or_else(|| {
130                    Error::from_str("Cannot call function twice")
131                })?;
132
133            fun(args).into_result().map_err(Error::from_err)
134        })
135    }
136
137    pub fn call(&self, args: A) -> Result<R, lua::Error>
138    where
139        A: Pushable,
140        R: Poppable,
141    {
142        lua::function::call(self.lua_ref, args)
143    }
144
145    /// Consumes the `Function`, removing the reference stored in the Lua
146    /// registry.
147    #[doc(hidden)]
148    pub fn remove_from_lua_registry(self) {
149        lua::function::remove(self.lua_ref)
150    }
151}
152
153#[cfg(feature = "serde")]
154mod serde {
155    use std::fmt;
156
157    use serde::de::{self, Deserialize, Deserializer, Visitor};
158    use serde::ser::{Serialize, Serializer};
159
160    use super::Function;
161    use crate::LuaRef;
162
163    impl<A, R> Serialize for Function<A, R> {
164        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
165        where
166            S: Serializer,
167        {
168            serializer.serialize_f32(self.lua_ref as f32)
169        }
170    }
171
172    impl<'de, A, R> Deserialize<'de> for Function<A, R> {
173        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
174        where
175            D: Deserializer<'de>,
176        {
177            use std::marker::PhantomData;
178
179            struct FunctionVisitor<A, R>(PhantomData<A>, PhantomData<R>);
180
181            impl<'de, A, R> Visitor<'de> for FunctionVisitor<A, R> {
182                type Value = Function<A, R>;
183
184                fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
185                    f.write_str("an f32 representing a Lua reference")
186                }
187
188                fn visit_f32<E>(self, value: f32) -> Result<Self::Value, E>
189                where
190                    E: de::Error,
191                {
192                    Ok(Function::from_ref(value as LuaRef))
193                }
194            }
195
196            deserializer
197                .deserialize_f32(FunctionVisitor(PhantomData, PhantomData))
198        }
199    }
200}