Skip to main content

cranpose_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use proc_macro_crate::{crate_name, FoundCrate};
4use quote::quote;
5use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Type};
6
7/// Check if a type is Fn-like (impl FnMut/Fn/FnOnce, Box<dyn FnMut>, generic with Fn bound, etc.)
8/// For generic type parameters (e.g., `F` where F: FnMut()), we need to check the bounds.
9fn is_fn_like_type(ty: &Type) -> bool {
10    match ty {
11        // impl FnMut(...) + 'static, impl Fn(...), etc.
12        Type::ImplTrait(impl_trait) => impl_trait.bounds.iter().any(|bound| {
13            if let syn::TypeParamBound::Trait(trait_bound) = bound {
14                let path = &trait_bound.path;
15                if let Some(segment) = path.segments.last() {
16                    let ident_str = segment.ident.to_string();
17                    return ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce";
18                }
19            }
20            false
21        }),
22        // Box<dyn FnMut(...)>
23        Type::Path(type_path) => {
24            if let Some(segment) = type_path.path.segments.last() {
25                if segment.ident == "Box" {
26                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
27                        if let Some(syn::GenericArgument::Type(Type::TraitObject(trait_obj))) =
28                            args.args.first()
29                        {
30                            return trait_obj.bounds.iter().any(|bound| {
31                                if let syn::TypeParamBound::Trait(trait_bound) = bound {
32                                    let path = &trait_bound.path;
33                                    if let Some(segment) = path.segments.last() {
34                                        let ident_str = segment.ident.to_string();
35                                        return ident_str == "FnMut"
36                                            || ident_str == "Fn"
37                                            || ident_str == "FnOnce";
38                                    }
39                                }
40                                false
41                            });
42                        }
43                    }
44                }
45            }
46            false
47        }
48        // bare fn(...) -> ...
49        Type::BareFn(_) => true,
50        _ => false,
51    }
52}
53
54/// Check if a generic type parameter has Fn-like bounds by looking at the where clause and bounds
55fn is_generic_fn_like(ty: &Type, generics: &syn::Generics) -> bool {
56    // Extract the ident for Type::Path that might be a generic param
57    let type_ident = match ty {
58        Type::Path(type_path) if type_path.path.segments.len() == 1 => {
59            &type_path.path.segments[0].ident
60        }
61        _ => return false,
62    };
63
64    // Check if it's a type parameter with Fn bounds
65    for param in &generics.params {
66        if let syn::GenericParam::Type(type_param) = param {
67            if type_param.ident == *type_ident {
68                // Check the bounds on the type parameter
69                for bound in &type_param.bounds {
70                    if let syn::TypeParamBound::Trait(trait_bound) = bound {
71                        if let Some(segment) = trait_bound.path.segments.last() {
72                            let ident_str = segment.ident.to_string();
73                            if ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce" {
74                                return true;
75                            }
76                        }
77                    }
78                }
79            }
80        }
81    }
82
83    // Also check where clause
84    if let Some(where_clause) = &generics.where_clause {
85        for predicate in &where_clause.predicates {
86            if let syn::WherePredicate::Type(pred) = predicate {
87                if let Type::Path(bounded_type) = &pred.bounded_ty {
88                    if bounded_type.path.segments.len() == 1
89                        && bounded_type.path.segments[0].ident == *type_ident
90                    {
91                        for bound in &pred.bounds {
92                            if let syn::TypeParamBound::Trait(trait_bound) = bound {
93                                if let Some(segment) = trait_bound.path.segments.last() {
94                                    let ident_str = segment.ident.to_string();
95                                    if ident_str == "FnMut"
96                                        || ident_str == "Fn"
97                                        || ident_str == "FnOnce"
98                                    {
99                                        return true;
100                                    }
101                                }
102                            }
103                        }
104                    }
105                }
106            }
107        }
108    }
109
110    false
111}
112
113/// Unified check: is this type Fn-like, either syntactically or via generic bounds?
114fn is_fn_param(ty: &Type, generics: &syn::Generics) -> bool {
115    is_fn_like_type(ty) || is_generic_fn_like(ty, generics)
116}
117
118fn core_crate_path() -> TokenStream2 {
119    let crate_name = crate_name("cranpose")
120        .ok()
121        .or_else(|| crate_name("cranpose-core").ok());
122
123    match crate_name {
124        Some(FoundCrate::Itself) => quote!(crate),
125        Some(FoundCrate::Name(name)) => {
126            let ident = Ident::new(&name, Span::call_site());
127            quote!(#ident)
128        }
129        None => quote!(cranpose_core),
130    }
131}
132
133#[proc_macro_attribute]
134pub fn composable(attr: TokenStream, item: TokenStream) -> TokenStream {
135    let attr_tokens = TokenStream2::from(attr);
136    let mut enable_skip = true;
137    let core_path = core_crate_path();
138    if !attr_tokens.is_empty() {
139        match syn::parse2::<Ident>(attr_tokens) {
140            Ok(ident) if ident == "no_skip" => enable_skip = false,
141            Ok(other) => {
142                return syn::Error::new_spanned(other, "unsupported composable attribute")
143                    .to_compile_error()
144                    .into();
145            }
146            Err(err) => {
147                return err.to_compile_error().into();
148            }
149        }
150    }
151
152    let mut func = parse_macro_input!(item as ItemFn);
153
154    struct ParamInfo {
155        ident: Ident,
156        pat: Box<Pat>,
157        ty: Type,
158        pat_is_mut: bool,
159        is_impl_trait: bool,
160    }
161
162    let mut param_info: Vec<ParamInfo> = Vec::new();
163
164    for (index, arg) in func.sig.inputs.iter_mut().enumerate() {
165        if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
166            let pat_is_mut = matches!(
167                pat.as_ref(),
168                Pat::Ident(pat_ident) if pat_ident.mutability.is_some()
169            );
170            let is_impl_trait = matches!(**ty, Type::ImplTrait(_));
171
172            if is_impl_trait {
173                let original_pat: Box<Pat> = pat.clone();
174                if let Pat::Ident(pat_ident) = &**pat {
175                    param_info.push(ParamInfo {
176                        ident: pat_ident.ident.clone(),
177                        pat: original_pat,
178                        ty: ty.as_ref().clone(),
179                        pat_is_mut,
180                        is_impl_trait: true,
181                    });
182                } else {
183                    param_info.push(ParamInfo {
184                        ident: Ident::new(&format!("__arg{}", index), Span::call_site()),
185                        pat: original_pat,
186                        ty: ty.as_ref().clone(),
187                        pat_is_mut,
188                        is_impl_trait: true,
189                    });
190                }
191            } else {
192                let ident = Ident::new(&format!("__arg{}", index), Span::call_site());
193                let original_pat: Box<Pat> = pat.clone();
194                *pat = Box::new(syn::parse_quote! { #ident });
195                param_info.push(ParamInfo {
196                    ident,
197                    pat: original_pat,
198                    ty: ty.as_ref().clone(),
199                    pat_is_mut,
200                    is_impl_trait: false,
201                });
202            }
203        }
204    }
205
206    let original_block = func.block.clone();
207    let helper_block = original_block.clone();
208    let recranpose_block = original_block.clone();
209    let key_expr = quote! { #core_path::location_key(file!(), line!(), column!()) };
210
211    // Rebinds will be generated later in the helper_body context where we have access to slots
212    let rebinds_for_no_skip: Vec<_> = param_info
213        .iter()
214        .map(|info| {
215            let ident = &info.ident;
216            let pat = &info.pat;
217            quote! { let #pat = #ident; }
218        })
219        .collect();
220
221    let return_ty: syn::Type = match &func.sig.output {
222        ReturnType::Default => syn::parse_quote! { () },
223        ReturnType::Type(_, ty) => ty.as_ref().clone(),
224    };
225    let _helper_ident = Ident::new(
226        &format!("__cranpose_impl_{}", func.sig.ident),
227        Span::call_site(),
228    );
229    let generics = func.sig.generics.clone();
230    let (_impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
231
232    let _helper_inputs: Vec<TokenStream2> = param_info
233        .iter()
234        .map(|info| {
235            let ident = &info.ident;
236            let ty = &info.ty;
237            quote! { #ident: #ty }
238        })
239        .collect();
240
241    // Check if any params are impl Trait - if so, can't use skip optimization
242    let has_impl_trait = param_info
243        .iter()
244        .any(|info| matches!(info.ty, Type::ImplTrait(_)));
245
246    if enable_skip && !has_impl_trait {
247        let helper_ident = Ident::new(
248            &format!("__cranpose_impl_{}", func.sig.ident),
249            Span::call_site(),
250        );
251        let generics = func.sig.generics.clone();
252        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
253        let ty_generics_turbofish = ty_generics.as_turbofish();
254
255        // Helper function signature: all params except impl Trait (which can't be named)
256        let helper_inputs: Vec<TokenStream2> = param_info
257            .iter()
258            .filter_map(|info| {
259                if info.is_impl_trait {
260                    None
261                } else {
262                    let ident = &info.ident;
263                    let ty = &info.ty;
264                    Some(quote! { #ident: #ty })
265                }
266            })
267            .collect();
268
269        // Separate Fn-like params from regular params
270        let param_state_slots: Vec<Ident> = (0..param_info.len())
271            .map(|index| Ident::new(&format!("__param_state_slot{}", index), Span::call_site()))
272            .collect();
273
274        let param_setup: Vec<TokenStream2> = param_info
275            .iter()
276            .zip(param_state_slots.iter())
277            .map(|(info, slot_ident)| {
278                if info.is_impl_trait {
279                    quote! { __changed = true; }
280                } else if is_fn_param(&info.ty, &generics) {
281                    let ident = &info.ident;
282                    quote! {
283                        let #slot_ident = __composer
284                            .use_value_slot(|| #core_path::CallbackHolder::new());
285                        __composer.with_slot_value::<#core_path::CallbackHolder, _>(
286                            #slot_ident,
287                            |holder| {
288                                holder.update(#ident);
289                            },
290                        );
291                        __changed = true;
292                    }
293                } else {
294                    let ident = &info.ident;
295                    let ty = &info.ty;
296                    quote! {
297                        let #slot_ident = __composer
298                            .use_value_slot(|| #core_path::ParamState::<#ty>::default());
299                        if __composer.with_slot_value_mut::<#core_path::ParamState<#ty>, _>(
300                            #slot_ident,
301                            |state| state.update(&#ident),
302                        )
303                        {
304                            __changed = true;
305                        }
306                    }
307                }
308            })
309            .collect();
310
311        let param_setup_recompose: Vec<TokenStream2> = param_info
312            .iter()
313            .zip(param_state_slots.iter())
314            .map(|(info, slot_ident)| {
315                if info.is_impl_trait {
316                    quote! {}
317                } else if is_fn_param(&info.ty, &generics) {
318                    quote! {
319                        let #slot_ident = __composer
320                            .use_value_slot(|| #core_path::CallbackHolder::new());
321                    }
322                } else {
323                    let ty = &info.ty;
324                    quote! {
325                        let #slot_ident = __composer
326                            .use_value_slot(|| #core_path::ParamState::<#ty>::default());
327                    }
328                }
329            })
330            .collect();
331
332        let rebinds: Vec<TokenStream2> = param_info
333            .iter()
334            .zip(param_state_slots.iter())
335            .map(|(info, slot_ident)| {
336                if info.is_impl_trait {
337                    quote! {}
338                } else if is_fn_param(&info.ty, &generics) {
339                    let pat = &info.pat;
340                    let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
341                    if can_add_mut && !info.pat_is_mut {
342                        quote! {
343                            #[allow(unused_mut)]
344                            let mut #pat = __composer
345                                .with_slot_value::<#core_path::CallbackHolder, _>(
346                                    #slot_ident,
347                                    |holder| holder.clone_rc(),
348                                );
349                        }
350                    } else {
351                        quote! {
352                            #[allow(unused_mut)]
353                            let #pat = __composer
354                                .with_slot_value::<#core_path::CallbackHolder, _>(
355                                    #slot_ident,
356                                    |holder| holder.clone_rc(),
357                                );
358                        }
359                    }
360                } else {
361                    let pat = &info.pat;
362                    let ident = &info.ident;
363                    quote! {
364                        let #pat = #ident;
365                    }
366                }
367            })
368            .collect();
369
370        let rebinds_for_recompose: Vec<TokenStream2> = param_info
371            .iter()
372            .zip(param_state_slots.iter())
373            .map(|(info, slot_ident)| {
374                if info.is_impl_trait {
375                    quote! {}
376                } else if is_fn_param(&info.ty, &generics) {
377                    let pat = &info.pat;
378                    let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
379                    if can_add_mut && !info.pat_is_mut {
380                        quote! {
381                            #[allow(unused_mut)]
382                            let mut #pat = __composer
383                                .with_slot_value::<#core_path::CallbackHolder, _>(
384                                    #slot_ident,
385                                    |holder| holder.clone_rc(),
386                                );
387                        }
388                    } else {
389                        quote! {
390                            #[allow(unused_mut)]
391                            let #pat = __composer
392                                .with_slot_value::<#core_path::CallbackHolder, _>(
393                                    #slot_ident,
394                                    |holder| holder.clone_rc(),
395                                );
396                        }
397                    }
398                } else {
399                    let pat = &info.pat;
400                    let ty = &info.ty;
401                    quote! {
402                        let #pat = __composer
403                            .with_slot_value::<#core_path::ParamState<#ty>, _>(
404                                #slot_ident,
405                                |state| {
406                                    state
407                                        .value()
408                                        .expect("composable parameter missing for recomposition")
409                                },
410                            );
411                    }
412                }
413            })
414            .collect();
415
416        let recranpose_fn_ident = Ident::new(
417            &format!("__cranpose_recranpose_{}", func.sig.ident),
418            Span::call_site(),
419        );
420
421        let recranpose_setter = quote! {
422            {
423                __composer.set_recranpose_callback(move |
424                    __composer: &#core_path::Composer|
425                {
426                    #recranpose_fn_ident #ty_generics_turbofish (
427                        __composer
428                    );
429                });
430            }
431        };
432
433        let helper_body = quote! {
434            let __current_scope = __composer
435                .current_recranpose_scope()
436                .expect("missing recompose scope");
437            let mut __changed = __current_scope.should_recompose();
438            #(#param_setup)*
439            let __result_slot_index = __composer
440                .use_value_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
441            let __has_previous = __composer
442                .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
443                    __result_slot_index,
444                    |slot| slot.get().is_some(),
445                );
446            if !__changed && __has_previous {
447                __composer.skip_current_group();
448                let __result = __composer
449                    .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
450                        __result_slot_index,
451                        |slot| {
452                            slot.get()
453                                .expect("composable return value missing during skip")
454                        },
455                    );
456                return __result;
457            }
458            let __value: #return_ty = {
459                #(#rebinds)*
460                #helper_block
461            };
462            __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
463                __result_slot_index,
464                |slot| {
465                    slot.store(__value.clone());
466                },
467            );
468            #recranpose_setter
469            __value
470        };
471
472        let recranpose_fn_body = quote! {
473            #(#param_setup_recompose)*
474            let __result_slot_index = __composer
475                .use_value_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
476            #(#rebinds_for_recompose)*
477            let __value: #return_ty = {
478                #recranpose_block
479            };
480            __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
481                __result_slot_index,
482                |slot| {
483                    slot.store(__value.clone());
484                },
485            );
486            #recranpose_setter
487            __value
488        };
489
490        let recranpose_fn = quote! {
491            #[allow(non_snake_case)]
492            fn #recranpose_fn_ident #impl_generics (
493                __composer: &#core_path::Composer
494            ) -> #return_ty #where_clause {
495                #recranpose_fn_body
496            }
497        };
498
499        let helper_fn = quote! {
500            #[allow(non_snake_case)]
501            fn #helper_ident #impl_generics (
502                __composer: &#core_path::Composer
503                #(, #helper_inputs)*
504            ) -> #return_ty #where_clause {
505                #helper_body
506            }
507        };
508
509        // Wrapper args: pass all params except impl Trait on initial call
510        let wrapper_args: Vec<TokenStream2> = param_info
511            .iter()
512            .filter_map(|info| {
513                if info.is_impl_trait {
514                    None
515                } else {
516                    let ident = &info.ident;
517                    Some(quote! { #ident })
518                }
519            })
520            .collect();
521
522        let wrapped = quote!({
523            #core_path::with_current_composer(|__composer: &#core_path::Composer| {
524                __composer.with_group(#key_expr, |__composer: &#core_path::Composer| {
525                    #helper_ident(__composer #(, #wrapper_args)*)
526                })
527            })
528        });
529        func.block = Box::new(syn::parse2(wrapped).expect("failed to build block"));
530        TokenStream::from(quote! {
531            #recranpose_fn
532            #helper_fn
533            #func
534        })
535    } else {
536        // no_skip path: still uses simple rebinds
537        let wrapped = quote!({
538            #core_path::with_current_composer(|__composer: &#core_path::Composer| {
539                __composer.with_group(#key_expr, |__scope: &#core_path::Composer| {
540                    #(#rebinds_for_no_skip)*
541                    #original_block
542                })
543            })
544        });
545        func.block = Box::new(syn::parse2(wrapped).expect("failed to build block"));
546        TokenStream::from(quote! { #func })
547    }
548}