1use std::cell::RefCell;
12use std::collections::HashMap;
13use std::marker::PhantomData;
14
15use lua_types::{LuaError as VmError, LuaType};
16use lua_vm::api;
17use lua_vm::state::LuaState;
18
19use crate::any::{
20 push_any, push_hashable, read_any, read_map, read_sequence, string_bytes_at, AnyHashableLuaValue,
21 AnyLuaValue,
22};
23use crate::{Lua, LuaError};
24
25pub(crate) type Adapter = Box<dyn Fn(&mut LuaState) -> Result<usize, VmError>>;
27
28thread_local! {
29 static REGISTRY: RefCell<Vec<Option<Adapter>>> = const { RefCell::new(Vec::new()) };
30}
31
32const LUA_REGISTRYINDEX: i32 = -(1_000_000) - 1000;
34
35fn upvalue_index(n: i32) -> i32 {
36 LUA_REGISTRYINDEX - n
37}
38
39pub(crate) fn registry_insert(adapter: Adapter) -> usize {
40 REGISTRY.with(|cell| {
41 let mut slots = cell.borrow_mut();
42 match slots.iter().position(|slot| slot.is_none()) {
43 Some(i) => {
44 slots[i] = Some(adapter);
45 i
46 }
47 None => {
48 slots.push(Some(adapter));
49 slots.len() - 1
50 }
51 }
52 })
53}
54
55pub(crate) fn registry_remove(index: usize) {
56 REGISTRY.with(|cell| {
57 let mut slots = cell.borrow_mut();
58 if index < slots.len() {
59 slots[index] = None;
60 }
61 });
62}
63
64pub(crate) fn trampoline(state: &mut LuaState) -> Result<usize, VmError> {
67 let index = api::to_integer_x(state, upvalue_index(1))
68 .ok_or_else(|| VmError::runtime(format_args!("hlua-shim: closure upvalue missing")))?
69 as usize;
70 REGISTRY.with(|cell| {
71 let slots = cell.borrow();
72 match slots.get(index).and_then(|slot| slot.as_ref()) {
73 Some(adapter) => adapter(state),
74 None => Err(VmError::runtime(format_args!(
75 "hlua-shim: closure {index} not registered"
76 ))),
77 }
78 })
79}
80
81pub trait LuaReadArg: Sized {
85 fn read_arg(state: &mut LuaState, idx: i32) -> Result<Self, VmError>;
86}
87
88impl LuaReadArg for AnyLuaValue {
89 fn read_arg(state: &mut LuaState, idx: i32) -> Result<Self, VmError> {
90 Ok(read_any(state, idx))
91 }
92}
93
94impl LuaReadArg for String {
95 fn read_arg(state: &mut LuaState, idx: i32) -> Result<Self, VmError> {
96 let bytes = string_bytes_at(state, idx)
97 .ok_or_else(|| VmError::runtime(format_args!("expected string argument")))?;
98 String::from_utf8(bytes)
99 .map_err(|_| VmError::runtime(format_args!("string argument is not valid utf-8")))
100 }
101}
102
103impl LuaReadArg for Vec<AnyLuaValue> {
104 fn read_arg(state: &mut LuaState, idx: i32) -> Result<Self, VmError> {
105 Ok(read_sequence(state, idx))
106 }
107}
108
109impl LuaReadArg for HashMap<AnyHashableLuaValue, AnyLuaValue> {
110 fn read_arg(state: &mut LuaState, idx: i32) -> Result<Self, VmError> {
111 Ok(read_map(state, idx))
112 }
113}
114
115macro_rules! read_arg_int {
116 ($($ty:ty),*) => {$(
117 impl LuaReadArg for $ty {
118 fn read_arg(state: &mut LuaState, idx: i32) -> Result<Self, VmError> {
119 let n = api::to_integer_x(state, idx)
120 .ok_or_else(|| VmError::runtime(format_args!("expected integer argument")))?;
121 Ok(n as $ty)
122 }
123 }
124 )*};
125}
126read_arg_int!(i32, u32, u16, i64, u64);
127
128pub trait PushToLua {
132 fn push_to(self, state: &mut LuaState) -> Result<usize, VmError>;
133}
134
135impl PushToLua for () {
136 fn push_to(self, _state: &mut LuaState) -> Result<usize, VmError> {
137 Ok(0)
138 }
139}
140
141impl PushToLua for bool {
142 fn push_to(self, state: &mut LuaState) -> Result<usize, VmError> {
143 api::push_boolean(state, self);
144 Ok(1)
145 }
146}
147
148impl PushToLua for String {
149 fn push_to(self, state: &mut LuaState) -> Result<usize, VmError> {
150 api::push_lstring(state, self.as_bytes())?;
151 Ok(1)
152 }
153}
154
155impl PushToLua for AnyLuaValue {
156 fn push_to(self, state: &mut LuaState) -> Result<usize, VmError> {
157 push_any(state, &self)?;
158 Ok(1)
159 }
160}
161
162impl PushToLua for Vec<AnyLuaValue> {
163 fn push_to(self, state: &mut LuaState) -> Result<usize, VmError> {
164 state.create_table(self.len() as i32, 0)?;
165 let table = api::get_top(state);
166 for (i, value) in self.iter().enumerate() {
167 push_any(state, value)?;
168 state.raw_seti(table, (i + 1) as i64)?;
169 }
170 Ok(1)
171 }
172}
173
174impl PushToLua for HashMap<AnyHashableLuaValue, AnyLuaValue> {
175 fn push_to(self, state: &mut LuaState) -> Result<usize, VmError> {
176 state.create_table(0, self.len() as i32)?;
177 let table = api::get_top(state);
178 for (key, value) in &self {
179 push_hashable(state, key)?;
180 push_any(state, value)?;
181 api::raw_set(state, table)?;
182 }
183 Ok(1)
184 }
185}
186
187macro_rules! push_int {
188 ($($ty:ty),*) => {$(
189 impl PushToLua for $ty {
190 fn push_to(self, state: &mut LuaState) -> Result<usize, VmError> {
191 api::push_integer(state, self as i64);
192 Ok(1)
193 }
194 }
195 )*};
196}
197push_int!(i32, u32, u16, i64);
198
199impl<T: PushToLua, E> PushToLua for Result<T, E> {
204 fn push_to(self, state: &mut LuaState) -> Result<usize, VmError> {
205 match self {
206 Ok(value) => value.push_to(state),
207 Err(_) => {
208 api::push_nil(state);
209 Ok(1)
210 }
211 }
212 }
213}
214
215pub trait SetValue {
219 fn set_into(self, lua: &mut Lua<'_>, name: &str);
220}
221
222macro_rules! define_function {
223 ($name:ident, $wrapper:ident, ($($arg:ident : $argty:ident),*), ($($idx:expr),*)) => {
224 pub struct $wrapper<F, $($argty,)* R> {
226 f: F,
227 _marker: PhantomData<fn($($argty,)*) -> R>,
228 }
229
230 pub fn $name<F, $($argty,)* R>(f: F) -> $wrapper<F, $($argty,)* R>
232 where
233 F: Fn($($argty,)*) -> R + 'static,
234 $($argty: LuaReadArg + 'static,)*
235 R: PushToLua + 'static,
236 {
237 $wrapper { f, _marker: PhantomData }
238 }
239
240 impl<F, $($argty,)* R> SetValue for $wrapper<F, $($argty,)* R>
241 where
242 F: Fn($($argty,)*) -> R + 'static,
243 $($argty: LuaReadArg + 'static,)*
244 R: PushToLua + 'static,
245 {
246 fn set_into(self, lua: &mut Lua<'_>, name: &str) {
247 let f = self.f;
248 let adapter: Adapter = Box::new(move |state: &mut LuaState| {
249 $(let $arg = $argty::read_arg(state, $idx)?;)*
250 let result = f($($arg,)*);
251 result.push_to(state)
252 });
253 lua.install_closure(name, adapter);
254 }
255 }
256 };
257}
258
259define_function!(function0, Function0, (), ());
260define_function!(function1, Function1, (a0: A0), (1));
261define_function!(function2, Function2, (a0: A0, a1: A1), (1, 2));
262define_function!(function3, Function3, (a0: A0, a1: A1, a2: A2), (1, 2, 3));
263define_function!(function4, Function4, (a0: A0, a1: A1, a2: A2, a3: A3), (1, 2, 3, 4));
264define_function!(
265 function5,
266 Function5,
267 (a0: A0, a1: A1, a2: A2, a3: A3, a4: A4),
268 (1, 2, 3, 4, 5)
269);
270define_function!(
271 function6,
272 Function6,
273 (a0: A0, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5),
274 (1, 2, 3, 4, 5, 6)
275);
276
277pub struct LuaFunction<L> {
282 pub(crate) inner: L,
283}
284
285pub struct FnHandle<'a> {
288 pub(crate) state: &'a mut LuaState,
289 pub(crate) name: String,
290}
291
292pub trait PushArgs {
294 fn push_args(self, state: &mut LuaState) -> Result<usize, VmError>;
295}
296
297impl PushArgs for (AnyLuaValue, AnyLuaValue) {
298 fn push_args(self, state: &mut LuaState) -> Result<usize, VmError> {
299 push_any(state, &self.0)?;
300 push_any(state, &self.1)?;
301 Ok(2)
302 }
303}
304
305impl PushArgs for (AnyLuaValue,) {
306 fn push_args(self, state: &mut LuaState) -> Result<usize, VmError> {
307 push_any(state, &self.0)?;
308 Ok(1)
309 }
310}
311
312pub trait FromTop: Sized {
314 fn from_top(state: &mut LuaState) -> Self;
315}
316
317impl FromTop for AnyLuaValue {
318 fn from_top(state: &mut LuaState) -> Self {
319 read_any(state, -1)
320 }
321}
322
323impl<'a> LuaFunction<FnHandle<'a>> {
324 pub fn call_with_args<V, A>(&mut self, args: A) -> Result<V, LuaError>
326 where
327 A: PushArgs,
328 V: FromTop,
329 {
330 let state = &mut *self.inner.state;
331 api::get_global(state, self.inner.name.as_bytes()).map_err(LuaError::from_vm)?;
332 let nargs = args.push_args(state).map_err(LuaError::from_vm)?;
333 state
334 .protected_call(nargs as i32, 1, 0)
335 .map_err(LuaError::from_vm)?;
336 let value = V::from_top(state);
337 api::set_top(state, -2).ok();
338 Ok(value)
339 }
340}
341
342pub struct StringInLua<L> {
346 pub(crate) value: String,
347 pub(crate) _marker: PhantomData<L>,
348}
349
350impl<L> std::ops::Deref for StringInLua<L> {
351 type Target = str;
352 fn deref(&self) -> &str {
353 &self.value
354 }
355}
356
357pub trait FromLuaGlobal<'l>: Sized {
361 fn from_lua_global<'lua>(lua: &'l mut Lua<'lua>, name: &str) -> Option<Self>;
362}
363
364impl<'l> FromLuaGlobal<'l> for AnyLuaValue {
365 fn from_lua_global<'lua>(lua: &'l mut Lua<'lua>, name: &str) -> Option<Self> {
366 let state = lua.state_mut();
367 api::get_global(state, name.as_bytes()).ok()?;
368 let value = read_any(state, -1);
369 api::set_top(state, -2).ok();
370 Some(value)
371 }
372}
373
374impl<'l> FromLuaGlobal<'l> for String {
375 fn from_lua_global<'lua>(lua: &'l mut Lua<'lua>, name: &str) -> Option<Self> {
376 let state = lua.state_mut();
377 api::get_global(state, name.as_bytes()).ok()?;
378 let bytes = string_bytes_at(state, -1);
379 api::set_top(state, -2).ok();
380 String::from_utf8(bytes?).ok()
381 }
382}
383
384impl<'l> FromLuaGlobal<'l> for StringInLua<()> {
385 fn from_lua_global<'lua>(lua: &'l mut Lua<'lua>, name: &str) -> Option<Self> {
386 let value: String = FromLuaGlobal::from_lua_global(lua, name)?;
387 Some(StringInLua {
388 value,
389 _marker: PhantomData,
390 })
391 }
392}
393
394impl<'l> FromLuaGlobal<'l> for LuaFunction<FnHandle<'l>> {
395 fn from_lua_global<'lua>(lua: &'l mut Lua<'lua>, name: &str) -> Option<Self> {
396 let owned_name = name.to_string();
397 let state = lua.state_mut();
398 let ty = api::get_global(state, name.as_bytes()).ok()?;
399 api::set_top(state, -2).ok();
400 if ty != LuaType::Function {
401 return None;
402 }
403 Some(LuaFunction {
404 inner: FnHandle {
405 state,
406 name: owned_name,
407 },
408 })
409 }
410}