1use std::cell::RefCell;
2use std::error::Error as StdError;
3use std::ffi::c_int;
4use std::fmt;
5use std::marker::PhantomData;
6
7use luajit_bindings::{self as lua, ffi, Poppable, Pushable};
8
9use crate::{Error, LuaRef};
10
11#[derive(Clone, Eq, PartialEq, Hash)]
13pub struct Function<A, R> {
14 pub(crate) lua_ref: LuaRef,
15 _pd: (PhantomData<A>, PhantomData<R>),
16}
17
18impl<A, R> fmt::Debug for Function<A, R> {
19 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20 write!(
21 f,
22 "<function {}: {} -> {}>",
23 self.lua_ref,
24 std::any::type_name::<A>(),
25 std::any::type_name::<R>()
26 )
27 }
28}
29
30impl<A, R, F, E> From<F> for Function<A, R>
31where
32 F: FnMut(A) -> Result<R, E> + 'static,
33 A: Poppable,
34 R: Pushable,
35 E: StdError + 'static,
36{
37 fn from(fun: F) -> Function<A, R> {
38 Function::from_fn_mut(fun)
39 }
40}
41
42impl<A, R> Poppable for Function<A, R> {
43 unsafe fn pop(
44 state: *mut lua::ffi::lua_State,
45 ) -> Result<Self, lua::Error> {
46 if ffi::lua_gettop(state) == 0 {
47 return Err(lua::Error::PopEmptyStack);
48 }
49
50 match ffi::lua_type(state, -1) {
51 ffi::LUA_TFUNCTION => {
52 let lua_ref = ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX);
53 Ok(Self::from_ref(lua_ref))
55 },
56
57 other => Err(lua::Error::pop_wrong_type::<Self>(
58 ffi::LUA_TFUNCTION,
59 other,
60 )),
61 }
62 }
63}
64
65impl<A, R> Pushable for Function<A, R> {
66 unsafe fn push(
67 self,
68 state: *mut lua::ffi::lua_State,
69 ) -> Result<c_int, lua::Error> {
70 ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, self.lua_ref);
71 Ok(1)
72 }
73}
74
75impl<A, R> Function<A, R> {
76 pub(crate) fn from_ref(lua_ref: LuaRef) -> Self {
77 Self { lua_ref, _pd: (PhantomData, PhantomData) }
78 }
79
80 #[doc(hidden)]
81 pub fn lua_ref(&self) -> LuaRef {
82 self.lua_ref
83 }
84
85 pub fn from_fn<F, E>(fun: F) -> Self
86 where
87 F: Fn(A) -> Result<R, E> + 'static,
88 A: Poppable,
89 R: Pushable,
90 E: StdError + 'static,
91 {
92 Self::from_ref(lua::function::store(fun))
93 }
94
95 pub fn from_fn_mut<F, E>(fun: F) -> Self
96 where
97 F: FnMut(A) -> Result<R, E> + 'static,
98 A: Poppable,
99 R: Pushable,
100 E: StdError + 'static,
101 {
102 let fun = RefCell::new(fun);
103
104 Self::from_fn(move |args| {
105 let fun = &mut *fun.try_borrow_mut().map_err(Error::from_err)?;
106
107 fun(args).map_err(Error::from_err)
108 })
109 }
110
111 pub fn from_fn_once<F, E>(fun: F) -> Self
112 where
113 F: FnOnce(A) -> Result<R, E> + 'static,
114 A: Poppable,
115 R: Pushable,
116 E: StdError + 'static,
117 {
118 let fun = RefCell::new(Some(fun));
119
120 Self::from_fn(move |args| {
121 let fun = fun
122 .try_borrow_mut()
123 .map_err(Error::from_err)?
124 .take()
125 .ok_or_else(|| {
126 Error::from_str("Cannot call function twice")
127 })?;
128
129 fun(args).map_err(Error::from_err)
130 })
131 }
132
133 pub fn call(&self, args: A) -> Result<R, lua::Error>
134 where
135 A: Pushable,
136 R: Poppable,
137 {
138 lua::function::call(self.lua_ref, args)
139 }
140
141 #[doc(hidden)]
144 pub fn remove_from_lua_registry(self) {
145 lua::function::remove(self.lua_ref)
146 }
147}
148
149#[cfg(feature = "serde")]
150mod serde {
151 use std::fmt;
152
153 use serde::de::{self, Deserialize, Deserializer, Visitor};
154 use serde::ser::{Serialize, Serializer};
155
156 use super::Function;
157 use crate::LuaRef;
158
159 impl<A, R> Serialize for Function<A, R> {
160 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
161 where
162 S: Serializer,
163 {
164 serializer.serialize_f32(self.lua_ref as f32)
165 }
166 }
167
168 impl<'de, A, R> Deserialize<'de> for Function<A, R> {
169 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
170 where
171 D: Deserializer<'de>,
172 {
173 use std::marker::PhantomData;
174
175 struct FunctionVisitor<A, R>(PhantomData<A>, PhantomData<R>);
176
177 impl<'de, A, R> Visitor<'de> for FunctionVisitor<A, R> {
178 type Value = Function<A, R>;
179
180 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
181 f.write_str("an f32 representing a Lua reference")
182 }
183
184 fn visit_f32<E>(self, value: f32) -> Result<Self::Value, E>
185 where
186 E: de::Error,
187 {
188 Ok(Function::from_ref(value as LuaRef))
189 }
190 }
191
192 deserializer
193 .deserialize_f32(FunctionVisitor(PhantomData, PhantomData))
194 }
195 }
196}