Skip to main content

cranpose_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use proc_macro_crate::{crate_name, FoundCrate};
4use quote::quote;
5use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Type};
6
7/// Check if a type is Fn-like (impl FnMut/Fn/FnOnce, Box<dyn FnMut>, generic with Fn bound, etc.)
8/// For generic type parameters (e.g., `F` where F: FnMut()), we need to check the bounds.
9fn is_fn_like_type(ty: &Type) -> bool {
10    match ty {
11        // impl FnMut(...) + 'static, impl Fn(...), etc.
12        Type::ImplTrait(impl_trait) => impl_trait.bounds.iter().any(|bound| {
13            if let syn::TypeParamBound::Trait(trait_bound) = bound {
14                let path = &trait_bound.path;
15                if let Some(segment) = path.segments.last() {
16                    let ident_str = segment.ident.to_string();
17                    return ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce";
18                }
19            }
20            false
21        }),
22        // Box<dyn FnMut(...)>
23        Type::Path(type_path) => {
24            if let Some(segment) = type_path.path.segments.last() {
25                if segment.ident == "Box" {
26                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
27                        if let Some(syn::GenericArgument::Type(Type::TraitObject(trait_obj))) =
28                            args.args.first()
29                        {
30                            return trait_obj.bounds.iter().any(|bound| {
31                                if let syn::TypeParamBound::Trait(trait_bound) = bound {
32                                    let path = &trait_bound.path;
33                                    if let Some(segment) = path.segments.last() {
34                                        let ident_str = segment.ident.to_string();
35                                        return ident_str == "FnMut"
36                                            || ident_str == "Fn"
37                                            || ident_str == "FnOnce";
38                                    }
39                                }
40                                false
41                            });
42                        }
43                    }
44                }
45            }
46            false
47        }
48        // bare fn(...) -> ...
49        Type::BareFn(_) => true,
50        _ => false,
51    }
52}
53
54/// Check if a generic type parameter has Fn-like bounds by looking at the where clause and bounds
55fn is_generic_fn_like(ty: &Type, generics: &syn::Generics) -> bool {
56    // Extract the ident for Type::Path that might be a generic param
57    let type_ident = match ty {
58        Type::Path(type_path) if type_path.path.segments.len() == 1 => {
59            &type_path.path.segments[0].ident
60        }
61        _ => return false,
62    };
63
64    // Check if it's a type parameter with Fn bounds
65    for param in &generics.params {
66        if let syn::GenericParam::Type(type_param) = param {
67            if type_param.ident == *type_ident {
68                // Check the bounds on the type parameter
69                for bound in &type_param.bounds {
70                    if let syn::TypeParamBound::Trait(trait_bound) = bound {
71                        if let Some(segment) = trait_bound.path.segments.last() {
72                            let ident_str = segment.ident.to_string();
73                            if ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce" {
74                                return true;
75                            }
76                        }
77                    }
78                }
79            }
80        }
81    }
82
83    // Also check where clause
84    if let Some(where_clause) = &generics.where_clause {
85        for predicate in &where_clause.predicates {
86            if let syn::WherePredicate::Type(pred) = predicate {
87                if let Type::Path(bounded_type) = &pred.bounded_ty {
88                    if bounded_type.path.segments.len() == 1
89                        && bounded_type.path.segments[0].ident == *type_ident
90                    {
91                        for bound in &pred.bounds {
92                            if let syn::TypeParamBound::Trait(trait_bound) = bound {
93                                if let Some(segment) = trait_bound.path.segments.last() {
94                                    let ident_str = segment.ident.to_string();
95                                    if ident_str == "FnMut"
96                                        || ident_str == "Fn"
97                                        || ident_str == "FnOnce"
98                                    {
99                                        return true;
100                                    }
101                                }
102                            }
103                        }
104                    }
105                }
106            }
107        }
108    }
109
110    false
111}
112
113/// Unified check: is this type Fn-like, either syntactically or via generic bounds?
114fn is_fn_param(ty: &Type, generics: &syn::Generics) -> bool {
115    is_fn_like_type(ty) || is_generic_fn_like(ty, generics)
116}
117
118/// Check if a type is `impl Fn() + ...` or `impl FnMut() + ...` with **zero** arguments.
119/// Only these can be stored through [`CallbackHolder`] (excludes `FnOnce` which can't be
120/// called more than once).
121fn is_zero_arg_fn_impl_trait(ty: &Type) -> bool {
122    if let Type::ImplTrait(impl_trait) = ty {
123        impl_trait.bounds.iter().any(|bound| {
124            if let syn::TypeParamBound::Trait(trait_bound) = bound {
125                if let Some(segment) = trait_bound.path.segments.last() {
126                    let ident_str = segment.ident.to_string();
127                    if ident_str == "Fn" || ident_str == "FnMut" {
128                        if let syn::PathArguments::Parenthesized(args) = &segment.arguments {
129                            return args.inputs.is_empty();
130                        }
131                    }
132                }
133            }
134            false
135        })
136    } else {
137        false
138    }
139}
140
141fn core_crate_path() -> TokenStream2 {
142    let crate_name = crate_name("cranpose")
143        .ok()
144        .or_else(|| crate_name("cranpose-core").ok());
145
146    match crate_name {
147        Some(FoundCrate::Itself) => quote!(crate),
148        Some(FoundCrate::Name(name)) => {
149            let ident = Ident::new(&name, Span::call_site());
150            quote!(#ident)
151        }
152        None => quote!(cranpose_core),
153    }
154}
155
156#[proc_macro_attribute]
157pub fn composable(attr: TokenStream, item: TokenStream) -> TokenStream {
158    let attr_tokens = TokenStream2::from(attr);
159    let mut enable_skip = true;
160    let core_path = core_crate_path();
161    if !attr_tokens.is_empty() {
162        match syn::parse2::<Ident>(attr_tokens) {
163            Ok(ident) if ident == "no_skip" => enable_skip = false,
164            Ok(other) => {
165                return syn::Error::new_spanned(other, "unsupported composable attribute")
166                    .to_compile_error()
167                    .into();
168            }
169            Err(err) => {
170                return err.to_compile_error().into();
171            }
172        }
173    }
174
175    let mut func = parse_macro_input!(item as ItemFn);
176
177    struct ParamInfo {
178        ident: Ident,
179        pat: Box<Pat>,
180        ty: Type,
181        pat_is_mut: bool,
182        is_impl_trait: bool,
183    }
184
185    let mut param_info: Vec<ParamInfo> = Vec::new();
186
187    for (index, arg) in func.sig.inputs.iter_mut().enumerate() {
188        if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
189            let pat_is_mut = matches!(
190                pat.as_ref(),
191                Pat::Ident(pat_ident) if pat_ident.mutability.is_some()
192            );
193            let is_impl_trait = matches!(**ty, Type::ImplTrait(_));
194
195            if is_impl_trait {
196                let original_pat: Box<Pat> = pat.clone();
197                if let Pat::Ident(pat_ident) = &**pat {
198                    param_info.push(ParamInfo {
199                        ident: pat_ident.ident.clone(),
200                        pat: original_pat,
201                        ty: ty.as_ref().clone(),
202                        pat_is_mut,
203                        is_impl_trait: true,
204                    });
205                } else {
206                    param_info.push(ParamInfo {
207                        ident: Ident::new(&format!("__arg{}", index), Span::call_site()),
208                        pat: original_pat,
209                        ty: ty.as_ref().clone(),
210                        pat_is_mut,
211                        is_impl_trait: true,
212                    });
213                }
214            } else {
215                let ident = Ident::new(&format!("__arg{}", index), Span::call_site());
216                let original_pat: Box<Pat> = pat.clone();
217                **pat = syn::parse_quote! { #ident };
218                param_info.push(ParamInfo {
219                    ident,
220                    pat: original_pat,
221                    ty: ty.as_ref().clone(),
222                    pat_is_mut,
223                    is_impl_trait: false,
224                });
225            }
226        }
227    }
228
229    let scope_label_ident = func.sig.ident.clone();
230    let original_block = func.block.clone();
231    let helper_block = original_block.clone();
232    let recranpose_block = original_block.clone();
233    let key_expr = quote! { #core_path::location_key(file!(), line!(), column!()) };
234
235    // Rebinds will be generated later in the helper_body context where we have access to slots
236    let rebinds_for_no_skip: Vec<_> = param_info
237        .iter()
238        .map(|info| {
239            let ident = &info.ident;
240            let pat = &info.pat;
241            quote! { let #pat = #ident; }
242        })
243        .collect();
244
245    let return_ty: syn::Type = match &func.sig.output {
246        ReturnType::Default => syn::parse_quote! { () },
247        ReturnType::Type(_, ty) => ty.as_ref().clone(),
248    };
249    let returns_unit = match &func.sig.output {
250        ReturnType::Default => true,
251        ReturnType::Type(_, ty) => {
252            matches!(ty.as_ref(), Type::Tuple(tuple) if tuple.elems.is_empty())
253        }
254    };
255    let _helper_ident = Ident::new(
256        &format!("__cranpose_impl_{}", func.sig.ident),
257        Span::call_site(),
258    );
259    let generics = func.sig.generics.clone();
260    let (_impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
261
262    let _helper_inputs: Vec<TokenStream2> = param_info
263        .iter()
264        .map(|info| {
265            let ident = &info.ident;
266            let ty = &info.ty;
267            quote! { #ident: #ty }
268        })
269        .collect();
270
271    // Check if any params are impl Trait that we can't store in a slot.
272    // Zero-arg Fn-like impl traits (impl Fn() + 'static) are handled via CallbackHolder.
273    let has_unhandled_impl_trait = param_info
274        .iter()
275        .any(|info| info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty));
276
277    if enable_skip && !has_unhandled_impl_trait {
278        let helper_ident = Ident::new(
279            &format!("__cranpose_impl_{}", func.sig.ident),
280            Span::call_site(),
281        );
282        let generics = func.sig.generics.clone();
283        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
284        let ty_generics_turbofish = ty_generics.as_turbofish();
285
286        // Helper function signature: all params except unhandled impl Trait.
287        // Zero-arg Fn impl traits are included (they become anonymous generics).
288        let helper_inputs: Vec<TokenStream2> = param_info
289            .iter()
290            .filter_map(|info| {
291                if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
292                    None
293                } else {
294                    let ident = &info.ident;
295                    let ty = &info.ty;
296                    Some(quote! { #ident: #ty })
297                }
298            })
299            .collect();
300
301        // Separate Fn-like params from regular params
302        let param_state_slots: Vec<Ident> = (0..param_info.len())
303            .map(|index| Ident::new(&format!("__param_state_slot{}", index), Span::call_site()))
304            .collect();
305
306        let param_setup: Vec<TokenStream2> = param_info
307            .iter()
308            .zip(param_state_slots.iter())
309            .map(|(info, slot_ident)| {
310                // Zero-arg Fn impl traits and generic Fn params → CallbackHolder
311                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
312                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
313                {
314                    let ident = &info.ident;
315                    quote! {
316                        let #slot_ident = __composer
317                            .__use_param_slot(|| #core_path::CallbackHolder::new());
318                        __composer.with_slot_value::<#core_path::CallbackHolder, _>(
319                            #slot_ident,
320                            |holder| {
321                                holder.update(#ident);
322                            },
323                        );
324                        __changed = true;
325                    }
326                } else if info.is_impl_trait {
327                    // Non-Fn impl trait – cannot store, always mark changed
328                    quote! { __changed = true; }
329                } else {
330                    let ident = &info.ident;
331                    let ty = &info.ty;
332                    quote! {
333                        let #slot_ident = __composer
334                            .__use_param_slot(|| #core_path::ParamState::<#ty>::default());
335                        if __composer.with_slot_value_mut::<#core_path::ParamState<#ty>, _>(
336                            #slot_ident,
337                            |state| state.update(&#ident),
338                        )
339                        {
340                            __changed = true;
341                        }
342                    }
343                }
344            })
345            .collect();
346
347        let param_setup_recompose: Vec<TokenStream2> = param_info
348            .iter()
349            .zip(param_state_slots.iter())
350            .map(|(info, slot_ident)| {
351                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
352                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
353                {
354                    quote! {
355                        let #slot_ident = __composer
356                            .__use_param_slot(|| #core_path::CallbackHolder::new());
357                    }
358                } else if info.is_impl_trait {
359                    quote! {}
360                } else {
361                    let ty = &info.ty;
362                    quote! {
363                        let #slot_ident = __composer
364                            .__use_param_slot(|| #core_path::ParamState::<#ty>::default());
365                    }
366                }
367            })
368            .collect();
369
370        let rebinds: Vec<TokenStream2> = param_info
371            .iter()
372            .zip(param_state_slots.iter())
373            .map(|(info, slot_ident)| {
374                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
375                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
376                {
377                    let pat = &info.pat;
378                    let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
379                    if can_add_mut && !info.pat_is_mut {
380                        quote! {
381                            #[allow(unused_mut)]
382                            let mut #pat = __composer
383                                .with_slot_value::<#core_path::CallbackHolder, _>(
384                                    #slot_ident,
385                                    |holder| holder.clone_rc(),
386                                );
387                        }
388                    } else {
389                        quote! {
390                            #[allow(unused_mut)]
391                            let #pat = __composer
392                                .with_slot_value::<#core_path::CallbackHolder, _>(
393                                    #slot_ident,
394                                    |holder| holder.clone_rc(),
395                                );
396                        }
397                    }
398                } else if info.is_impl_trait {
399                    quote! {}
400                } else {
401                    let pat = &info.pat;
402                    let ident = &info.ident;
403                    quote! {
404                        let #pat = #ident;
405                    }
406                }
407            })
408            .collect();
409
410        let rebinds_for_recompose: Vec<TokenStream2> = param_info
411            .iter()
412            .zip(param_state_slots.iter())
413            .map(|(info, slot_ident)| {
414                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
415                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
416                {
417                    let pat = &info.pat;
418                    let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
419                    if can_add_mut && !info.pat_is_mut {
420                        quote! {
421                            #[allow(unused_mut)]
422                            let mut #pat = __composer
423                                .with_slot_value::<#core_path::CallbackHolder, _>(
424                                    #slot_ident,
425                                    |holder| holder.clone_rc(),
426                                );
427                        }
428                    } else {
429                        quote! {
430                            #[allow(unused_mut)]
431                            let #pat = __composer
432                                .with_slot_value::<#core_path::CallbackHolder, _>(
433                                    #slot_ident,
434                                    |holder| holder.clone_rc(),
435                                );
436                        }
437                    }
438                } else if info.is_impl_trait {
439                    quote! {}
440                } else {
441                    let pat = &info.pat;
442                    let ty = &info.ty;
443                    quote! {
444                        let #pat = __composer
445                            .with_slot_value::<#core_path::ParamState<#ty>, _>(
446                                #slot_ident,
447                                |state| {
448                                    state
449                                        .value()
450                                        .expect("composable parameter missing for recomposition")
451                                },
452                            );
453                    }
454                }
455            })
456            .collect();
457
458        let recranpose_fn_ident = Ident::new(
459            &format!("__cranpose_recranpose_{}", func.sig.ident),
460            Span::call_site(),
461        );
462
463        let recranpose_setter = quote! {
464            {
465                __composer.set_recranpose_callback(move |
466                    __composer: &#core_path::Composer|
467                {
468                    #recranpose_fn_ident #ty_generics_turbofish (
469                        __composer
470                    );
471                });
472            }
473        };
474
475        let helper_body = if returns_unit {
476            quote! {
477                #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
478                let __current_scope = __composer
479                    .current_recranpose_scope()
480                    .expect("missing recompose scope");
481                let mut __changed = __current_scope.should_recompose();
482                #(#param_setup)*
483                #recranpose_setter
484                if !__changed && __current_scope.has_composed_once() {
485                    __composer.skip_current_group();
486                    return;
487                }
488                #(#rebinds)*
489                #helper_block
490            }
491        } else {
492            quote! {
493                #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
494                let __current_scope = __composer
495                    .current_recranpose_scope()
496                    .expect("missing recompose scope");
497                let mut __changed = __current_scope.should_recompose();
498                #(#param_setup)*
499                #recranpose_setter
500                let __result_slot_index = __composer
501                    .__use_return_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
502                let __has_previous = __composer
503                    .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
504                        __result_slot_index,
505                        |slot| slot.get().is_some(),
506                    );
507                if !__changed && __has_previous {
508                    __composer.skip_current_group();
509                    let __result = __composer
510                        .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
511                            __result_slot_index,
512                            |slot| {
513                                slot.get()
514                                    .expect("composable return value missing during skip")
515                            },
516                        );
517                    return __result;
518                }
519                let __value: #return_ty = {
520                    #(#rebinds)*
521                    #helper_block
522                };
523                __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
524                    __result_slot_index,
525                    |slot| {
526                        slot.store(__value.clone());
527                    },
528                );
529                __value
530            }
531        };
532
533        let recranpose_fn_body = if returns_unit {
534            quote! {
535                #(#param_setup_recompose)*
536                #(#rebinds_for_recompose)*
537                #recranpose_block
538                #recranpose_setter
539            }
540        } else {
541            quote! {
542                #(#param_setup_recompose)*
543                let __result_slot_index = __composer
544                    .__use_return_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
545                #(#rebinds_for_recompose)*
546                let __value: #return_ty = {
547                    #recranpose_block
548                };
549                __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
550                    __result_slot_index,
551                    |slot| {
552                        slot.store(__value.clone());
553                    },
554                );
555                #recranpose_setter
556                __value
557            }
558        };
559
560        let recranpose_fn = quote! {
561            #[allow(non_snake_case)]
562            fn #recranpose_fn_ident #impl_generics (
563                __composer: &#core_path::Composer
564            ) -> #return_ty #where_clause {
565                #recranpose_fn_body
566            }
567        };
568
569        let helper_fn = quote! {
570            #[allow(non_snake_case, clippy::too_many_arguments)]
571            fn #helper_ident #impl_generics (
572                __composer: &#core_path::Composer
573                #(, #helper_inputs)*
574            ) -> #return_ty #where_clause {
575                #helper_body
576            }
577        };
578
579        // Wrapper args: pass all params except unhandled impl Trait on initial call
580        let wrapper_args: Vec<TokenStream2> = param_info
581            .iter()
582            .filter_map(|info| {
583                if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
584                    None
585                } else {
586                    let ident = &info.ident;
587                    Some(quote! { #ident })
588                }
589            })
590            .collect();
591
592        let wrapped = quote!({
593            #core_path::with_current_composer(|__composer: &#core_path::Composer| {
594                __composer.with_group(#key_expr, |__composer: &#core_path::Composer| {
595                    #helper_ident(__composer #(, #wrapper_args)*)
596                })
597            })
598        });
599        *func.block = syn::parse2(wrapped).expect("failed to build block");
600        TokenStream::from(quote! {
601            #recranpose_fn
602            #helper_fn
603            #func
604        })
605    } else {
606        // no_skip path: still uses simple rebinds
607        let wrapped = quote!({
608            #core_path::with_current_composer(|__composer: &#core_path::Composer| {
609                __composer.with_group(#key_expr, |__scope: &#core_path::Composer| {
610                    #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
611                    #(#rebinds_for_no_skip)*
612                    #original_block
613                })
614            })
615        });
616        *func.block = syn::parse2(wrapped).expect("failed to build block");
617        TokenStream::from(quote! { #func })
618    }
619}