1use std::cell::RefCell;
2use std::error::Error as StdError;
3use std::ffi::c_int;
4use std::fmt;
5use std::marker::PhantomData;
6
7use luajit::{self as lua, IntoResult, Poppable, Pushable, ffi};
8
9use crate::{Error, LuaRef};
10
11#[derive(Clone, Eq, PartialEq, Hash)]
13pub struct Function<A, R> {
14 pub(crate) lua_ref: LuaRef,
15 _pd: (PhantomData<A>, PhantomData<R>),
16}
17
18impl<A, R> fmt::Debug for Function<A, R> {
19 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20 write!(
21 f,
22 "<function {}: {} -> {}>",
23 self.lua_ref,
24 std::any::type_name::<A>(),
25 std::any::type_name::<R>()
26 )
27 }
28}
29
30impl<A, R, F, O> From<F> for Function<A, R>
31where
32 F: Fn(A) -> O + 'static,
33 A: Poppable,
34 O: IntoResult<R>,
35 R: Pushable,
36 O::Error: StdError + 'static,
37{
38 fn from(fun: F) -> Function<A, R> {
39 Function::from_fn_mut(fun)
40 }
41}
42
43impl<A, R> Poppable for Function<A, R> {
44 unsafe fn pop(state: *mut lua::ffi::State) -> Result<Self, lua::Error> {
45 if ffi::lua_gettop(state) == 0 {
46 return Err(lua::Error::PopEmptyStack);
47 }
48
49 match ffi::lua_type(state, -1) {
50 ffi::LUA_TFUNCTION => {
51 let lua_ref = ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX);
52 Ok(Self::from_ref(lua_ref))
54 },
55
56 other => Err(lua::Error::pop_wrong_type::<Self>(
57 ffi::LUA_TFUNCTION,
58 other,
59 )),
60 }
61 }
62}
63
64impl<A, R> Pushable for Function<A, R> {
65 unsafe fn push(
66 self,
67 state: *mut lua::ffi::State,
68 ) -> Result<c_int, lua::Error> {
69 ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, self.lua_ref);
70 Ok(1)
71 }
72}
73
74impl<A, R> Function<A, R> {
75 pub(crate) fn from_ref(lua_ref: LuaRef) -> Self {
76 Self { lua_ref, _pd: (PhantomData, PhantomData) }
77 }
78
79 #[doc(hidden)]
80 pub fn lua_ref(&self) -> LuaRef {
81 self.lua_ref
82 }
83
84 pub fn from_fn<F, O>(fun: F) -> Self
85 where
86 F: Fn(A) -> O + 'static,
87 A: Poppable,
88 O: IntoResult<R>,
89 R: Pushable,
90 O::Error: StdError + 'static,
91 {
92 Self::from_ref(lua::function::store(fun))
93 }
94
95 pub fn from_fn_mut<F, O>(fun: F) -> Self
96 where
97 F: FnMut(A) -> O + 'static,
98 A: Poppable,
99 O: IntoResult<R>,
100 R: Pushable,
101 O::Error: StdError + 'static,
102 {
103 let fun = RefCell::new(fun);
104
105 Self::from_fn(move |args| {
106 let fun = &mut *fun.try_borrow_mut().map_err(Error::from_err)?;
107
108 fun(args).into_result().map_err(Error::from_err)
109 })
110 }
111
112 pub fn from_fn_once<F, O>(fun: F) -> Self
113 where
114 F: FnOnce(A) -> O + 'static,
115 A: Poppable,
116 O: IntoResult<R>,
117 R: Pushable,
118 O::Error: StdError + 'static,
119 {
120 let fun = RefCell::new(Some(fun));
121
122 Self::from_fn(move |args| {
123 let fun = fun
124 .try_borrow_mut()
125 .map_err(Error::from_err)?
126 .take()
127 .ok_or_else(|| {
128 Error::from_str("Cannot call function twice")
129 })?;
130
131 fun(args).into_result().map_err(Error::from_err)
132 })
133 }
134
135 pub fn call(&self, args: A) -> Result<R, lua::Error>
136 where
137 A: Pushable,
138 R: Poppable,
139 {
140 lua::function::call(self.lua_ref, args)
141 }
142
143 #[doc(hidden)]
146 pub fn remove_from_lua_registry(self) {
147 lua::function::remove(self.lua_ref)
148 }
149}
150
151#[cfg(feature = "serde")]
152mod serde {
153 use std::fmt;
154
155 use serde::de::{self, Deserialize, Deserializer, Visitor};
156 use serde::ser::{Serialize, Serializer};
157
158 use super::Function;
159 use crate::LuaRef;
160
161 impl<A, R> Serialize for Function<A, R> {
162 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
163 where
164 S: Serializer,
165 {
166 serializer.serialize_f32(self.lua_ref as f32)
167 }
168 }
169
170 impl<'de, A, R> Deserialize<'de> for Function<A, R> {
171 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
172 where
173 D: Deserializer<'de>,
174 {
175 use std::marker::PhantomData;
176
177 struct FunctionVisitor<A, R>(PhantomData<A>, PhantomData<R>);
178
179 impl<A, R> Visitor<'_> for FunctionVisitor<A, R> {
180 type Value = Function<A, R>;
181
182 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
183 f.write_str("an f32 representing a Lua reference")
184 }
185
186 fn visit_f32<E>(self, value: f32) -> Result<Self::Value, E>
187 where
188 E: de::Error,
189 {
190 Ok(Function::from_ref(value as LuaRef))
191 }
192 }
193
194 deserializer
195 .deserialize_f32(FunctionVisitor(PhantomData, PhantomData))
196 }
197 }
198}