dioxus_motion_transitions_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{Attribute, Data, DataEnum, DeriveInput, Fields, Meta, parse_macro_input};
4
5fn get_transition_from_attrs(attrs: &[Attribute]) -> Option<String> {
6    attrs
7        .iter()
8        .find(|attr| attr.path().is_ident("transition"))
9        .and_then(|attr| {
10            if let Ok(Meta::Path(path)) = attr.parse_args::<Meta>() {
11                path.get_ident().map(|ident| ident.to_string())
12            } else {
13                None
14            }
15        })
16}
17
18// Helper to extract layout nesting information from enum variants
19fn get_layout_depth(variants: &[&syn::Variant]) -> Vec<(syn::Ident, usize)> {
20    let mut layout_depth = Vec::new();
21    let mut current_depth = 0;
22
23    for variant in variants {
24        // Check if this variant has a layout attribute
25        if variant
26            .attrs
27            .iter()
28            .any(|attr| attr.path().is_ident("layout"))
29        {
30            current_depth += 1;
31        }
32
33        // Check if this variant ends a layout
34        if variant
35            .attrs
36            .iter()
37            .any(|attr| attr.path().is_ident("end_layout"))
38            && current_depth > 0
39        {
40            current_depth -= 1;
41        }
42
43        // Associate current depth with this variant
44        layout_depth.push((variant.ident.clone(), current_depth));
45    }
46
47    layout_depth
48}
49
50#[proc_macro_derive(MotionTransitions, attributes(transition, layout, end_layout))]
51pub fn derive_route_transitions(input: TokenStream) -> TokenStream {
52    let input = parse_macro_input!(input as DeriveInput);
53    let name = &input.ident;
54    let variants = match input.data {
55        Data::Enum(DataEnum { variants, .. }) => variants,
56        _ => panic!("MotionTransitions can only be derived for enums"),
57    };
58
59    let component_match_arms = variants.iter().map(|variant| {
60        let variant_ident = &variant.ident;
61        let component_name = &variant.ident;
62
63        match &variant.fields {
64            Fields::Named(fields) => {
65                let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
66                quote! {
67                    Self::#variant_ident { #(#field_names,)* } => {
68                        rsx! { #component_name { #(#field_names: #field_names.clone(),)* } }
69                    }
70                }
71            }
72            Fields::Unnamed(_) => {
73                quote! { Self::#variant_ident(..) => rsx! { #component_name {} } }
74            }
75            Fields::Unit => {
76                quote! { Self::#variant_ident {} => rsx! { #component_name {} } }
77            }
78        }
79    });
80
81    let transition_match_arms = variants.iter().map(|variant| {
82        let variant_ident = &variant.ident;
83        let transition = get_transition_from_attrs(&variant.attrs)
84            .map(|t| format_ident!("{}", t))
85            .unwrap_or(format_ident!("Fade"));
86
87        match &variant.fields {
88            Fields::Named(fields) => {
89                let field_patterns = fields.named.iter().map(|f| {
90                    let name = &f.ident;
91                    quote! { #name: _ }
92                });
93                quote! {
94                    Self::#variant_ident { #(#field_patterns,)* } => TransitionVariant::#transition
95                }
96            }
97            Fields::Unnamed(_) => {
98                quote! { Self::#variant_ident(..) => TransitionVariant::#transition }
99            }
100            Fields::Unit => {
101                quote! { Self::#variant_ident {} => TransitionVariant::#transition }
102            }
103        }
104    });
105
106    // Generate layout depth match arms
107    let layout_depths = get_layout_depth(&variants.iter().collect::<Vec<_>>());
108    let layout_depth_match_arms =
109        layout_depths.iter().map(|(variant_ident, depth)| {
110            match &variants
111                .iter()
112                .find(|v| &v.ident == variant_ident)
113                .unwrap()
114                .fields
115            {
116                Fields::Named(fields) => {
117                    let field_patterns = fields.named.iter().map(|f| {
118                        let name = &f.ident;
119                        quote! { #name: _ }
120                    });
121                    quote! {
122                        Self::#variant_ident { #(#field_patterns,)* } => #depth
123                    }
124                }
125                Fields::Unnamed(_) => {
126                    quote! { Self::#variant_ident(..) => #depth }
127                }
128                Fields::Unit => {
129                    quote! { Self::#variant_ident {} => #depth }
130                }
131            }
132        });
133
134    let expanded = quote! {
135        impl AnimatableRoute for  #name {
136            fn get_transition(&self) -> TransitionVariant {
137                match self {
138                    #(#transition_match_arms,)*
139                    _ => TransitionVariant::Fade,
140                }
141            }
142
143            fn get_component(&self) -> Element {
144                match self {
145                    #(#component_match_arms,)*
146                }
147            }
148
149            // New method to get layout depth
150            fn get_layout_depth(&self) -> usize {
151                match self {
152                    #(#layout_depth_match_arms,)*
153                    _ => 0,
154                }
155            }
156        }
157    };
158
159    TokenStream::from(expanded)
160}