concoct_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote, ToTokens};
3use syn::{
4    fold::{self, Fold},
5    parse_macro_input, parse_quote,
6    punctuated::Punctuated,
7    token::Comma,
8    Expr, FnArg, GenericParam, ItemFn, Macro, ReturnType, Stmt, Type,
9};
10
11#[proc_macro_attribute]
12pub fn composable(_attr: TokenStream, item: TokenStream) -> TokenStream {
13    let item = parse_macro_input!(item as ItemFn);
14
15    let ident = item.sig.ident;
16    let vis = item.vis;
17    let mut generics = Vec::new();
18    for param in &item.sig.generics.params {
19        match param {
20            GenericParam::Type(type_param) => {
21                generics.push(type_param.ident.clone());
22            }
23            _ => todo!(),
24        }
25    }
26
27    let generics_clause = item.sig.generics.params;
28    let where_clause = item.sig.generics.where_clause;
29
30    let output = match item.sig.output {
31        ReturnType::Type(_, ty) => Some(*ty),
32        ReturnType::Default => None,
33    };
34    let output_ty = output.clone().unwrap_or(parse_quote!(()));
35
36    let block = Folder {
37        is_nested: false,
38        is_replaceable: false,
39        pos: 0,
40    }
41    .fold_block(*item.block);
42
43    let mut input_pats = Vec::new();
44    let mut input_types = Vec::new();
45    for input in item.sig.inputs {
46        match input {
47            FnArg::Typed(typed) => {
48                input_pats.push(typed.pat);
49                input_types.push(typed.ty);
50            }
51            _ => todo!(),
52        }
53    }
54
55    let struct_ident = format_ident!("{}_composable", ident);
56    let inputs: Vec<_> = input_pats
57        .iter()
58        .zip(&input_types)
59        .map(|(pat, ty)| quote!(#pat: #ty))
60        .collect();
61
62    let mut struct_fields = inputs.clone();
63
64    let input_generics: Vec<_> = input_types
65        .iter()
66        .filter_map(|ty| {
67            match &**ty {
68                Type::Path(type_path) => {
69                    if let Some(ident) = type_path.path.get_ident() {
70                        return Some(ident);
71                    }
72                }
73                _ => {}
74            }
75
76            None
77        })
78        .collect();
79
80    let mut struct_markers = Vec::new();
81    for (idx, generic) in generics.iter().enumerate() {
82        if !input_generics.contains(&generic) {
83            let ident = format_ident!("_marker{}", idx);
84            struct_fields.push(parse_quote!(#ident: std::marker::PhantomData<#generic>));
85            struct_markers.push(quote!(#ident: std::marker::PhantomData));
86        }
87    }
88
89    let group_id = quote!(std::any::TypeId::of::<#struct_ident::<#(#generics,)*>>());
90    let group = if output.is_some() {
91        quote! {
92            composer.replaceable_group(#group_id, move |composer| #block)
93        }
94    } else {
95        if inputs.is_empty() {
96            quote! {
97                composer.restart_group(#group_id, move |composer| {
98                    /*
99                    if changed == 0 && composer.is_skipping() {
100                        composer.skip_to_group_end();
101                    } else {
102                        #block
103                    }
104                     */
105
106                     #block
107                });
108            }
109        } else {
110            let checks = input_pats.iter().enumerate().map(|(idx, input)| {
111                let i: u32 = 0b111 << (idx * 3 + 1);
112                quote! {
113                    if changed & #i == 0 {
114                        dirty = changed | if composer.changed(&x) { 4 } else { 2 };
115                    }
116                }
117            });
118
119            let mut mask = 1u32;
120            let mut value = 0u32;
121            for idx in 0..input_pats.len() {
122                mask |= 0b101 << (idx * 3 + 1);
123                value |= 0b10 << (idx * 3);
124            }
125
126            quote! {
127                composer.restart_group(#group_id, move |composer| {
128                    /*
129                        let mut dirty = changed;
130
131                        #(#checks)*
132
133                        if dirty & #mask == #value  && composer.is_skipping() {
134                            composer.skip_to_group_end();
135                        } else {
136                            #block
137                        }
138                     */
139
140                    #block
141                });
142            }
143        }
144    };
145
146    let mut constructor_fields = Punctuated::<_, Comma>::new();
147    constructor_fields.extend(input_pats.iter().map(|pat| pat.to_token_stream()));
148    constructor_fields.extend(struct_markers.clone());
149
150    let mut struct_pattern = Punctuated::<_, Comma>::new();
151    struct_pattern.extend(input_pats.iter().map(|pat| pat.to_token_stream()));
152    struct_pattern.push(quote!(..));
153
154    let expanded = quote! {
155        #[must_use]
156        #vis fn #ident <#generics_clause> (#(#inputs),*) -> impl concoct::Composable<Output = #output_ty>  #where_clause {
157            #[allow(non_camel_case_types)]
158            struct #struct_ident <#(#generics),*> {
159                #(#struct_fields),*
160            }
161
162            impl<#generics_clause> concoct::Composable<> for #struct_ident <#(#generics),*> #where_clause {
163                type Output = #output_ty;
164
165                fn compose(self, composer: &mut concoct::Composer, changed: u32) -> Self::Output {
166                    compose!(());
167
168                    let Self { #struct_pattern } = self;
169
170                    #group
171                }
172            }
173
174            #struct_ident {
175                #constructor_fields
176            }
177        }
178    };
179
180    TokenStream::from(expanded)
181}
182
183struct Folder {
184    is_nested: bool,
185    is_replaceable: bool,
186    pos: usize,
187}
188
189impl Fold for Folder {
190    fn fold_stmt(&mut self, mut i: syn::Stmt) -> syn::Stmt {
191        if let Stmt::Macro(stmt_macro) = &i {
192            if let Some(expr) = get_compose_macro(&stmt_macro.mac) {
193                self.is_replaceable = true;
194                i = parse_quote! {
195                    (#expr).compose(composer, 0);
196                };
197            }
198        }
199
200        fold::fold_stmt(self, i)
201    }
202
203    fn fold_expr(&mut self, mut i: Expr) -> Expr {
204        match &mut i {
205            Expr::Macro(expr_macro) => {
206                self.is_replaceable = true;
207                if let Some(expr) = get_compose_macro(&expr_macro.mac) {
208                    i = parse_quote! {
209                        (#expr).compose(composer, 0)
210                    };
211                }
212            }
213            Expr::If(expr_if) => {
214                let old = self.is_nested;
215                self.is_nested = true;
216
217                *expr_if = fold::fold_expr_if(self, expr_if.clone());
218                self.is_nested = old;
219            }
220            _ => {}
221        }
222
223        fold::fold_expr(self, i)
224    }
225
226    fn fold_block(&mut self, i: syn::Block) -> syn::Block {
227        if self.is_nested {
228            let old = self.is_replaceable;
229            self.is_replaceable = false;
230
231            let mut block = fold::fold_block(self, i);
232            if self.is_replaceable {
233                let ident = format_ident!("Group{}", self.pos);
234                self.pos += 1;
235
236                block = parse_quote!({
237                    struct #ident;
238                    composer.replaceable_group(std::any::TypeId::of::<#ident>(), |composer| #block)
239                });
240            }
241
242            self.is_replaceable = old;
243
244            block
245        } else {
246            fold::fold_block(self, i)
247        }
248    }
249}
250
251fn get_compose_macro(mac: &Macro) -> Option<Expr> {
252    if mac.path.get_ident().map(ToString::to_string).as_deref() == Some("compose") {
253        let body = mac.parse_body().unwrap();
254        Some(body)
255    } else {
256        None
257    }
258}