Skip to main content

lua_rs_derive/
lib.rs

1//! Derive macros for the lua-rs embedding API.
2//!
3//! - `#[derive(LuaUserData)]` on a struct generates the `UserData` impl that exposes
4//!   the struct's fields to Lua (`obj.field` reads/writes), with field attributes
5//!   `#[lua(skip)]`, `#[lua(readonly)]`, `#[lua(name = "...")]`. `IntoLua` comes for
6//!   free from the runtime's blanket `impl<T: UserData> IntoLua for T`.
7//! - Struct attribute `#[lua_impl(Display, PartialEq, PartialOrd)]` wires the matching
8//!   metamethods (`__tostring`, `__eq`, `__lt`/`__le`) from the type's Rust trait impls.
9//! - Struct attribute `#[lua(methods)]` makes the generated `UserData` also register the
10//!   methods declared by `#[lua_methods]` on an `impl` block.
11//! - `#[lua_methods]` on an `impl` block exposes each `pub fn(&self/&mut self, ...)` to
12//!   Lua as `obj:method(args)`.
13
14use proc_macro::TokenStream;
15use quote::quote;
16use syn::{
17    parse_macro_input, Data, DeriveInput, Fields, FnArg, ImplItem, ItemImpl, LitStr, ReturnType,
18    Type,
19};
20
21// ---------------------------------------------------------------------------
22// #[derive(LuaUserData)]
23// ---------------------------------------------------------------------------
24
25struct FieldCfg {
26    ident: syn::Ident,
27    ty: Type,
28    lua_name: String,
29    skip: bool,
30    readonly: bool,
31}
32
33fn parse_field_cfg(field: &syn::Field) -> syn::Result<FieldCfg> {
34    let ident = field
35        .ident
36        .clone()
37        .ok_or_else(|| syn::Error::new_spanned(field, "LuaUserData requires named fields"))?;
38    let mut cfg = FieldCfg {
39        lua_name: ident.to_string(),
40        ident,
41        ty: field.ty.clone(),
42        skip: false,
43        readonly: false,
44    };
45    for attr in &field.attrs {
46        if !attr.path().is_ident("lua") {
47            continue;
48        }
49        attr.parse_nested_meta(|meta| {
50            if meta.path.is_ident("skip") {
51                cfg.skip = true;
52                Ok(())
53            } else if meta.path.is_ident("readonly") {
54                cfg.readonly = true;
55                Ok(())
56            } else if meta.path.is_ident("name") {
57                let lit: LitStr = meta.value()?.parse()?;
58                cfg.lua_name = lit.value();
59                Ok(())
60            } else {
61                Err(meta.error("unknown #[lua(...)] attribute; expected skip, readonly, or name"))
62            }
63        })?;
64    }
65    Ok(cfg)
66}
67
68/// Struct-level configuration from `#[lua(methods)]` and `#[lua_impl(...)]`.
69struct StructCfg {
70    register_methods: bool,
71    impl_display: bool,
72    impl_partial_eq: bool,
73    impl_partial_ord: bool,
74}
75
76fn parse_struct_cfg(input: &DeriveInput) -> syn::Result<StructCfg> {
77    let mut cfg = StructCfg {
78        register_methods: false,
79        impl_display: false,
80        impl_partial_eq: false,
81        impl_partial_ord: false,
82    };
83    for attr in &input.attrs {
84        if attr.path().is_ident("lua") {
85            attr.parse_nested_meta(|meta| {
86                if meta.path.is_ident("methods") {
87                    cfg.register_methods = true;
88                    Ok(())
89                } else {
90                    Err(meta.error("unknown #[lua(...)] attribute on struct; expected methods"))
91                }
92            })?;
93        } else if attr.path().is_ident("lua_impl") {
94            attr.parse_nested_meta(|meta| {
95                if meta.path.is_ident("Display") {
96                    cfg.impl_display = true;
97                    Ok(())
98                } else if meta.path.is_ident("PartialEq") {
99                    cfg.impl_partial_eq = true;
100                    Ok(())
101                } else if meta.path.is_ident("PartialOrd") {
102                    cfg.impl_partial_ord = true;
103                    Ok(())
104                } else {
105                    Err(meta.error(
106                        "unknown #[lua_impl(...)] trait; expected Display, PartialEq, or PartialOrd",
107                    ))
108                }
109            })?;
110        }
111    }
112    Ok(cfg)
113}
114
115/// Derive `UserData` for a struct: field access plus optional methods/metamethods.
116#[proc_macro_derive(LuaUserData, attributes(lua, lua_impl))]
117pub fn derive_lua_user_data(input: TokenStream) -> TokenStream {
118    let input = parse_macro_input!(input as DeriveInput);
119    expand_derive(input).unwrap_or_else(|e| e.to_compile_error().into())
120}
121
122fn expand_derive(input: DeriveInput) -> syn::Result<TokenStream> {
123    let name = &input.ident;
124
125    if !input.generics.params.is_empty() {
126        return Err(syn::Error::new_spanned(
127            &input.generics,
128            "LuaUserData does not yet support generic types",
129        ));
130    }
131
132    let scfg = parse_struct_cfg(&input)?;
133
134    let fields = match &input.data {
135        Data::Struct(s) => match &s.fields {
136            Fields::Named(named) => &named.named,
137            _ => {
138                return Err(syn::Error::new_spanned(
139                    &input.ident,
140                    "LuaUserData currently supports only structs with named fields",
141                ))
142            }
143        },
144        _ => {
145            return Err(syn::Error::new_spanned(
146                &input.ident,
147                "LuaUserData currently supports only structs",
148            ))
149        }
150    };
151
152    let mut field_regs = Vec::new();
153    for field in fields {
154        let cfg = parse_field_cfg(field)?;
155        if cfg.skip {
156            continue;
157        }
158        let ident = &cfg.ident;
159        let ty = &cfg.ty;
160        let lua_name = &cfg.lua_name;
161        field_regs.push(quote! {
162            __m.add_field_method_get(#lua_name, |_, __this| {
163                ::core::result::Result::Ok(::core::clone::Clone::clone(&__this.#ident))
164            });
165        });
166        if !cfg.readonly {
167            field_regs.push(quote! {
168                __m.add_field_method_set(#lua_name, |_, __this, __value: #ty| {
169                    __this.#ident = __value;
170                    ::core::result::Result::Ok(())
171                });
172            });
173        }
174    }
175
176    let methods_call = if scfg.register_methods {
177        quote! { <Self>::__lua_register_methods(__m); }
178    } else {
179        quote! {}
180    };
181
182    let mut meta_regs = Vec::new();
183    if scfg.impl_display {
184        meta_regs.push(quote! {
185            __m.add_meta_method(::lua_rs_runtime::MetaMethod::ToString, |_, __this, ()| {
186                ::core::result::Result::Ok(::std::string::ToString::to_string(__this))
187            });
188        });
189    }
190    if scfg.impl_partial_eq {
191        meta_regs.push(quote! {
192            __m.add_meta_method(
193                ::lua_rs_runtime::MetaMethod::Eq,
194                |_, __this, __other: ::lua_rs_runtime::Value| {
195                    if let ::lua_rs_runtime::Value::UserData(__ud) = __other {
196                        if let ::core::result::Result::Ok(__o) = __ud.borrow::<#name>() {
197                            return ::core::result::Result::Ok(*__this == *__o);
198                        }
199                    }
200                    ::core::result::Result::Ok(false)
201                },
202            );
203        });
204    }
205    if scfg.impl_partial_ord {
206        meta_regs.push(quote! {
207            __m.add_meta_method(
208                ::lua_rs_runtime::MetaMethod::Lt,
209                |_, __this, __other: ::lua_rs_runtime::Value| {
210                    if let ::lua_rs_runtime::Value::UserData(__ud) = __other {
211                        if let ::core::result::Result::Ok(__o) = __ud.borrow::<#name>() {
212                            return ::core::result::Result::Ok(*__this < *__o);
213                        }
214                    }
215                    ::core::result::Result::Ok(false)
216                },
217            );
218            __m.add_meta_method(
219                ::lua_rs_runtime::MetaMethod::Le,
220                |_, __this, __other: ::lua_rs_runtime::Value| {
221                    if let ::lua_rs_runtime::Value::UserData(__ud) = __other {
222                        if let ::core::result::Result::Ok(__o) = __ud.borrow::<#name>() {
223                            return ::core::result::Result::Ok(*__this <= *__o);
224                        }
225                    }
226                    ::core::result::Result::Ok(false)
227                },
228            );
229        });
230    }
231
232    let add_meta_methods = if meta_regs.is_empty() {
233        quote! {}
234    } else {
235        quote! {
236            fn add_meta_methods<__M: ::lua_rs_runtime::UserDataMethods<Self>>(__m: &mut __M) {
237                #(#meta_regs)*
238            }
239        }
240    };
241
242    let expanded = quote! {
243        impl ::lua_rs_runtime::UserData for #name {
244            fn add_methods<__M: ::lua_rs_runtime::UserDataMethods<Self>>(__m: &mut __M) {
245                #(#field_regs)*
246                #methods_call
247            }
248            #add_meta_methods
249        }
250    };
251
252    Ok(expanded.into())
253}
254
255// ---------------------------------------------------------------------------
256// #[lua_methods]
257// ---------------------------------------------------------------------------
258
259/// Expose an `impl` block's public methods to Lua as `obj:method(args)`.
260#[proc_macro_attribute]
261pub fn lua_methods(_attr: TokenStream, item: TokenStream) -> TokenStream {
262    let item = parse_macro_input!(item as ItemImpl);
263    expand_methods(item).unwrap_or_else(|e| e.to_compile_error().into())
264}
265
266fn expand_methods(item: ItemImpl) -> syn::Result<TokenStream> {
267    let self_ty = &item.self_ty;
268    let mut regs = Vec::new();
269
270    for impl_item in &item.items {
271        let ImplItem::Fn(method) = impl_item else {
272            continue;
273        };
274        if !matches!(method.vis, syn::Visibility::Public(_)) {
275            continue;
276        }
277
278        // Must have a self receiver to be callable as obj:method(...).
279        let receiver = method.sig.inputs.first().and_then(|arg| match arg {
280            FnArg::Receiver(r) => Some(r),
281            _ => None,
282        });
283        let Some(receiver) = receiver else {
284            continue;
285        };
286        let is_mut = receiver.mutability.is_some();
287
288        let name = &method.sig.ident;
289        let lua_name = name.to_string();
290
291        // Collect the non-self arguments: names + types.
292        let mut arg_names = Vec::new();
293        let mut arg_types = Vec::new();
294        for (i, arg) in method.sig.inputs.iter().enumerate().skip(1) {
295            let FnArg::Typed(pat) = arg else {
296                return Err(syn::Error::new_spanned(
297                    arg,
298                    "#[lua_methods] does not support a second receiver",
299                ));
300            };
301            let ident = syn::Ident::new(&format!("__a{i}"), proc_macro2::Span::call_site());
302            arg_names.push(ident);
303            arg_types.push((*pat.ty).clone());
304        }
305
306        // Closure argument binding: () for none, `name: T` for one, `(..): (..)` for many.
307        let arg_binding = match arg_names.len() {
308            0 => quote! { () },
309            1 => {
310                let n = &arg_names[0];
311                let t = &arg_types[0];
312                quote! { #n: #t }
313            }
314            _ => {
315                quote! { ( #(#arg_names),* ): ( #(#arg_types),* ) }
316            }
317        };
318
319        let call = quote! { <#self_ty>::#name(__this #(, #arg_names)*) };
320        let returns_unit = matches!(&method.sig.output, ReturnType::Default)
321            || matches!(&method.sig.output, ReturnType::Type(_, ty) if is_unit_type(ty));
322        let body = if returns_unit {
323            quote! { { #call; ::core::result::Result::Ok(()) } }
324        } else {
325            quote! { ::core::result::Result::Ok(#call) }
326        };
327
328        if is_mut {
329            regs.push(quote! {
330                __m.add_method_mut(#lua_name, |_, __this, #arg_binding| #body);
331            });
332        } else {
333            regs.push(quote! {
334                __m.add_method(#lua_name, |_, __this, #arg_binding| #body);
335            });
336        }
337    }
338
339    let expanded = quote! {
340        #item
341
342        impl #self_ty {
343            #[doc(hidden)]
344            fn __lua_register_methods<__M: ::lua_rs_runtime::UserDataMethods<Self>>(__m: &mut __M) {
345                #(#regs)*
346            }
347        }
348    };
349
350    Ok(expanded.into())
351}
352
353fn is_unit_type(ty: &Type) -> bool {
354    matches!(ty, Type::Tuple(t) if t.elems.is_empty())
355}