Skip to main content

luars_derive/
lib.rs

1//! Procedural macros for luars userdata system.
2//!
3//! Provides `#[derive(LuaUserData)]` to auto-generate `UserDataTrait` implementations
4//! for Rust structs, exposing public fields and methods to Lua.
5//!
6//! # Attributes
7//!
8//! - `#[lua(skip)]` on a field — exclude it from Lua access
9//! - `#[lua(readonly)]` on a field — only allow get, not set
10//! - `#[lua(name = "...")]` on a field — use a custom Lua-visible name
11//!
12//! # Auto-detected trait impls
13//!
14//! Use `#[lua_impl(...)]` on the struct to declare which Rust traits should map to Lua metamethods:
15//! - `Display` → `__tostring`
16//! - `PartialEq` → `__eq`
17//! - `PartialOrd` → `__lt`, `__le`
18//!
19//! # Example
20//!
21//! ```ignore
22//! use luars_derive::LuaUserData;
23//!
24//! #[derive(LuaUserData, PartialEq, PartialOrd)]
25//! #[lua_impl(Display, PartialEq, PartialOrd)]
26//! struct Point {
27//!     pub x: f64,
28//!     pub y: f64,
29//!     #[lua(skip)]
30//!     internal_id: u32,
31//! }
32//! ```
33
34use proc_macro::TokenStream;
35use quote::quote;
36use syn::{Data, DeriveInput, Fields, Ident, Meta, parse_macro_input};
37
38/// Derive `UserDataTrait` for a struct, exposing public fields to Lua.
39///
40/// # Supported field types (auto-converted to/from UdValue)
41/// - `i8`, `i16`, `i32`, `i64`, `isize` → `UdValue::Integer`
42/// - `u8`, `u16`, `u32`, `u64`, `usize` → `UdValue::Integer`
43/// - `f32`, `f64` → `UdValue::Number`
44/// - `bool` → `UdValue::Boolean`
45/// - `String` → `UdValue::Str`
46#[proc_macro_derive(LuaUserData, attributes(lua, lua_impl))]
47pub fn derive_lua_userdata(input: TokenStream) -> TokenStream {
48    let input = parse_macro_input!(input as DeriveInput);
49    let name = &input.ident;
50
51    // Parse #[lua_impl(...)] attribute for trait detection
52    let trait_impls = parse_lua_impl_attrs(&input);
53
54    // Only works on structs with named fields
55    let fields = match &input.data {
56        Data::Struct(data) => match &data.fields {
57            Fields::Named(fields) => &fields.named,
58            _ => {
59                return syn::Error::new_spanned(
60                    &input.ident,
61                    "LuaUserData can only be derived for structs with named fields",
62                )
63                .to_compile_error()
64                .into();
65            }
66        },
67        _ => {
68            return syn::Error::new_spanned(
69                &input.ident,
70                "LuaUserData can only be derived for structs",
71            )
72            .to_compile_error()
73            .into();
74        }
75    };
76
77    // Collect field info
78    let mut field_infos: Vec<FieldInfo> = Vec::new();
79    for field in fields.iter() {
80        let ident = field.ident.as_ref().unwrap();
81        let ty = &field.ty;
82        let is_pub = matches!(field.vis, syn::Visibility::Public(_));
83
84        // Parse field attributes
85        let mut skip = false;
86        let mut readonly = false;
87        let mut lua_name: Option<String> = None;
88
89        for attr in &field.attrs {
90            if attr.path().is_ident("lua") {
91                if let Ok(list) = attr.meta.require_list() {
92                    let _ = list.parse_nested_meta(|meta| {
93                        if meta.path.is_ident("skip") {
94                            skip = true;
95                        } else if meta.path.is_ident("readonly") {
96                            readonly = true;
97                        } else if meta.path.is_ident("name") {
98                            if let Ok(value) = meta.value() {
99                                if let Ok(lit) = value.parse::<syn::LitStr>() {
100                                    lua_name = Some(lit.value());
101                                }
102                            }
103                        }
104                        Ok(())
105                    });
106                }
107            }
108        }
109
110        if skip || !is_pub {
111            continue;
112        }
113
114        let name_str = lua_name.unwrap_or_else(|| ident.to_string());
115        field_infos.push(FieldInfo {
116            ident: ident.clone(),
117            ty: ty.clone(),
118            lua_name: name_str,
119            readonly,
120        });
121    }
122
123    // Generate get_field match arms
124    let get_field_arms = field_infos.iter().map(|f| {
125        let ident = &f.ident;
126        let lua_name = &f.lua_name;
127        let conversion = field_to_udvalue(&f.ty, quote!(self.#ident));
128        quote! { #lua_name => Some(#conversion), }
129    });
130
131    // Generate set_field match arms (writable fields)
132    let set_field_arms = field_infos.iter().filter(|f| !f.readonly).map(|f| {
133        let ident = &f.ident;
134        let lua_name = &f.lua_name;
135        let assign = udvalue_to_field(&f.ty, quote!(self.#ident), lua_name);
136        quote! { #lua_name => { #assign } }
137    });
138
139    // Generate set_field match arms (readonly fields → error)
140    let readonly_set_arms = field_infos.iter().filter(|f| f.readonly).map(|f| {
141        let lua_name = &f.lua_name;
142        quote! { #lua_name => Some(Err(format!("field '{}' is read-only", #lua_name))), }
143    });
144
145    // Generate field_names list
146    let field_name_strs: Vec<&String> = field_infos.iter().map(|f| &f.lua_name).collect();
147
148    // Generate metamethod impls based on #[lua_impl(...)]
149    let tostring_impl = if trait_impls.contains(&"Display".to_string()) {
150        quote! {
151            fn lua_tostring(&self) -> Option<String> {
152                Some(format!("{}", self))
153            }
154        }
155    } else {
156        quote! {}
157    };
158
159    let eq_impl = if trait_impls.contains(&"PartialEq".to_string()) {
160        quote! {
161            fn lua_eq(&self, other: &dyn luars::lua_value::userdata_trait::UserDataTrait) -> Option<bool> {
162                other.as_any().downcast_ref::<#name>().map(|o| self == o)
163            }
164        }
165    } else {
166        quote! {}
167    };
168
169    let ord_impl = if trait_impls.contains(&"PartialOrd".to_string()) {
170        quote! {
171            fn lua_lt(&self, other: &dyn luars::lua_value::userdata_trait::UserDataTrait) -> Option<bool> {
172                other.as_any().downcast_ref::<#name>()
173                    .and_then(|o| self.partial_cmp(o))
174                    .map(|c| c == std::cmp::Ordering::Less)
175            }
176            fn lua_le(&self, other: &dyn luars::lua_value::userdata_trait::UserDataTrait) -> Option<bool> {
177                other.as_any().downcast_ref::<#name>()
178                    .and_then(|o| self.partial_cmp(o))
179                    .map(|c| c != std::cmp::Ordering::Greater)
180            }
181        }
182    } else {
183        quote! {}
184    };
185
186    let type_name_str = name.to_string();
187
188    let expanded = quote! {
189        impl luars::lua_value::userdata_trait::UserDataTrait for #name {
190            fn type_name(&self) -> &'static str {
191                #type_name_str
192            }
193
194            fn get_field(&self, key: &str) -> Option<luars::lua_value::userdata_trait::UdValue> {
195                match key {
196                    #(#get_field_arms)*
197                    _ => None,
198                }
199            }
200
201            fn set_field(&mut self, key: &str, value: luars::lua_value::userdata_trait::UdValue) -> Option<Result<(), String>> {
202                match key {
203                    #(#set_field_arms)*
204                    #(#readonly_set_arms)*
205                    _ => None,
206                }
207            }
208
209            fn field_names(&self) -> &'static [&'static str] {
210                &[#(#field_name_strs),*]
211            }
212
213            #tostring_impl
214            #eq_impl
215            #ord_impl
216
217            fn as_any(&self) -> &dyn std::any::Any {
218                self
219            }
220
221            fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
222                self
223            }
224        }
225    };
226
227    expanded.into()
228}
229
230// ==================== Internal types ====================
231
232struct FieldInfo {
233    ident: Ident,
234    ty: syn::Type,
235    lua_name: String,
236    readonly: bool,
237}
238
239// ==================== Attribute parsing ====================
240
241/// Parse `#[lua_impl(Display, PartialEq, PartialOrd, ...)]` attributes
242fn parse_lua_impl_attrs(input: &DeriveInput) -> Vec<String> {
243    let mut impls = Vec::new();
244    for attr in &input.attrs {
245        if attr.path().is_ident("lua_impl") {
246            if let Meta::List(list) = &attr.meta {
247                let _ = list.parse_nested_meta(|meta| {
248                    if let Some(ident) = meta.path.get_ident() {
249                        impls.push(ident.to_string());
250                    }
251                    Ok(())
252                });
253            }
254        }
255    }
256    impls
257}
258
259// ==================== Code generation helpers ====================
260
261/// Generate code to convert a Rust field value → UdValue
262fn field_to_udvalue(
263    ty: &syn::Type,
264    accessor: proc_macro2::TokenStream,
265) -> proc_macro2::TokenStream {
266    let type_str = normalize_type(ty);
267
268    match type_str.as_str() {
269        // Integers → UdValue::Integer
270        "i8" | "i16" | "i32" | "i64" | "isize" | "u8" | "u16" | "u32" | "u64" | "usize" => {
271            quote! { luars::lua_value::userdata_trait::UdValue::Integer(#accessor as i64) }
272        }
273        // Floats → UdValue::Number
274        "f32" | "f64" => {
275            quote! { luars::lua_value::userdata_trait::UdValue::Number(#accessor as f64) }
276        }
277        // Bool → UdValue::Boolean
278        "bool" => {
279            quote! { luars::lua_value::userdata_trait::UdValue::Boolean(#accessor) }
280        }
281        // String → UdValue::Str (cloned)
282        "String" => {
283            quote! { luars::lua_value::userdata_trait::UdValue::Str(#accessor.clone()) }
284        }
285        // Fallback: try Into<UdValue>
286        _ => {
287            quote! { luars::lua_value::userdata_trait::UdValue::from(#accessor.clone()) }
288        }
289    }
290}
291
292/// Generate code to convert UdValue → Rust type and assign to a field
293fn udvalue_to_field(
294    ty: &syn::Type,
295    target: proc_macro2::TokenStream,
296    field_name: &str,
297) -> proc_macro2::TokenStream {
298    let type_str = normalize_type(ty);
299
300    match type_str.as_str() {
301        // Integer types
302        "i8" | "i16" | "i32" | "i64" | "isize" => {
303            quote! {
304                match value.to_integer() {
305                    Some(i) => { #target = i as #ty; Some(Ok(())) }
306                    None => Some(Err(format!("expected integer for field '{}'", #field_name)))
307                }
308            }
309        }
310        "u8" | "u16" | "u32" | "u64" | "usize" => {
311            quote! {
312                match value.to_integer() {
313                    Some(i) if i >= 0 => { #target = i as #ty; Some(Ok(())) }
314                    Some(_) => Some(Err(format!("expected non-negative integer for field '{}'", #field_name))),
315                    None => Some(Err(format!("expected integer for field '{}'", #field_name)))
316                }
317            }
318        }
319        // Float types
320        "f32" | "f64" => {
321            quote! {
322                match value.to_number() {
323                    Some(n) => { #target = n as #ty; Some(Ok(())) }
324                    None => Some(Err(format!("expected number for field '{}'", #field_name)))
325                }
326            }
327        }
328        // Bool
329        "bool" => {
330            quote! {
331                {
332                    #target = value.to_bool();
333                    Some(Ok(()))
334                }
335            }
336        }
337        // String
338        "String" => {
339            quote! {
340                match value.to_str() {
341                    Some(s) => { #target = s.to_owned(); Some(Ok(())) }
342                    None => Some(Err(format!("expected string for field '{}'", #field_name)))
343                }
344            }
345        }
346        // Unsupported type
347        _ => {
348            quote! {
349                Some(Err(format!("cannot set field '{}': unsupported type", #field_name)))
350            }
351        }
352    }
353}
354
355/// Normalize a syn::Type to a simple string for matching
356fn normalize_type(ty: &syn::Type) -> String {
357    quote!(#ty).to_string().replace(" ", "")
358}