feather_macro/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: 2025 Fundament Research Institute <https://fundament.institute>
3
4use core::panic;
5
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use quote::{format_ident, quote};
9use syn::{Data, DataEnum, DeriveInput, Meta, parse_macro_input};
10
11fn derive_base_prop(input: TokenStream, prop: &str, source: &str, result: &str) -> TokenStream {
12    let ast = parse_macro_input!(input as DeriveInput);
13
14    let result: syn::Path = syn::parse_str(result).unwrap();
15    let source: syn::Path = syn::parse_str(source).unwrap();
16    let prop = format_ident!("{}", prop);
17    let name = ast.ident;
18    quote! {
19        impl #source for #name {
20            fn #prop(&self) -> &#result {
21                &self.#prop
22            }
23        }
24    }
25    .into()
26}
27
28#[proc_macro_derive(Empty)]
29pub fn derive_empty(input: TokenStream) -> TokenStream {
30    let ast = parse_macro_input!(input as DeriveInput);
31
32    let sname = ast.ident;
33    quote! {
34        impl feather_ui::layout::base::Empty for #sname {}
35    }
36    .into()
37}
38
39#[proc_macro_derive(Area)]
40pub fn derive_area(input: TokenStream) -> TokenStream {
41    derive_base_prop(
42        input,
43        "area",
44        "feather_ui::layout::base::Area",
45        "feather_ui::DRect",
46    )
47}
48
49#[proc_macro_derive(Padding)]
50pub fn derive_padding(input: TokenStream) -> TokenStream {
51    derive_base_prop(
52        input,
53        "padding",
54        "feather_ui::layout::base::Padding",
55        "feather_ui::DAbsRect",
56    )
57}
58
59#[proc_macro_derive(Margin)]
60pub fn derive_margin(input: TokenStream) -> TokenStream {
61    derive_base_prop(
62        input,
63        "margin",
64        "feather_ui::layout::base::Margin",
65        "feather_ui::DRect",
66    )
67}
68
69#[proc_macro_derive(Limits)]
70pub fn derive_limits(input: TokenStream) -> TokenStream {
71    derive_base_prop(
72        input,
73        "limits",
74        "feather_ui::layout::base::Limits",
75        "feather_ui::DLimits",
76    )
77}
78
79#[proc_macro_derive(RLimits)]
80pub fn derive_rlimits(input: TokenStream) -> TokenStream {
81    derive_base_prop(
82        input,
83        "rlimits",
84        "feather_ui::layout::base::RLimits",
85        "feather_ui::RelLimits",
86    )
87}
88
89#[proc_macro_derive(Anchor)]
90pub fn derive_anchor(input: TokenStream) -> TokenStream {
91    derive_base_prop(
92        input,
93        "anchor",
94        "feather_ui::layout::base::Anchor",
95        "feather_ui::DPoint",
96    )
97}
98
99#[proc_macro_derive(TextEdit)]
100pub fn derive_textedit(input: TokenStream) -> TokenStream {
101    derive_base_prop(
102        input,
103        "textedit",
104        "feather_ui::layout::base::TextEdit",
105        "feather_ui::text::EditView",
106    )
107}
108
109#[proc_macro_derive(FlexProp)]
110pub fn derive_flex_prop(input: TokenStream) -> TokenStream {
111    let ast = parse_macro_input!(input as DeriveInput);
112
113    let name = ast.ident;
114    quote! {
115        impl feather_ui::layout::flex::Prop for #name {
116        fn wrap(&self) -> bool { self.wrap }
117        fn justify(&self) -> feather_ui::layout::flex::FlexJustify { self.justify }
118        fn align(&self) -> feather_ui::layout::flex::FlexJustify { self.align }
119        }
120    }
121    .into()
122}
123
124#[proc_macro_derive(FlexChild)]
125pub fn derive_flex_child(input: TokenStream) -> TokenStream {
126    let ast = parse_macro_input!(input as DeriveInput);
127
128    let name = ast.ident;
129    quote! {
130        impl feather_ui::layout::flex::Child for #name {
131            fn grow(&self) -> f32 { self.grow }
132            fn shrink(&self) -> f32 { self.shrink }
133            fn basis(&self) -> feather_ui::DValue { self.basis }
134        }
135    }
136    .into()
137}
138
139#[proc_macro_derive(ZIndex)]
140pub fn derive_zindex(input: TokenStream) -> TokenStream {
141    let ast = parse_macro_input!(input as DeriveInput);
142
143    let sname = ast.ident;
144    quote! {
145        impl feather_ui::layout::base::ZIndex for #sname {
146            fn zindex(&self) -> i32 {
147                self.zindex
148            }
149        }
150    }
151    .into()
152}
153
154#[proc_macro_derive(Direction)]
155pub fn derive_direction(input: TokenStream) -> TokenStream {
156    let ast = parse_macro_input!(input as DeriveInput);
157
158    let sname = ast.ident;
159    quote! {
160        impl feather_ui::layout::base::Direction for #sname {
161            fn direction(&self) -> feather_ui::RowDirection {
162                self.direction
163            }
164        }
165    }
166    .into()
167}
168
169#[proc_macro_derive(RootProp)]
170pub fn derive_root_prop(input: TokenStream) -> TokenStream {
171    derive_base_prop(
172        input,
173        "dim",
174        "feather_ui::layout::root::Prop",
175        "feather_ui::AbsDim",
176    )
177}
178
179fn data_enum(ast: &DeriveInput) -> &DataEnum {
180    if let Data::Enum(data_enum) = &ast.data {
181        data_enum
182    } else {
183        panic!("`Dispatch` derive can only be used on an enum.");
184    }
185}
186
187fn find_enum_module(attrs: &[syn::Attribute]) -> syn::Result<String> {
188    // Extract EnumVariantType's module, since this has to be used in conjunction
189    // with our derive
190    for attr in attrs.iter() {
191        if attr.path().is_ident("evt") {
192            let nested = attr
193                .parse_args_with(
194                    syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
195                )
196                .unwrap();
197
198            for meta in nested {
199                if let Meta::NameValue(name_value) = meta {
200                    if let (true, syn::Expr::Lit(lit_str)) =
201                        (name_value.path.is_ident("module"), name_value.value)
202                    {
203                        if let syn::Lit::Str(s) = lit_str.lit {
204                            return Ok(s.value());
205                        } else {
206                            return Err(syn::Error::new(Span::call_site(), ""));
207                        }
208                    } else {
209                        return Err(syn::Error::new(Span::call_site(), ""));
210                    }
211                }
212            }
213
214            // This would be a lot easier but it doesn't seem to work for
215            // #[evt(derive(Clone), module = "mouse_area_event")]
216            /*let _ = attr.parse_nested_meta(|meta| {
217                if meta.path.is_ident("module") {
218                    let value = meta.value()?;
219                    let s: LitStr = value.parse()?;
220                    enum_module = Some(s.value());
221                }
222
223                Ok(())
224            });*/
225        }
226    }
227
228    // Error here doesn't matter, we transform it into another error message upon
229    // return
230    Err(syn::Error::new(Span::call_site(), ""))
231}
232
233#[proc_macro_derive(Dispatch)]
234pub fn dispatchable(input: TokenStream) -> TokenStream {
235    let crate_name = std::env::var("CARGO_PKG_NAME").unwrap();
236
237    let crate_name = format_ident!(
238        "{}",
239        if crate_name == "feather-ui" {
240            "crate"
241        } else {
242            "feather_ui"
243        }
244    );
245
246    let ast = parse_macro_input!(input as DeriveInput);
247    let enum_module = format_ident!(
248        "{}",
249        find_enum_module(&ast.attrs).expect(
250        "Expected `evt` attribute argument in the form: `#[evt(module = \"some_module_name\")]`",
251    ));
252
253    let enum_name = &ast.ident;
254    let data_enum = data_enum(&ast);
255    let variants = &data_enum.variants;
256
257    let mut extract_declarations = proc_macro2::TokenStream::new();
258    let mut restore_declarations = proc_macro2::TokenStream::new();
259
260    for (counter, variant) in variants.iter().enumerate() {
261        let variant_name = &variant.ident;
262
263        let idx = (1_u64)
264            .checked_shl(counter as u32)
265            .expect("Too many variants! Can't handle more than 64!");
266
267        if variant.fields.is_empty() {
268            extract_declarations.extend(quote! {
269                #enum_name::#variant_name => (
270                    #idx,
271                    Box::new(#enum_module::#variant_name::try_from(self).unwrap()),
272                ),
273            });
274        } else if variant.fields.iter().next().unwrap().ident.is_none() {
275            let underscores = variant.fields.iter().map(|_| format_ident!("_"));
276            extract_declarations.extend(quote! {
277                #enum_name::#variant_name(#(#underscores),*) => (
278                    #idx,
279                    Box::new(#enum_module::#variant_name::try_from(self).unwrap()),
280                ),
281            });
282        } else {
283            extract_declarations.extend(quote! {
284                #enum_name::#variant_name { .. } => (
285                    #idx,
286                    Box::new(#enum_module::#variant_name::try_from(self).unwrap()),
287                ),
288            });
289        }
290
291        restore_declarations.extend(quote! {
292            #idx => Ok(#enum_name::from(
293                *pair
294                    .1
295                    .downcast::<#enum_module::#variant_name>()
296                    .map_err(|_| {
297                        #crate_name::Error::MismatchedEnumTag(
298                            pair.0,
299                            std::any::TypeId::of::<#enum_module::#variant_name>(),
300                            typeid,
301                        )
302                    })?,
303            )),
304        });
305    }
306
307    let counter = variants.len();
308    quote! {
309        impl #crate_name::Dispatchable for #enum_name {
310            const SIZE: usize = #counter;
311
312            fn extract(self) -> #crate_name::DispatchPair {
313                match self {
314                    #extract_declarations
315                }
316            }
317
318            fn restore(pair: #crate_name::DispatchPair) -> Result<Self, #crate_name::Error> {
319                let typeid = (*pair.1).type_id();
320                match pair.0 {
321                    #restore_declarations
322                    _ => Err(#crate_name::Error::InvalidEnumTag(pair.0)),
323                }
324            }
325        }
326    }
327    .into()
328}
329
330#[proc_macro_derive(StateMachineChild)]
331pub fn state_machine_child(input: TokenStream) -> TokenStream {
332    let crate_name = std::env::var("CARGO_PKG_NAME").unwrap();
333
334    let crate_name = format_ident!(
335        "{}",
336        if crate_name == "feather-ui" {
337            "crate"
338        } else {
339            "feather_ui"
340        }
341    );
342
343    let ast = parse_macro_input!(input as DeriveInput);
344    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
345
346    let data = if let Data::Struct(data_enum) = &ast.data {
347        data_enum
348    } else {
349        panic!("`StateMachineChild` derive can only be used on a struct.");
350    };
351
352    let has_children = data.fields.members().any(|x| {
353        if let syn::Member::Named(f) = x {
354            f == "children"
355        } else {
356            false
357        }
358    });
359
360    let apply_children = if has_children {
361        quote! {
362            fn apply_children(
363                    &self,
364                    f: &mut dyn FnMut(&dyn #crate_name::StateMachineChild) -> eyre::Result<()>,
365                ) -> eyre::Result<()> {
366                    self.children
367                        .iter()
368                        .try_for_each(|x| f(x.as_ref().unwrap().as_ref()))
369                }
370        }
371    } else {
372        quote! {}
373    };
374
375    let sname = ast.ident;
376    quote! {
377        impl #impl_generics #crate_name::StateMachineChild for #sname #ty_generics #where_clause {
378            fn id(&self) -> std::sync::Arc<SourceID> {
379                self.id.clone()
380            }
381
382            #apply_children
383        }
384    }
385    .into()
386}
387
388#[proc_macro_derive(UserData)]
389pub fn lua_user_data(input: TokenStream) -> TokenStream {
390    /*let crate_name = std::env::var("CARGO_PKG_NAME").unwrap();
391
392    let crate_name = format_ident!(
393        "{}",
394        if crate_name == "feather-ui" {
395            "crate"
396        } else {
397            "feather_ui"
398        }
399    );*/
400    let crate_name = format_ident!("feather_ui");
401
402    let ast = parse_macro_input!(input as DeriveInput);
403    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
404
405    let data = if let Data::Struct(data_enum) = &ast.data {
406        data_enum
407    } else {
408        panic!("`UserData` derive can only be used on a struct.");
409    };
410
411    let mut field_methods = proc_macro2::TokenStream::new();
412    for m in data.fields.members() {
413        match m {
414            syn::Member::Named(ident) => {
415                field_methods.extend(quote! {
416                    f.add_field_method_get(stringify!(#ident), |_, this| Ok(this.#ident.clone()));
417                    f.add_field_method_set(stringify!(#ident), |_, this, v| Ok(this.#ident = v));
418                });
419            }
420            syn::Member::Unnamed(_) => panic!(
421                "You can't use a UserData derive on a tuple, because mlua knows how to parse tuples already!"
422            ),
423        }
424    }
425
426    let sname = ast.ident;
427    quote! {
428        impl #impl_generics #crate_name::mlua::UserData for #sname #ty_generics #where_clause {
429            fn add_fields<F: #crate_name::mlua::UserDataFields<Self>>(f: &mut F) {
430                #field_methods
431            }
432        }
433
434        impl #impl_generics #crate_name::mlua::FromLua for #sname #ty_generics #where_clause {
435            #[inline]
436            fn from_lua(value: #crate_name::mlua::Value, _: &#crate_name::mlua::Lua) -> #crate_name::mlua::Result<Self> {
437                match value {
438                    #crate_name::mlua::Value::UserData(ud) => Ok(ud.borrow::<Self>()?.clone()),
439                    _ => Err(#crate_name::mlua::Error::FromLuaConversionError {
440                        from: value.type_name(),
441                        to: stringify!(#sname).to_string(),
442                        message: None,
443                    }),
444                }
445            }
446        }
447    }
448    .into()
449}