nvim_types/
function.rs

1use std::cell::RefCell;
2use std::error::Error as StdError;
3use std::ffi::c_int;
4use std::fmt;
5use std::marker::PhantomData;
6
7use luajit_bindings::{self as lua, ffi, Poppable, Pushable};
8
9use crate::{Error, LuaRef};
10
11/// A wrapper around a Lua reference to a function stored in the Lua registry.
12#[derive(Clone, Eq, PartialEq, Hash)]
13pub struct Function<A, R> {
14    pub(crate) lua_ref: LuaRef,
15    _pd: (PhantomData<A>, PhantomData<R>),
16}
17
18impl<A, R> fmt::Debug for Function<A, R> {
19    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20        write!(
21            f,
22            "<function {}: {} -> {}>",
23            self.lua_ref,
24            std::any::type_name::<A>(),
25            std::any::type_name::<R>()
26        )
27    }
28}
29
30impl<A, R, F, E> From<F> for Function<A, R>
31where
32    F: FnMut(A) -> Result<R, E> + 'static,
33    A: Poppable,
34    R: Pushable,
35    E: StdError + 'static,
36{
37    fn from(fun: F) -> Function<A, R> {
38        Function::from_fn_mut(fun)
39    }
40}
41
42impl<A, R> Poppable for Function<A, R> {
43    unsafe fn pop(
44        state: *mut lua::ffi::lua_State,
45    ) -> Result<Self, lua::Error> {
46        if ffi::lua_gettop(state) == 0 {
47            return Err(lua::Error::PopEmptyStack);
48        }
49
50        match ffi::lua_type(state, -1) {
51            ffi::LUA_TFUNCTION => {
52                let lua_ref = ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX);
53                // TODO: check `lua_ref`.
54                Ok(Self::from_ref(lua_ref))
55            },
56
57            other => Err(lua::Error::pop_wrong_type::<Self>(
58                ffi::LUA_TFUNCTION,
59                other,
60            )),
61        }
62    }
63}
64
65impl<A, R> Pushable for Function<A, R> {
66    unsafe fn push(
67        self,
68        state: *mut lua::ffi::lua_State,
69    ) -> Result<c_int, lua::Error> {
70        ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, self.lua_ref);
71        Ok(1)
72    }
73}
74
75impl<A, R> Function<A, R> {
76    pub(crate) fn from_ref(lua_ref: LuaRef) -> Self {
77        Self { lua_ref, _pd: (PhantomData, PhantomData) }
78    }
79
80    #[doc(hidden)]
81    pub fn lua_ref(&self) -> LuaRef {
82        self.lua_ref
83    }
84
85    pub fn from_fn<F, E>(fun: F) -> Self
86    where
87        F: Fn(A) -> Result<R, E> + 'static,
88        A: Poppable,
89        R: Pushable,
90        E: StdError + 'static,
91    {
92        Self::from_ref(lua::function::store(fun))
93    }
94
95    pub fn from_fn_mut<F, E>(fun: F) -> Self
96    where
97        F: FnMut(A) -> Result<R, E> + 'static,
98        A: Poppable,
99        R: Pushable,
100        E: StdError + 'static,
101    {
102        let fun = RefCell::new(fun);
103
104        Self::from_fn(move |args| {
105            let fun = &mut *fun.try_borrow_mut().map_err(Error::from_err)?;
106
107            fun(args).map_err(Error::from_err)
108        })
109    }
110
111    pub fn from_fn_once<F, E>(fun: F) -> Self
112    where
113        F: FnOnce(A) -> Result<R, E> + 'static,
114        A: Poppable,
115        R: Pushable,
116        E: StdError + 'static,
117    {
118        let fun = RefCell::new(Some(fun));
119
120        Self::from_fn(move |args| {
121            let fun = fun
122                .try_borrow_mut()
123                .map_err(Error::from_err)?
124                .take()
125                .ok_or_else(|| {
126                    Error::from_str("Cannot call function twice")
127                })?;
128
129            fun(args).map_err(Error::from_err)
130        })
131    }
132
133    pub fn call(&self, args: A) -> Result<R, lua::Error>
134    where
135        A: Pushable,
136        R: Poppable,
137    {
138        lua::function::call(self.lua_ref, args)
139    }
140
141    /// Consumes the `Function`, removing the reference stored in the Lua
142    /// registry.
143    #[doc(hidden)]
144    pub fn remove_from_lua_registry(self) {
145        lua::function::remove(self.lua_ref)
146    }
147}
148
149#[cfg(feature = "serde")]
150mod serde {
151    use std::fmt;
152
153    use serde::de::{self, Deserialize, Deserializer, Visitor};
154    use serde::ser::{Serialize, Serializer};
155
156    use super::Function;
157    use crate::LuaRef;
158
159    impl<A, R> Serialize for Function<A, R> {
160        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
161        where
162            S: Serializer,
163        {
164            serializer.serialize_f32(self.lua_ref as f32)
165        }
166    }
167
168    impl<'de, A, R> Deserialize<'de> for Function<A, R> {
169        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
170        where
171            D: Deserializer<'de>,
172        {
173            use std::marker::PhantomData;
174
175            struct FunctionVisitor<A, R>(PhantomData<A>, PhantomData<R>);
176
177            impl<'de, A, R> Visitor<'de> for FunctionVisitor<A, R> {
178                type Value = Function<A, R>;
179
180                fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
181                    f.write_str("an f32 representing a Lua reference")
182                }
183
184                fn visit_f32<E>(self, value: f32) -> Result<Self::Value, E>
185                where
186                    E: de::Error,
187                {
188                    Ok(Function::from_ref(value as LuaRef))
189                }
190            }
191
192            deserializer
193                .deserialize_f32(FunctionVisitor(PhantomData, PhantomData))
194        }
195    }
196}