1use std::cell::RefCell;
2use std::error::Error as StdError;
3use std::ffi::c_int;
4use std::fmt;
5use std::marker::PhantomData;
6
7use luajit::{self as lua, ffi, IntoResult, Poppable, Pushable};
8
9use crate::{Error, LuaRef};
10
11#[derive(Clone, Eq, PartialEq, Hash)]
13pub struct Function<A, R> {
14 pub(crate) lua_ref: LuaRef,
15 _pd: (PhantomData<A>, PhantomData<R>),
16}
17
18impl<A, R> fmt::Debug for Function<A, R> {
19 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20 write!(
21 f,
22 "<function {}: {} -> {}>",
23 self.lua_ref,
24 std::any::type_name::<A>(),
25 std::any::type_name::<R>()
26 )
27 }
28}
29
30impl<A, R, F, O> From<F> for Function<A, R>
31where
32 F: Fn(A) -> O + 'static,
33 A: Poppable,
34 O: IntoResult<R>,
35 R: Pushable,
36 O::Error: StdError + 'static,
37{
38 fn from(fun: F) -> Function<A, R> {
39 Function::from_fn_mut(fun)
40 }
41}
42
43impl<A, R> Poppable for Function<A, R> {
44 unsafe fn pop(
45 state: *mut lua::ffi::lua_State,
46 ) -> Result<Self, lua::Error> {
47 if ffi::lua_gettop(state) == 0 {
48 return Err(lua::Error::PopEmptyStack);
49 }
50
51 match ffi::lua_type(state, -1) {
52 ffi::LUA_TFUNCTION => {
53 let lua_ref = ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX);
54 Ok(Self::from_ref(lua_ref))
56 },
57
58 other => Err(lua::Error::pop_wrong_type::<Self>(
59 ffi::LUA_TFUNCTION,
60 other,
61 )),
62 }
63 }
64}
65
66impl<A, R> Pushable for Function<A, R> {
67 unsafe fn push(
68 self,
69 state: *mut lua::ffi::lua_State,
70 ) -> Result<c_int, lua::Error> {
71 ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, self.lua_ref);
72 Ok(1)
73 }
74}
75
76impl<A, R> Function<A, R> {
77 pub(crate) fn from_ref(lua_ref: LuaRef) -> Self {
78 Self { lua_ref, _pd: (PhantomData, PhantomData) }
79 }
80
81 #[doc(hidden)]
82 pub fn lua_ref(&self) -> LuaRef {
83 self.lua_ref
84 }
85
86 pub fn from_fn<F, O>(fun: F) -> Self
87 where
88 F: Fn(A) -> O + 'static,
89 A: Poppable,
90 O: IntoResult<R>,
91 R: Pushable,
92 O::Error: StdError + 'static,
93 {
94 Self::from_ref(lua::function::store(fun))
95 }
96
97 pub fn from_fn_mut<F, O>(fun: F) -> Self
98 where
99 F: FnMut(A) -> O + 'static,
100 A: Poppable,
101 O: IntoResult<R>,
102 R: Pushable,
103 O::Error: StdError + 'static,
104 {
105 let fun = RefCell::new(fun);
106
107 Self::from_fn(move |args| {
108 let fun = &mut *fun.try_borrow_mut().map_err(Error::from_err)?;
109
110 fun(args).into_result().map_err(Error::from_err)
111 })
112 }
113
114 pub fn from_fn_once<F, O>(fun: F) -> Self
115 where
116 F: FnOnce(A) -> O + 'static,
117 A: Poppable,
118 O: IntoResult<R>,
119 R: Pushable,
120 O::Error: StdError + 'static,
121 {
122 let fun = RefCell::new(Some(fun));
123
124 Self::from_fn(move |args| {
125 let fun = fun
126 .try_borrow_mut()
127 .map_err(Error::from_err)?
128 .take()
129 .ok_or_else(|| {
130 Error::from_str("Cannot call function twice")
131 })?;
132
133 fun(args).into_result().map_err(Error::from_err)
134 })
135 }
136
137 pub fn call(&self, args: A) -> Result<R, lua::Error>
138 where
139 A: Pushable,
140 R: Poppable,
141 {
142 lua::function::call(self.lua_ref, args)
143 }
144
145 #[doc(hidden)]
148 pub fn remove_from_lua_registry(self) {
149 lua::function::remove(self.lua_ref)
150 }
151}
152
153#[cfg(feature = "serde")]
154mod serde {
155 use std::fmt;
156
157 use serde::de::{self, Deserialize, Deserializer, Visitor};
158 use serde::ser::{Serialize, Serializer};
159
160 use super::Function;
161 use crate::LuaRef;
162
163 impl<A, R> Serialize for Function<A, R> {
164 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
165 where
166 S: Serializer,
167 {
168 serializer.serialize_f32(self.lua_ref as f32)
169 }
170 }
171
172 impl<'de, A, R> Deserialize<'de> for Function<A, R> {
173 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
174 where
175 D: Deserializer<'de>,
176 {
177 use std::marker::PhantomData;
178
179 struct FunctionVisitor<A, R>(PhantomData<A>, PhantomData<R>);
180
181 impl<'de, A, R> Visitor<'de> for FunctionVisitor<A, R> {
182 type Value = Function<A, R>;
183
184 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
185 f.write_str("an f32 representing a Lua reference")
186 }
187
188 fn visit_f32<E>(self, value: f32) -> Result<Self::Value, E>
189 where
190 E: de::Error,
191 {
192 Ok(Function::from_ref(value as LuaRef))
193 }
194 }
195
196 deserializer
197 .deserialize_f32(FunctionVisitor(PhantomData, PhantomData))
198 }
199 }
200}