Skip to main content

cranpose_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use proc_macro_crate::{crate_name, FoundCrate};
4use quote::quote;
5use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Type};
6
7/// Check if a type is Fn-like (impl FnMut/Fn/FnOnce, Box<dyn FnMut>, generic with Fn bound, etc.)
8/// For generic type parameters (e.g., `F` where F: FnMut()), we need to check the bounds.
9fn is_fn_like_type(ty: &Type) -> bool {
10    match ty {
11        // impl FnMut(...) + 'static, impl Fn(...), etc.
12        Type::ImplTrait(impl_trait) => impl_trait.bounds.iter().any(|bound| {
13            if let syn::TypeParamBound::Trait(trait_bound) = bound {
14                let path = &trait_bound.path;
15                if let Some(segment) = path.segments.last() {
16                    let ident_str = segment.ident.to_string();
17                    return ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce";
18                }
19            }
20            false
21        }),
22        // Box<dyn FnMut(...)>
23        Type::Path(type_path) => {
24            if let Some(segment) = type_path.path.segments.last() {
25                if segment.ident == "Box" {
26                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
27                        if let Some(syn::GenericArgument::Type(Type::TraitObject(trait_obj))) =
28                            args.args.first()
29                        {
30                            return trait_obj.bounds.iter().any(|bound| {
31                                if let syn::TypeParamBound::Trait(trait_bound) = bound {
32                                    let path = &trait_bound.path;
33                                    if let Some(segment) = path.segments.last() {
34                                        let ident_str = segment.ident.to_string();
35                                        return ident_str == "FnMut"
36                                            || ident_str == "Fn"
37                                            || ident_str == "FnOnce";
38                                    }
39                                }
40                                false
41                            });
42                        }
43                    }
44                }
45            }
46            false
47        }
48        // bare fn(...) -> ...
49        Type::BareFn(_) => true,
50        _ => false,
51    }
52}
53
54/// Check if a generic type parameter has Fn-like bounds by looking at the where clause and bounds
55fn is_generic_fn_like(ty: &Type, generics: &syn::Generics) -> bool {
56    // Extract the ident for Type::Path that might be a generic param
57    let type_ident = match ty {
58        Type::Path(type_path) if type_path.path.segments.len() == 1 => {
59            &type_path.path.segments[0].ident
60        }
61        _ => return false,
62    };
63
64    // Check if it's a type parameter with Fn bounds
65    for param in &generics.params {
66        if let syn::GenericParam::Type(type_param) = param {
67            if type_param.ident == *type_ident {
68                // Check the bounds on the type parameter
69                for bound in &type_param.bounds {
70                    if let syn::TypeParamBound::Trait(trait_bound) = bound {
71                        if let Some(segment) = trait_bound.path.segments.last() {
72                            let ident_str = segment.ident.to_string();
73                            if ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce" {
74                                return true;
75                            }
76                        }
77                    }
78                }
79            }
80        }
81    }
82
83    // Also check where clause
84    if let Some(where_clause) = &generics.where_clause {
85        for predicate in &where_clause.predicates {
86            if let syn::WherePredicate::Type(pred) = predicate {
87                if let Type::Path(bounded_type) = &pred.bounded_ty {
88                    if bounded_type.path.segments.len() == 1
89                        && bounded_type.path.segments[0].ident == *type_ident
90                    {
91                        for bound in &pred.bounds {
92                            if let syn::TypeParamBound::Trait(trait_bound) = bound {
93                                if let Some(segment) = trait_bound.path.segments.last() {
94                                    let ident_str = segment.ident.to_string();
95                                    if ident_str == "FnMut"
96                                        || ident_str == "Fn"
97                                        || ident_str == "FnOnce"
98                                    {
99                                        return true;
100                                    }
101                                }
102                            }
103                        }
104                    }
105                }
106            }
107        }
108    }
109
110    false
111}
112
113/// Unified check: is this type Fn-like, either syntactically or via generic bounds?
114fn is_fn_param(ty: &Type, generics: &syn::Generics) -> bool {
115    is_fn_like_type(ty) || is_generic_fn_like(ty, generics)
116}
117
118/// Check if a type is `impl Fn() + ...` or `impl FnMut() + ...` with **zero** arguments.
119/// Only these can be stored through [`CallbackHolder`] (excludes `FnOnce` which can't be
120/// called more than once).
121fn is_zero_arg_fn_impl_trait(ty: &Type) -> bool {
122    if let Type::ImplTrait(impl_trait) = ty {
123        impl_trait.bounds.iter().any(|bound| {
124            if let syn::TypeParamBound::Trait(trait_bound) = bound {
125                if let Some(segment) = trait_bound.path.segments.last() {
126                    let ident_str = segment.ident.to_string();
127                    if ident_str == "Fn" || ident_str == "FnMut" {
128                        if let syn::PathArguments::Parenthesized(args) = &segment.arguments {
129                            return args.inputs.is_empty();
130                        }
131                    }
132                }
133            }
134            false
135        })
136    } else {
137        false
138    }
139}
140
141fn is_node_id_return(ty: &Type) -> bool {
142    matches!(
143        ty,
144        Type::Path(type_path)
145            if type_path
146                .path
147                .segments
148                .last()
149                .is_some_and(|segment| segment.ident == "NodeId")
150    )
151}
152
153fn core_crate_path() -> TokenStream2 {
154    let crate_name = crate_name("cranpose")
155        .ok()
156        .or_else(|| crate_name("cranpose-core").ok());
157
158    match crate_name {
159        Some(FoundCrate::Itself) => quote!(crate),
160        Some(FoundCrate::Name(name)) => {
161            let ident = Ident::new(&name, Span::call_site());
162            quote!(#ident)
163        }
164        None => quote!(cranpose_core),
165    }
166}
167
168#[proc_macro_attribute]
169pub fn composable(attr: TokenStream, item: TokenStream) -> TokenStream {
170    let attr_tokens = TokenStream2::from(attr);
171    let mut enable_skip = true;
172    let core_path = core_crate_path();
173    if !attr_tokens.is_empty() {
174        match syn::parse2::<Ident>(attr_tokens) {
175            Ok(ident) if ident == "no_skip" => enable_skip = false,
176            Ok(other) => {
177                return syn::Error::new_spanned(other, "unsupported composable attribute")
178                    .to_compile_error()
179                    .into();
180            }
181            Err(err) => {
182                return err.to_compile_error().into();
183            }
184        }
185    }
186
187    let mut func = parse_macro_input!(item as ItemFn);
188
189    struct ParamInfo {
190        ident: Ident,
191        pat: Box<Pat>,
192        ty: Type,
193        pat_is_mut: bool,
194        is_impl_trait: bool,
195    }
196
197    let mut param_info: Vec<ParamInfo> = Vec::new();
198
199    for (index, arg) in func.sig.inputs.iter_mut().enumerate() {
200        if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
201            let pat_is_mut = matches!(
202                pat.as_ref(),
203                Pat::Ident(pat_ident) if pat_ident.mutability.is_some()
204            );
205            let is_impl_trait = matches!(**ty, Type::ImplTrait(_));
206
207            if is_impl_trait {
208                let original_pat: Box<Pat> = pat.clone();
209                if let Pat::Ident(pat_ident) = &**pat {
210                    param_info.push(ParamInfo {
211                        ident: pat_ident.ident.clone(),
212                        pat: original_pat,
213                        ty: ty.as_ref().clone(),
214                        pat_is_mut,
215                        is_impl_trait: true,
216                    });
217                } else {
218                    param_info.push(ParamInfo {
219                        ident: Ident::new(&format!("__arg{}", index), Span::call_site()),
220                        pat: original_pat,
221                        ty: ty.as_ref().clone(),
222                        pat_is_mut,
223                        is_impl_trait: true,
224                    });
225                }
226            } else {
227                let ident = Ident::new(&format!("__arg{}", index), Span::call_site());
228                let original_pat: Box<Pat> = pat.clone();
229                **pat = syn::parse_quote! { #ident };
230                param_info.push(ParamInfo {
231                    ident,
232                    pat: original_pat,
233                    ty: ty.as_ref().clone(),
234                    pat_is_mut,
235                    is_impl_trait: false,
236                });
237            }
238        }
239    }
240
241    let scope_label_ident = func.sig.ident.clone();
242    let original_block = func.block.clone();
243    let helper_block = original_block.clone();
244    let recranpose_block = original_block.clone();
245    let key_expr = quote! { #core_path::location_key(file!(), line!(), column!()) };
246
247    // Rebinds will be generated later in the helper_body context where we have access to slots
248    let rebinds_for_no_skip: Vec<_> = param_info
249        .iter()
250        .map(|info| {
251            let ident = &info.ident;
252            let pat = &info.pat;
253            quote! { let #pat = #ident; }
254        })
255        .collect();
256
257    let return_ty: syn::Type = match &func.sig.output {
258        ReturnType::Default => syn::parse_quote! { () },
259        ReturnType::Type(_, ty) => ty.as_ref().clone(),
260    };
261    let returns_unit = match &func.sig.output {
262        ReturnType::Default => true,
263        ReturnType::Type(_, ty) => {
264            matches!(ty.as_ref(), Type::Tuple(tuple) if tuple.elems.is_empty())
265        }
266    };
267    let invalidate_return_consumer = if returns_unit || is_node_id_return(&return_ty) {
268        quote! {}
269    } else {
270        quote! { __composer.__invalidate_return_consumer_scope(); }
271    };
272    let _helper_ident = Ident::new(
273        &format!("__cranpose_impl_{}", func.sig.ident),
274        Span::call_site(),
275    );
276    let generics = func.sig.generics.clone();
277    let (_impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
278
279    let _helper_inputs: Vec<TokenStream2> = param_info
280        .iter()
281        .map(|info| {
282            let ident = &info.ident;
283            let ty = &info.ty;
284            quote! { #ident: #ty }
285        })
286        .collect();
287
288    // Check if any params are impl Trait that we can't store in a slot.
289    // Zero-arg Fn-like impl traits (impl Fn() + 'static) are handled via CallbackHolder.
290    let has_unhandled_impl_trait = param_info
291        .iter()
292        .any(|info| info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty));
293
294    if enable_skip && !has_unhandled_impl_trait {
295        let helper_ident = Ident::new(
296            &format!("__cranpose_impl_{}", func.sig.ident),
297            Span::call_site(),
298        );
299        let generics = func.sig.generics.clone();
300        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
301        let ty_generics_turbofish = ty_generics.as_turbofish();
302
303        // Helper function signature: all params except unhandled impl Trait.
304        // Zero-arg Fn impl traits are included (they become anonymous generics).
305        let helper_inputs: Vec<TokenStream2> = param_info
306            .iter()
307            .filter_map(|info| {
308                if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
309                    None
310                } else {
311                    let ident = &info.ident;
312                    let ty = &info.ty;
313                    Some(quote! { #ident: #ty })
314                }
315            })
316            .collect();
317
318        // Separate Fn-like params from regular params
319        let param_state_slots: Vec<Ident> = (0..param_info.len())
320            .map(|index| Ident::new(&format!("__param_state_slot{}", index), Span::call_site()))
321            .collect();
322
323        let param_setup: Vec<TokenStream2> = param_info
324            .iter()
325            .zip(param_state_slots.iter())
326            .map(|(info, slot_ident)| {
327                // Zero-arg Fn impl traits and generic Fn params → CallbackHolder
328                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
329                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
330                {
331                    let ident = &info.ident;
332                    quote! {
333                        let #slot_ident = __composer
334                            .__use_param_slot(|| #core_path::CallbackHolder::new());
335                        __composer.with_slot_value::<#core_path::CallbackHolder, _>(
336                            #slot_ident,
337                            |holder| {
338                                holder.update(#ident);
339                            },
340                        );
341                        __changed = true;
342                    }
343                } else if info.is_impl_trait {
344                    // Non-Fn impl trait – cannot store, always mark changed
345                    quote! { __changed = true; }
346                } else {
347                    let ident = &info.ident;
348                    let ty = &info.ty;
349                    quote! {
350                        let #slot_ident = __composer
351                            .__use_param_slot(|| #core_path::ParamState::<#ty>::default());
352                        if __composer.with_slot_value_mut::<#core_path::ParamState<#ty>, _>(
353                            #slot_ident,
354                            |state| state.update(&#ident),
355                        )
356                        {
357                            __changed = true;
358                        }
359                    }
360                }
361            })
362            .collect();
363
364        let param_setup_recompose: Vec<TokenStream2> = param_info
365            .iter()
366            .zip(param_state_slots.iter())
367            .map(|(info, slot_ident)| {
368                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
369                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
370                {
371                    quote! {
372                        let #slot_ident = __composer
373                            .__use_param_slot(|| #core_path::CallbackHolder::new());
374                    }
375                } else if info.is_impl_trait {
376                    quote! {}
377                } else {
378                    let ty = &info.ty;
379                    quote! {
380                        let #slot_ident = __composer
381                            .__use_param_slot(|| #core_path::ParamState::<#ty>::default());
382                    }
383                }
384            })
385            .collect();
386
387        let rebinds: Vec<TokenStream2> = param_info
388            .iter()
389            .zip(param_state_slots.iter())
390            .map(|(info, slot_ident)| {
391                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
392                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
393                {
394                    let pat = &info.pat;
395                    let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
396                    if can_add_mut && !info.pat_is_mut {
397                        quote! {
398                            #[allow(unused_mut)]
399                            let mut #pat = __composer
400                                .with_slot_value::<#core_path::CallbackHolder, _>(
401                                    #slot_ident,
402                                    |holder| holder.clone_rc(),
403                                );
404                        }
405                    } else {
406                        quote! {
407                            #[allow(unused_mut)]
408                            let #pat = __composer
409                                .with_slot_value::<#core_path::CallbackHolder, _>(
410                                    #slot_ident,
411                                    |holder| holder.clone_rc(),
412                                );
413                        }
414                    }
415                } else if info.is_impl_trait {
416                    quote! {}
417                } else {
418                    let pat = &info.pat;
419                    let ident = &info.ident;
420                    quote! {
421                        let #pat = #ident;
422                    }
423                }
424            })
425            .collect();
426
427        let rebinds_for_recompose: Vec<TokenStream2> = param_info
428            .iter()
429            .zip(param_state_slots.iter())
430            .map(|(info, slot_ident)| {
431                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
432                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
433                {
434                    let pat = &info.pat;
435                    let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
436                    if can_add_mut && !info.pat_is_mut {
437                        quote! {
438                            #[allow(unused_mut)]
439                            let mut #pat = __composer
440                                .with_slot_value::<#core_path::CallbackHolder, _>(
441                                    #slot_ident,
442                                    |holder| holder.clone_rc(),
443                                );
444                        }
445                    } else {
446                        quote! {
447                            #[allow(unused_mut)]
448                            let #pat = __composer
449                                .with_slot_value::<#core_path::CallbackHolder, _>(
450                                    #slot_ident,
451                                    |holder| holder.clone_rc(),
452                                );
453                        }
454                    }
455                } else if info.is_impl_trait {
456                    quote! {}
457                } else {
458                    let pat = &info.pat;
459                    let ty = &info.ty;
460                    quote! {
461                        let #pat = __composer
462                            .with_slot_value::<#core_path::ParamState<#ty>, _>(
463                                #slot_ident,
464                                |state| {
465                                    state
466                                        .value()
467                                        .expect("composable parameter missing for recomposition")
468                                },
469                            );
470                    }
471                }
472            })
473            .collect();
474
475        let recranpose_fn_ident = Ident::new(
476            &format!("__cranpose_recranpose_{}", func.sig.ident),
477            Span::call_site(),
478        );
479
480        let recranpose_setter = quote! {
481            {
482                __composer.set_recranpose_callback(move |
483                    __composer: &#core_path::Composer|
484                {
485                    #recranpose_fn_ident #ty_generics_turbofish (
486                        __composer
487                    );
488                });
489            }
490        };
491
492        let helper_body = if returns_unit {
493            quote! {
494                #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
495                let __current_scope = __composer
496                    .current_recranpose_scope()
497                    .expect("missing recompose scope");
498                let mut __changed = __current_scope.should_recompose();
499                #(#param_setup)*
500                #recranpose_setter
501                if !__changed && __current_scope.has_composed_once() {
502                    __composer.skip_current_group();
503                    return;
504                }
505                #(#rebinds)*
506                #helper_block
507            }
508        } else {
509            quote! {
510                #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
511                let __current_scope = __composer
512                    .current_recranpose_scope()
513                    .expect("missing recompose scope");
514                let mut __changed = __current_scope.should_recompose();
515                #(#param_setup)*
516                #recranpose_setter
517                let __result_slot_index = __composer
518                    .__use_return_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
519                let __has_previous = __composer
520                    .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
521                        __result_slot_index,
522                        |slot| slot.get().is_some(),
523                    );
524                if !__changed && __has_previous {
525                    __composer.skip_current_group();
526                    let __result = __composer
527                        .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
528                            __result_slot_index,
529                            |slot| {
530                                slot.get()
531                                    .expect("composable return value missing during skip")
532                            },
533                        );
534                    return __result;
535                }
536                let __value: #return_ty = {
537                    #(#rebinds)*
538                    #helper_block
539                };
540                __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
541                    __result_slot_index,
542                    |slot| {
543                        slot.store(__value.clone());
544                    },
545                );
546                __value
547            }
548        };
549
550        let recranpose_fn_body = if returns_unit {
551            quote! {
552                #(#param_setup_recompose)*
553                #(#rebinds_for_recompose)*
554                #recranpose_block
555                #recranpose_setter
556            }
557        } else {
558            quote! {
559                #(#param_setup_recompose)*
560                let __result_slot_index = __composer
561                    .__use_return_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
562                #(#rebinds_for_recompose)*
563                let __value: #return_ty = {
564                    #recranpose_block
565                };
566                __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
567                    __result_slot_index,
568                    |slot| {
569                        slot.store(__value.clone());
570                    },
571                );
572                #recranpose_setter
573                #invalidate_return_consumer
574                __value
575            }
576        };
577
578        let recranpose_fn = quote! {
579            #[allow(non_snake_case)]
580            fn #recranpose_fn_ident #impl_generics (
581                __composer: &#core_path::Composer
582            ) -> #return_ty #where_clause {
583                #recranpose_fn_body
584            }
585        };
586
587        let helper_fn = quote! {
588            #[allow(non_snake_case, clippy::too_many_arguments)]
589            fn #helper_ident #impl_generics (
590                __composer: &#core_path::Composer
591                #(, #helper_inputs)*
592            ) -> #return_ty #where_clause {
593                #helper_body
594            }
595        };
596
597        // Wrapper args: pass all params except unhandled impl Trait on initial call
598        let wrapper_args: Vec<TokenStream2> = param_info
599            .iter()
600            .filter_map(|info| {
601                if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
602                    None
603                } else {
604                    let ident = &info.ident;
605                    Some(quote! { #ident })
606                }
607            })
608            .collect();
609
610        let wrapped = quote!({
611            #core_path::with_current_composer(|__composer: &#core_path::Composer| {
612                __composer.with_group(#key_expr, |__composer: &#core_path::Composer| {
613                    #helper_ident(__composer #(, #wrapper_args)*)
614                })
615            })
616        });
617        *func.block = syn::parse2(wrapped).expect("failed to build block");
618        TokenStream::from(quote! {
619            #recranpose_fn
620            #helper_fn
621            #func
622        })
623    } else {
624        // no_skip path: still uses simple rebinds
625        let wrapped = quote!({
626            #core_path::with_current_composer(|__composer: &#core_path::Composer| {
627                __composer.with_group(#key_expr, |__scope: &#core_path::Composer| {
628                    #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
629                    #(#rebinds_for_no_skip)*
630                    #original_block
631                })
632            })
633        });
634        *func.block = syn::parse2(wrapped).expect("failed to build block");
635        TokenStream::from(quote! { #func })
636    }
637}