Skip to main content

cranpose_macros/
lib.rs

1#![deny(unsafe_code)]
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use proc_macro_crate::{crate_name, FoundCrate};
6use quote::quote;
7use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Type};
8
9/// Check if a type is Fn-like (impl FnMut/Fn/FnOnce, Box<dyn FnMut>, generic with Fn bound, etc.)
10/// For generic type parameters (e.g., `F` where F: FnMut()), we need to check the bounds.
11fn is_fn_like_type(ty: &Type) -> bool {
12    match ty {
13        // impl FnMut(...) + 'static, impl Fn(...), etc.
14        Type::ImplTrait(impl_trait) => impl_trait.bounds.iter().any(|bound| {
15            if let syn::TypeParamBound::Trait(trait_bound) = bound {
16                let path = &trait_bound.path;
17                if let Some(segment) = path.segments.last() {
18                    let ident_str = segment.ident.to_string();
19                    return ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce";
20                }
21            }
22            false
23        }),
24        // Box<dyn FnMut(...)>
25        Type::Path(type_path) => {
26            if let Some(segment) = type_path.path.segments.last() {
27                if segment.ident == "Box" {
28                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
29                        if let Some(syn::GenericArgument::Type(Type::TraitObject(trait_obj))) =
30                            args.args.first()
31                        {
32                            return trait_obj.bounds.iter().any(|bound| {
33                                if let syn::TypeParamBound::Trait(trait_bound) = bound {
34                                    let path = &trait_bound.path;
35                                    if let Some(segment) = path.segments.last() {
36                                        let ident_str = segment.ident.to_string();
37                                        return ident_str == "FnMut"
38                                            || ident_str == "Fn"
39                                            || ident_str == "FnOnce";
40                                    }
41                                }
42                                false
43                            });
44                        }
45                    }
46                }
47            }
48            false
49        }
50        // bare fn(...) -> ...
51        Type::BareFn(_) => true,
52        _ => false,
53    }
54}
55
56/// Check if a generic type parameter has Fn-like bounds by looking at the where clause and bounds
57fn is_generic_fn_like(ty: &Type, generics: &syn::Generics) -> bool {
58    // Extract the ident for Type::Path that might be a generic param
59    let type_ident = match ty {
60        Type::Path(type_path) if type_path.path.segments.len() == 1 => {
61            &type_path.path.segments[0].ident
62        }
63        _ => return false,
64    };
65
66    // Check if it's a type parameter with Fn bounds
67    for param in &generics.params {
68        if let syn::GenericParam::Type(type_param) = param {
69            if type_param.ident == *type_ident {
70                // Check the bounds on the type parameter
71                for bound in &type_param.bounds {
72                    if let syn::TypeParamBound::Trait(trait_bound) = bound {
73                        if let Some(segment) = trait_bound.path.segments.last() {
74                            let ident_str = segment.ident.to_string();
75                            if ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce" {
76                                return true;
77                            }
78                        }
79                    }
80                }
81            }
82        }
83    }
84
85    // Also check where clause
86    if let Some(where_clause) = &generics.where_clause {
87        for predicate in &where_clause.predicates {
88            if let syn::WherePredicate::Type(pred) = predicate {
89                if let Type::Path(bounded_type) = &pred.bounded_ty {
90                    if bounded_type.path.segments.len() == 1
91                        && bounded_type.path.segments[0].ident == *type_ident
92                    {
93                        for bound in &pred.bounds {
94                            if let syn::TypeParamBound::Trait(trait_bound) = bound {
95                                if let Some(segment) = trait_bound.path.segments.last() {
96                                    let ident_str = segment.ident.to_string();
97                                    if ident_str == "FnMut"
98                                        || ident_str == "Fn"
99                                        || ident_str == "FnOnce"
100                                    {
101                                        return true;
102                                    }
103                                }
104                            }
105                        }
106                    }
107                }
108            }
109        }
110    }
111
112    false
113}
114
115/// Unified check: is this type Fn-like, either syntactically or via generic bounds?
116fn is_fn_param(ty: &Type, generics: &syn::Generics) -> bool {
117    is_fn_like_type(ty) || is_generic_fn_like(ty, generics)
118}
119
120/// Check if a type is `impl Fn() + ...` or `impl FnMut() + ...` with **zero** arguments.
121/// Only these can be stored through [`CallbackHolder`] (excludes `FnOnce` which can't be
122/// called more than once).
123fn is_zero_arg_fn_impl_trait(ty: &Type) -> bool {
124    if let Type::ImplTrait(impl_trait) = ty {
125        impl_trait.bounds.iter().any(|bound| {
126            if let syn::TypeParamBound::Trait(trait_bound) = bound {
127                if let Some(segment) = trait_bound.path.segments.last() {
128                    let ident_str = segment.ident.to_string();
129                    if ident_str == "Fn" || ident_str == "FnMut" {
130                        if let syn::PathArguments::Parenthesized(args) = &segment.arguments {
131                            return args.inputs.is_empty();
132                        }
133                    }
134                }
135            }
136            false
137        })
138    } else {
139        false
140    }
141}
142
143fn is_node_id_return(ty: &Type) -> bool {
144    matches!(
145        ty,
146        Type::Path(type_path)
147            if type_path
148                .path
149                .segments
150                .last()
151                .is_some_and(|segment| segment.ident == "NodeId")
152    )
153}
154
155fn core_crate_path() -> TokenStream2 {
156    let crate_name = crate_name("cranpose")
157        .ok()
158        .or_else(|| crate_name("cranpose-core").ok());
159
160    match crate_name {
161        Some(FoundCrate::Itself) => quote!(crate),
162        Some(FoundCrate::Name(name)) => {
163            let ident = Ident::new(&name, Span::call_site());
164            quote!(#ident)
165        }
166        None => quote!(cranpose_core),
167    }
168}
169
170#[proc_macro_attribute]
171pub fn composable(attr: TokenStream, item: TokenStream) -> TokenStream {
172    let attr_tokens = TokenStream2::from(attr);
173    let mut enable_skip = true;
174    let core_path = core_crate_path();
175    if !attr_tokens.is_empty() {
176        match syn::parse2::<Ident>(attr_tokens) {
177            Ok(ident) if ident == "no_skip" => enable_skip = false,
178            Ok(other) => {
179                return syn::Error::new_spanned(other, "unsupported composable attribute")
180                    .to_compile_error()
181                    .into();
182            }
183            Err(err) => {
184                return err.to_compile_error().into();
185            }
186        }
187    }
188
189    let mut func = parse_macro_input!(item as ItemFn);
190
191    struct ParamInfo {
192        ident: Ident,
193        pat: Box<Pat>,
194        ty: Type,
195        pat_is_mut: bool,
196        is_impl_trait: bool,
197    }
198
199    let mut param_info: Vec<ParamInfo> = Vec::new();
200
201    for (index, arg) in func.sig.inputs.iter_mut().enumerate() {
202        if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
203            let pat_is_mut = matches!(
204                pat.as_ref(),
205                Pat::Ident(pat_ident) if pat_ident.mutability.is_some()
206            );
207            let is_impl_trait = matches!(**ty, Type::ImplTrait(_));
208
209            if is_impl_trait {
210                let original_pat: Box<Pat> = pat.clone();
211                if let Pat::Ident(pat_ident) = &**pat {
212                    param_info.push(ParamInfo {
213                        ident: pat_ident.ident.clone(),
214                        pat: original_pat,
215                        ty: ty.as_ref().clone(),
216                        pat_is_mut,
217                        is_impl_trait: true,
218                    });
219                } else {
220                    param_info.push(ParamInfo {
221                        ident: Ident::new(&format!("__arg{}", index), Span::call_site()),
222                        pat: original_pat,
223                        ty: ty.as_ref().clone(),
224                        pat_is_mut,
225                        is_impl_trait: true,
226                    });
227                }
228            } else {
229                let ident = Ident::new(&format!("__arg{}", index), Span::call_site());
230                let original_pat: Box<Pat> = pat.clone();
231                **pat = syn::parse_quote! { #ident };
232                param_info.push(ParamInfo {
233                    ident,
234                    pat: original_pat,
235                    ty: ty.as_ref().clone(),
236                    pat_is_mut,
237                    is_impl_trait: false,
238                });
239            }
240        }
241    }
242
243    let scope_label_ident = func.sig.ident.clone();
244    let original_block = func.block.clone();
245    let helper_block = original_block.clone();
246    let recranpose_block = original_block.clone();
247    let key_expr = quote! { #core_path::location_key(file!(), line!(), column!()) };
248
249    // Rebinds will be generated later in the helper_body context where we have access to slots
250    let rebinds_for_no_skip: Vec<_> = param_info
251        .iter()
252        .map(|info| {
253            let ident = &info.ident;
254            let pat = &info.pat;
255            quote! { let #pat = #ident; }
256        })
257        .collect();
258
259    let return_ty: syn::Type = match &func.sig.output {
260        ReturnType::Default => syn::parse_quote! { () },
261        ReturnType::Type(_, ty) => ty.as_ref().clone(),
262    };
263    let returns_unit = match &func.sig.output {
264        ReturnType::Default => true,
265        ReturnType::Type(_, ty) => {
266            matches!(ty.as_ref(), Type::Tuple(tuple) if tuple.elems.is_empty())
267        }
268    };
269    let invalidate_return_consumer = if returns_unit || is_node_id_return(&return_ty) {
270        quote! {}
271    } else {
272        quote! { __composer.__invalidate_return_consumer_scope(); }
273    };
274    let _helper_ident = Ident::new(
275        &format!("__cranpose_impl_{}", func.sig.ident),
276        Span::call_site(),
277    );
278    let generics = func.sig.generics.clone();
279    let (_impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
280
281    let _helper_inputs: Vec<TokenStream2> = param_info
282        .iter()
283        .map(|info| {
284            let ident = &info.ident;
285            let ty = &info.ty;
286            quote! { #ident: #ty }
287        })
288        .collect();
289
290    // Check if any params are impl Trait that we can't store in a slot.
291    // Zero-arg Fn-like impl traits (impl Fn() + 'static) are handled via CallbackHolder.
292    let has_unhandled_impl_trait = param_info
293        .iter()
294        .any(|info| info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty));
295
296    if enable_skip && !has_unhandled_impl_trait {
297        let helper_ident = Ident::new(
298            &format!("__cranpose_impl_{}", func.sig.ident),
299            Span::call_site(),
300        );
301        let generics = func.sig.generics.clone();
302        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
303        let ty_generics_turbofish = ty_generics.as_turbofish();
304
305        // Helper function signature: all params except unhandled impl Trait.
306        // Zero-arg Fn impl traits are included (they become anonymous generics).
307        let helper_inputs: Vec<TokenStream2> = param_info
308            .iter()
309            .filter_map(|info| {
310                if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
311                    None
312                } else {
313                    let ident = &info.ident;
314                    let ty = &info.ty;
315                    Some(quote! { #ident: #ty })
316                }
317            })
318            .collect();
319
320        // Separate Fn-like params from regular params
321        let param_state_slots: Vec<Ident> = (0..param_info.len())
322            .map(|index| Ident::new(&format!("__param_state_slot{}", index), Span::call_site()))
323            .collect();
324
325        let param_setup: Vec<TokenStream2> = param_info
326            .iter()
327            .zip(param_state_slots.iter())
328            .map(|(info, slot_ident)| {
329                // Zero-arg Fn impl traits and generic Fn params → CallbackHolder
330                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
331                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
332                {
333                    let ident = &info.ident;
334                    quote! {
335                        let #slot_ident = __composer
336                            .__use_param_slot(|| #core_path::CallbackHolder::new());
337                        __composer.with_slot_value::<#core_path::CallbackHolder, _>(
338                            #slot_ident,
339                            |holder| {
340                                holder.update(#ident);
341                            },
342                        );
343                        __changed = true;
344                    }
345                } else if info.is_impl_trait {
346                    // Non-Fn impl trait – cannot store, always mark changed
347                    quote! { __changed = true; }
348                } else {
349                    let ident = &info.ident;
350                    let ty = &info.ty;
351                    quote! {
352                        let #slot_ident = __composer
353                            .__use_param_slot(|| #core_path::ParamState::<#ty>::default());
354                        if __composer.with_slot_value_mut::<#core_path::ParamState<#ty>, _>(
355                            #slot_ident,
356                            |state| state.update(&#ident),
357                        )
358                        {
359                            __changed = true;
360                        }
361                    }
362                }
363            })
364            .collect();
365
366        let param_setup_recompose: Vec<TokenStream2> = param_info
367            .iter()
368            .zip(param_state_slots.iter())
369            .map(|(info, slot_ident)| {
370                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
371                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
372                {
373                    quote! {
374                        let #slot_ident = __composer
375                            .__use_param_slot(|| #core_path::CallbackHolder::new());
376                    }
377                } else if info.is_impl_trait {
378                    quote! {}
379                } else {
380                    let ty = &info.ty;
381                    quote! {
382                        let #slot_ident = __composer
383                            .__use_param_slot(|| #core_path::ParamState::<#ty>::default());
384                    }
385                }
386            })
387            .collect();
388
389        let rebinds: Vec<TokenStream2> = param_info
390            .iter()
391            .zip(param_state_slots.iter())
392            .map(|(info, slot_ident)| {
393                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
394                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
395                {
396                    let pat = &info.pat;
397                    let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
398                    if can_add_mut && !info.pat_is_mut {
399                        quote! {
400                            #[allow(unused_mut)]
401                            let mut #pat = __composer
402                                .with_slot_value::<#core_path::CallbackHolder, _>(
403                                    #slot_ident,
404                                    |holder| holder.clone_rc(),
405                                );
406                        }
407                    } else {
408                        quote! {
409                            #[allow(unused_mut)]
410                            let #pat = __composer
411                                .with_slot_value::<#core_path::CallbackHolder, _>(
412                                    #slot_ident,
413                                    |holder| holder.clone_rc(),
414                                );
415                        }
416                    }
417                } else if info.is_impl_trait {
418                    quote! {}
419                } else {
420                    let pat = &info.pat;
421                    let ident = &info.ident;
422                    quote! {
423                        let #pat = #ident;
424                    }
425                }
426            })
427            .collect();
428
429        let rebinds_for_recompose: Vec<TokenStream2> = param_info
430            .iter()
431            .zip(param_state_slots.iter())
432            .map(|(info, slot_ident)| {
433                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
434                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
435                {
436                    let pat = &info.pat;
437                    let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
438                    if can_add_mut && !info.pat_is_mut {
439                        quote! {
440                            #[allow(unused_mut)]
441                            let mut #pat = __composer
442                                .with_slot_value::<#core_path::CallbackHolder, _>(
443                                    #slot_ident,
444                                    |holder| holder.clone_rc(),
445                                );
446                        }
447                    } else {
448                        quote! {
449                            #[allow(unused_mut)]
450                            let #pat = __composer
451                                .with_slot_value::<#core_path::CallbackHolder, _>(
452                                    #slot_ident,
453                                    |holder| holder.clone_rc(),
454                                );
455                        }
456                    }
457                } else if info.is_impl_trait {
458                    quote! {}
459                } else {
460                    let pat = &info.pat;
461                    let ty = &info.ty;
462                    quote! {
463                        let #pat = __composer
464                            .with_slot_value::<#core_path::ParamState<#ty>, _>(
465                                #slot_ident,
466                                |state| {
467                                    state
468                                        .value()
469                                        .expect("composable parameter missing for recomposition")
470                                },
471                            );
472                    }
473                }
474            })
475            .collect();
476
477        let recranpose_fn_ident = Ident::new(
478            &format!("__cranpose_recranpose_{}", func.sig.ident),
479            Span::call_site(),
480        );
481
482        let recranpose_setter = quote! {
483            {
484                __composer.set_recranpose_callback(move |
485                    __composer: &#core_path::Composer|
486                {
487                    #recranpose_fn_ident #ty_generics_turbofish (
488                        __composer
489                    );
490                });
491            }
492        };
493
494        let helper_body = if returns_unit {
495            quote! {
496                #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
497                let __current_scope = __composer
498                    .current_recranpose_scope()
499                    .expect("missing recompose scope");
500                let mut __changed = __current_scope.should_recompose();
501                #(#param_setup)*
502                #recranpose_setter
503                if !__changed && __current_scope.has_composed_once() {
504                    __composer.skip_current_group();
505                    return;
506                }
507                #(#rebinds)*
508                #helper_block
509            }
510        } else {
511            quote! {
512                #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
513                let __current_scope = __composer
514                    .current_recranpose_scope()
515                    .expect("missing recompose scope");
516                let mut __changed = __current_scope.should_recompose();
517                #(#param_setup)*
518                #recranpose_setter
519                let __result_slot_index = __composer
520                    .__use_return_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
521                let __has_previous = __composer
522                    .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
523                        __result_slot_index,
524                        |slot| slot.get().is_some(),
525                    );
526                if !__changed && __has_previous {
527                    __composer.skip_current_group();
528                    let __result = __composer
529                        .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
530                            __result_slot_index,
531                            |slot| {
532                                slot.get()
533                                    .expect("composable return value missing during skip")
534                            },
535                        );
536                    return __result;
537                }
538                let __value: #return_ty = {
539                    #(#rebinds)*
540                    #helper_block
541                };
542                __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
543                    __result_slot_index,
544                    |slot| {
545                        slot.store(__value.clone());
546                    },
547                );
548                __value
549            }
550        };
551
552        let recranpose_fn_body = if returns_unit {
553            quote! {
554                #(#param_setup_recompose)*
555                #(#rebinds_for_recompose)*
556                #recranpose_block
557                #recranpose_setter
558            }
559        } else {
560            quote! {
561                #(#param_setup_recompose)*
562                let __result_slot_index = __composer
563                    .__use_return_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
564                #(#rebinds_for_recompose)*
565                let __value: #return_ty = {
566                    #recranpose_block
567                };
568                __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
569                    __result_slot_index,
570                    |slot| {
571                        slot.store(__value.clone());
572                    },
573                );
574                #recranpose_setter
575                #invalidate_return_consumer
576                __value
577            }
578        };
579
580        let recranpose_fn = quote! {
581            #[allow(non_snake_case)]
582            fn #recranpose_fn_ident #impl_generics (
583                __composer: &#core_path::Composer
584            ) -> #return_ty #where_clause {
585                #recranpose_fn_body
586            }
587        };
588
589        let helper_fn = quote! {
590            #[allow(non_snake_case, clippy::too_many_arguments)]
591            fn #helper_ident #impl_generics (
592                __composer: &#core_path::Composer
593                #(, #helper_inputs)*
594            ) -> #return_ty #where_clause {
595                #helper_body
596            }
597        };
598
599        // Wrapper args: pass all params except unhandled impl Trait on initial call
600        let wrapper_args: Vec<TokenStream2> = param_info
601            .iter()
602            .filter_map(|info| {
603                if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
604                    None
605                } else {
606                    let ident = &info.ident;
607                    Some(quote! { #ident })
608                }
609            })
610            .collect();
611
612        let wrapped = quote!({
613            #core_path::with_current_composer(|__composer: &#core_path::Composer| {
614                __composer.with_group(#key_expr, |__composer: &#core_path::Composer| {
615                    #helper_ident(__composer #(, #wrapper_args)*)
616                })
617            })
618        });
619        *func.block = syn::parse2(wrapped).expect("failed to build block");
620        TokenStream::from(quote! {
621            #recranpose_fn
622            #helper_fn
623            #func
624        })
625    } else {
626        // no_skip path: still uses simple rebinds
627        let wrapped = quote!({
628            #core_path::with_current_composer(|__composer: &#core_path::Composer| {
629                __composer.with_group(#key_expr, |__scope: &#core_path::Composer| {
630                    #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
631                    #(#rebinds_for_no_skip)*
632                    #original_block
633                })
634            })
635        });
636        *func.block = syn::parse2(wrapped).expect("failed to build block");
637        TokenStream::from(quote! { #func })
638    }
639}