async_component_macro/
lib.rs

1use proc_macro2::TokenStream;
2use quote::{format_ident, quote, quote_spanned, IdentFragment};
3use syn::{parse_macro_input, Attribute, Data, DeriveInput, ExprPath, Fields, Index};
4
5#[proc_macro_derive(AsyncComponent, attributes(component, state))]
6pub fn component_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
7    let input = parse_macro_input!(input as DeriveInput);
8
9    proc_macro::TokenStream::from(impl_component_stream(&input))
10}
11
12fn extract_attribute<'a>(ident: &str, attrs: &'a [Attribute]) -> Option<&'a Attribute> {
13    for attr in attrs {
14        if !attr.path.is_ident(ident) {
15            continue;
16        }
17        return Some(attr);
18    }
19
20    None
21}
22
23fn extract_path_attribute(attr: &Attribute) -> Option<ExprPath> {
24    attr.parse_args::<ExprPath>().ok()
25}
26
27fn impl_component_stream(input: &DeriveInput) -> TokenStream {
28    let name = &input.ident;
29
30    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
31
32    let update_component_body = match input.data {
33        Data::Struct(ref data) => {
34            let state_update_call = extract_attribute("component", &input.attrs)
35                .map(|attr| extract_path_attribute(attr))
36                .map(|path| {
37                    quote! {
38                        #path(self);
39                    }
40                });
41
42            let state_poll = update_state_body(&data.fields);
43            let component_poll = component_update_body(&data.fields);
44
45            quote! {
46                #component_poll
47
48                #state_poll
49
50                #state_update_call
51            }
52        }
53        Data::Enum(_) => unimplemented!("Derive cannot be applied to enum"),
54        Data::Union(_) => unimplemented!("Derive cannot be applied to union"),
55    };
56
57    quote! {
58        impl #impl_generics ::async_component::AsyncComponent for #name #ty_generics #where_clause {
59            fn update_component(&mut self) {
60                #update_component_body
61            }
62        }
63    }
64}
65
66fn update_state_body(fields: &Fields) -> TokenStream {
67    match fields {
68        Fields::Named(fields) => {
69            let iter = fields.named.iter().filter_map(|field| {
70                let method_attr = extract_attribute("state", &field.attrs)?;
71                let method_path = extract_path_attribute(method_attr);
72
73                let name = field.ident.as_ref().unwrap();
74
75                Some(field_state_update_body(name, method_path))
76            });
77
78            quote! {
79                #(#iter)*
80            }
81        }
82        Fields::Unnamed(fields) => {
83            let iter = fields
84                .unnamed
85                .iter()
86                .enumerate()
87                .filter_map(|(index, field)| {
88                    let method_attr = extract_attribute("state", &field.attrs)?;
89                    let method_path = extract_path_attribute(method_attr);
90
91                    let index = Index::from(index);
92
93                    Some(field_state_update_body(index, method_path))
94                });
95
96            quote! {
97                #(#iter)*
98            }
99        }
100        Fields::Unit => {
101            quote!()
102        }
103    }
104}
105
106fn field_state_update_body(name: impl IdentFragment, method_path: Option<ExprPath>) -> TokenStream {
107    let name = format_ident!("{}", name);
108
109    let method_call = method_path.map(|path| quote! { #path(self, _recv); });
110
111    quote_spanned! { name.span() =>
112        if let Some(_recv) = ::async_component::State::update(&mut self.#name) {
113            #method_call
114        }
115    }
116}
117
118fn component_update_body(fields: &Fields) -> TokenStream {
119    match fields {
120        Fields::Named(fields) => {
121            let iter = fields.named.iter().filter_map(|field| {
122                let _ = extract_attribute("component", &field.attrs)?;
123                let name = field.ident.as_ref().unwrap();
124
125                Some(field_component_update_body(name))
126            });
127
128            quote! {
129                #(#iter)*
130            }
131        }
132        Fields::Unnamed(fields) => {
133            let iter = fields
134                .unnamed
135                .iter()
136                .enumerate()
137                .filter_map(|(index, field)| {
138                    let _ = extract_attribute("component", &field.attrs)?;
139                    let index = Index::from(index);
140
141                    Some(field_component_update_body(index))
142                });
143
144            quote! {
145                #(#iter)*
146            }
147        }
148        Fields::Unit => {
149            quote!()
150        }
151    }
152}
153
154fn field_component_update_body(name: impl IdentFragment) -> TokenStream {
155    let name = format_ident!("{}", name);
156
157    quote_spanned! { name.span() =>
158        ::async_component::AsyncComponent::update_component(&mut self.#name);
159    }
160}