Skip to main content

cranpose_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use proc_macro_crate::{crate_name, FoundCrate};
4use quote::quote;
5use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Type};
6
7/// Check if a type is Fn-like (impl FnMut/Fn/FnOnce, Box<dyn FnMut>, generic with Fn bound, etc.)
8/// For generic type parameters (e.g., `F` where F: FnMut()), we need to check the bounds.
9fn is_fn_like_type(ty: &Type) -> bool {
10    match ty {
11        // impl FnMut(...) + 'static, impl Fn(...), etc.
12        Type::ImplTrait(impl_trait) => impl_trait.bounds.iter().any(|bound| {
13            if let syn::TypeParamBound::Trait(trait_bound) = bound {
14                let path = &trait_bound.path;
15                if let Some(segment) = path.segments.last() {
16                    let ident_str = segment.ident.to_string();
17                    return ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce";
18                }
19            }
20            false
21        }),
22        // Box<dyn FnMut(...)>
23        Type::Path(type_path) => {
24            if let Some(segment) = type_path.path.segments.last() {
25                if segment.ident == "Box" {
26                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
27                        if let Some(syn::GenericArgument::Type(Type::TraitObject(trait_obj))) =
28                            args.args.first()
29                        {
30                            return trait_obj.bounds.iter().any(|bound| {
31                                if let syn::TypeParamBound::Trait(trait_bound) = bound {
32                                    let path = &trait_bound.path;
33                                    if let Some(segment) = path.segments.last() {
34                                        let ident_str = segment.ident.to_string();
35                                        return ident_str == "FnMut"
36                                            || ident_str == "Fn"
37                                            || ident_str == "FnOnce";
38                                    }
39                                }
40                                false
41                            });
42                        }
43                    }
44                }
45            }
46            false
47        }
48        // bare fn(...) -> ...
49        Type::BareFn(_) => true,
50        _ => false,
51    }
52}
53
54/// Check if a generic type parameter has Fn-like bounds by looking at the where clause and bounds
55fn is_generic_fn_like(ty: &Type, generics: &syn::Generics) -> bool {
56    // Extract the ident for Type::Path that might be a generic param
57    let type_ident = match ty {
58        Type::Path(type_path) if type_path.path.segments.len() == 1 => {
59            &type_path.path.segments[0].ident
60        }
61        _ => return false,
62    };
63
64    // Check if it's a type parameter with Fn bounds
65    for param in &generics.params {
66        if let syn::GenericParam::Type(type_param) = param {
67            if type_param.ident == *type_ident {
68                // Check the bounds on the type parameter
69                for bound in &type_param.bounds {
70                    if let syn::TypeParamBound::Trait(trait_bound) = bound {
71                        if let Some(segment) = trait_bound.path.segments.last() {
72                            let ident_str = segment.ident.to_string();
73                            if ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce" {
74                                return true;
75                            }
76                        }
77                    }
78                }
79            }
80        }
81    }
82
83    // Also check where clause
84    if let Some(where_clause) = &generics.where_clause {
85        for predicate in &where_clause.predicates {
86            if let syn::WherePredicate::Type(pred) = predicate {
87                if let Type::Path(bounded_type) = &pred.bounded_ty {
88                    if bounded_type.path.segments.len() == 1
89                        && bounded_type.path.segments[0].ident == *type_ident
90                    {
91                        for bound in &pred.bounds {
92                            if let syn::TypeParamBound::Trait(trait_bound) = bound {
93                                if let Some(segment) = trait_bound.path.segments.last() {
94                                    let ident_str = segment.ident.to_string();
95                                    if ident_str == "FnMut"
96                                        || ident_str == "Fn"
97                                        || ident_str == "FnOnce"
98                                    {
99                                        return true;
100                                    }
101                                }
102                            }
103                        }
104                    }
105                }
106            }
107        }
108    }
109
110    false
111}
112
113/// Unified check: is this type Fn-like, either syntactically or via generic bounds?
114fn is_fn_param(ty: &Type, generics: &syn::Generics) -> bool {
115    is_fn_like_type(ty) || is_generic_fn_like(ty, generics)
116}
117
118/// Check if a type is `impl Fn() + ...` or `impl FnMut() + ...` with **zero** arguments.
119/// Only these can be stored through [`CallbackHolder`] (excludes `FnOnce` which can't be
120/// called more than once).
121fn is_zero_arg_fn_impl_trait(ty: &Type) -> bool {
122    if let Type::ImplTrait(impl_trait) = ty {
123        impl_trait.bounds.iter().any(|bound| {
124            if let syn::TypeParamBound::Trait(trait_bound) = bound {
125                if let Some(segment) = trait_bound.path.segments.last() {
126                    let ident_str = segment.ident.to_string();
127                    if ident_str == "Fn" || ident_str == "FnMut" {
128                        if let syn::PathArguments::Parenthesized(args) = &segment.arguments {
129                            return args.inputs.is_empty();
130                        }
131                    }
132                }
133            }
134            false
135        })
136    } else {
137        false
138    }
139}
140
141fn core_crate_path() -> TokenStream2 {
142    let crate_name = crate_name("cranpose")
143        .ok()
144        .or_else(|| crate_name("cranpose-core").ok());
145
146    match crate_name {
147        Some(FoundCrate::Itself) => quote!(crate),
148        Some(FoundCrate::Name(name)) => {
149            let ident = Ident::new(&name, Span::call_site());
150            quote!(#ident)
151        }
152        None => quote!(cranpose_core),
153    }
154}
155
156#[proc_macro_attribute]
157pub fn composable(attr: TokenStream, item: TokenStream) -> TokenStream {
158    let attr_tokens = TokenStream2::from(attr);
159    let mut enable_skip = true;
160    let core_path = core_crate_path();
161    if !attr_tokens.is_empty() {
162        match syn::parse2::<Ident>(attr_tokens) {
163            Ok(ident) if ident == "no_skip" => enable_skip = false,
164            Ok(other) => {
165                return syn::Error::new_spanned(other, "unsupported composable attribute")
166                    .to_compile_error()
167                    .into();
168            }
169            Err(err) => {
170                return err.to_compile_error().into();
171            }
172        }
173    }
174
175    let mut func = parse_macro_input!(item as ItemFn);
176
177    struct ParamInfo {
178        ident: Ident,
179        pat: Box<Pat>,
180        ty: Type,
181        pat_is_mut: bool,
182        is_impl_trait: bool,
183    }
184
185    let mut param_info: Vec<ParamInfo> = Vec::new();
186
187    for (index, arg) in func.sig.inputs.iter_mut().enumerate() {
188        if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
189            let pat_is_mut = matches!(
190                pat.as_ref(),
191                Pat::Ident(pat_ident) if pat_ident.mutability.is_some()
192            );
193            let is_impl_trait = matches!(**ty, Type::ImplTrait(_));
194
195            if is_impl_trait {
196                let original_pat: Box<Pat> = pat.clone();
197                if let Pat::Ident(pat_ident) = &**pat {
198                    param_info.push(ParamInfo {
199                        ident: pat_ident.ident.clone(),
200                        pat: original_pat,
201                        ty: ty.as_ref().clone(),
202                        pat_is_mut,
203                        is_impl_trait: true,
204                    });
205                } else {
206                    param_info.push(ParamInfo {
207                        ident: Ident::new(&format!("__arg{}", index), Span::call_site()),
208                        pat: original_pat,
209                        ty: ty.as_ref().clone(),
210                        pat_is_mut,
211                        is_impl_trait: true,
212                    });
213                }
214            } else {
215                let ident = Ident::new(&format!("__arg{}", index), Span::call_site());
216                let original_pat: Box<Pat> = pat.clone();
217                *pat = Box::new(syn::parse_quote! { #ident });
218                param_info.push(ParamInfo {
219                    ident,
220                    pat: original_pat,
221                    ty: ty.as_ref().clone(),
222                    pat_is_mut,
223                    is_impl_trait: false,
224                });
225            }
226        }
227    }
228
229    let original_block = func.block.clone();
230    let helper_block = original_block.clone();
231    let recranpose_block = original_block.clone();
232    let key_expr = quote! { #core_path::location_key(file!(), line!(), column!()) };
233
234    // Rebinds will be generated later in the helper_body context where we have access to slots
235    let rebinds_for_no_skip: Vec<_> = param_info
236        .iter()
237        .map(|info| {
238            let ident = &info.ident;
239            let pat = &info.pat;
240            quote! { let #pat = #ident; }
241        })
242        .collect();
243
244    let return_ty: syn::Type = match &func.sig.output {
245        ReturnType::Default => syn::parse_quote! { () },
246        ReturnType::Type(_, ty) => ty.as_ref().clone(),
247    };
248    let _helper_ident = Ident::new(
249        &format!("__cranpose_impl_{}", func.sig.ident),
250        Span::call_site(),
251    );
252    let generics = func.sig.generics.clone();
253    let (_impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
254
255    let _helper_inputs: Vec<TokenStream2> = param_info
256        .iter()
257        .map(|info| {
258            let ident = &info.ident;
259            let ty = &info.ty;
260            quote! { #ident: #ty }
261        })
262        .collect();
263
264    // Check if any params are impl Trait that we can't store in a slot.
265    // Zero-arg Fn-like impl traits (impl Fn() + 'static) are handled via CallbackHolder.
266    let has_unhandled_impl_trait = param_info
267        .iter()
268        .any(|info| info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty));
269
270    if enable_skip && !has_unhandled_impl_trait {
271        let helper_ident = Ident::new(
272            &format!("__cranpose_impl_{}", func.sig.ident),
273            Span::call_site(),
274        );
275        let generics = func.sig.generics.clone();
276        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
277        let ty_generics_turbofish = ty_generics.as_turbofish();
278
279        // Helper function signature: all params except unhandled impl Trait.
280        // Zero-arg Fn impl traits are included (they become anonymous generics).
281        let helper_inputs: Vec<TokenStream2> = param_info
282            .iter()
283            .filter_map(|info| {
284                if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
285                    None
286                } else {
287                    let ident = &info.ident;
288                    let ty = &info.ty;
289                    Some(quote! { #ident: #ty })
290                }
291            })
292            .collect();
293
294        // Separate Fn-like params from regular params
295        let param_state_slots: Vec<Ident> = (0..param_info.len())
296            .map(|index| Ident::new(&format!("__param_state_slot{}", index), Span::call_site()))
297            .collect();
298
299        let param_setup: Vec<TokenStream2> = param_info
300            .iter()
301            .zip(param_state_slots.iter())
302            .map(|(info, slot_ident)| {
303                // Zero-arg Fn impl traits and generic Fn params → CallbackHolder
304                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
305                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
306                {
307                    let ident = &info.ident;
308                    quote! {
309                        let #slot_ident = __composer
310                            .use_value_slot(|| #core_path::CallbackHolder::new());
311                        __composer.with_slot_value::<#core_path::CallbackHolder, _>(
312                            #slot_ident,
313                            |holder| {
314                                holder.update(#ident);
315                            },
316                        );
317                        __changed = true;
318                    }
319                } else if info.is_impl_trait {
320                    // Non-Fn impl trait – cannot store, always mark changed
321                    quote! { __changed = true; }
322                } else {
323                    let ident = &info.ident;
324                    let ty = &info.ty;
325                    quote! {
326                        let #slot_ident = __composer
327                            .use_value_slot(|| #core_path::ParamState::<#ty>::default());
328                        if __composer.with_slot_value_mut::<#core_path::ParamState<#ty>, _>(
329                            #slot_ident,
330                            |state| state.update(&#ident),
331                        )
332                        {
333                            __changed = true;
334                        }
335                    }
336                }
337            })
338            .collect();
339
340        let param_setup_recompose: Vec<TokenStream2> = param_info
341            .iter()
342            .zip(param_state_slots.iter())
343            .map(|(info, slot_ident)| {
344                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
345                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
346                {
347                    quote! {
348                        let #slot_ident = __composer
349                            .use_value_slot(|| #core_path::CallbackHolder::new());
350                    }
351                } else if info.is_impl_trait {
352                    quote! {}
353                } else {
354                    let ty = &info.ty;
355                    quote! {
356                        let #slot_ident = __composer
357                            .use_value_slot(|| #core_path::ParamState::<#ty>::default());
358                    }
359                }
360            })
361            .collect();
362
363        let rebinds: Vec<TokenStream2> = param_info
364            .iter()
365            .zip(param_state_slots.iter())
366            .map(|(info, slot_ident)| {
367                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
368                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
369                {
370                    let pat = &info.pat;
371                    let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
372                    if can_add_mut && !info.pat_is_mut {
373                        quote! {
374                            #[allow(unused_mut)]
375                            let mut #pat = __composer
376                                .with_slot_value::<#core_path::CallbackHolder, _>(
377                                    #slot_ident,
378                                    |holder| holder.clone_rc(),
379                                );
380                        }
381                    } else {
382                        quote! {
383                            #[allow(unused_mut)]
384                            let #pat = __composer
385                                .with_slot_value::<#core_path::CallbackHolder, _>(
386                                    #slot_ident,
387                                    |holder| holder.clone_rc(),
388                                );
389                        }
390                    }
391                } else if info.is_impl_trait {
392                    quote! {}
393                } else {
394                    let pat = &info.pat;
395                    let ident = &info.ident;
396                    quote! {
397                        let #pat = #ident;
398                    }
399                }
400            })
401            .collect();
402
403        let rebinds_for_recompose: Vec<TokenStream2> = param_info
404            .iter()
405            .zip(param_state_slots.iter())
406            .map(|(info, slot_ident)| {
407                if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
408                    || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
409                {
410                    let pat = &info.pat;
411                    let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
412                    if can_add_mut && !info.pat_is_mut {
413                        quote! {
414                            #[allow(unused_mut)]
415                            let mut #pat = __composer
416                                .with_slot_value::<#core_path::CallbackHolder, _>(
417                                    #slot_ident,
418                                    |holder| holder.clone_rc(),
419                                );
420                        }
421                    } else {
422                        quote! {
423                            #[allow(unused_mut)]
424                            let #pat = __composer
425                                .with_slot_value::<#core_path::CallbackHolder, _>(
426                                    #slot_ident,
427                                    |holder| holder.clone_rc(),
428                                );
429                        }
430                    }
431                } else if info.is_impl_trait {
432                    quote! {}
433                } else {
434                    let pat = &info.pat;
435                    let ty = &info.ty;
436                    quote! {
437                        let #pat = __composer
438                            .with_slot_value::<#core_path::ParamState<#ty>, _>(
439                                #slot_ident,
440                                |state| {
441                                    state
442                                        .value()
443                                        .expect("composable parameter missing for recomposition")
444                                },
445                            );
446                    }
447                }
448            })
449            .collect();
450
451        let recranpose_fn_ident = Ident::new(
452            &format!("__cranpose_recranpose_{}", func.sig.ident),
453            Span::call_site(),
454        );
455
456        let recranpose_setter = quote! {
457            {
458                __composer.set_recranpose_callback(move |
459                    __composer: &#core_path::Composer|
460                {
461                    #recranpose_fn_ident #ty_generics_turbofish (
462                        __composer
463                    );
464                });
465            }
466        };
467
468        let helper_body = quote! {
469            let __current_scope = __composer
470                .current_recranpose_scope()
471                .expect("missing recompose scope");
472            let mut __changed = __current_scope.should_recompose();
473            #(#param_setup)*
474            let __result_slot_index = __composer
475                .use_value_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
476            let __has_previous = __composer
477                .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
478                    __result_slot_index,
479                    |slot| slot.get().is_some(),
480                );
481            if !__changed && __has_previous {
482                __composer.skip_current_group();
483                let __result = __composer
484                    .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
485                        __result_slot_index,
486                        |slot| {
487                            slot.get()
488                                .expect("composable return value missing during skip")
489                        },
490                    );
491                return __result;
492            }
493            let __value: #return_ty = {
494                #(#rebinds)*
495                #helper_block
496            };
497            __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
498                __result_slot_index,
499                |slot| {
500                    slot.store(__value.clone());
501                },
502            );
503            #recranpose_setter
504            __value
505        };
506
507        let recranpose_fn_body = quote! {
508            #(#param_setup_recompose)*
509            let __result_slot_index = __composer
510                .use_value_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
511            #(#rebinds_for_recompose)*
512            let __value: #return_ty = {
513                #recranpose_block
514            };
515            __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
516                __result_slot_index,
517                |slot| {
518                    slot.store(__value.clone());
519                },
520            );
521            #recranpose_setter
522            __value
523        };
524
525        let recranpose_fn = quote! {
526            #[allow(non_snake_case)]
527            fn #recranpose_fn_ident #impl_generics (
528                __composer: &#core_path::Composer
529            ) -> #return_ty #where_clause {
530                #recranpose_fn_body
531            }
532        };
533
534        let helper_fn = quote! {
535            #[allow(non_snake_case, clippy::too_many_arguments)]
536            fn #helper_ident #impl_generics (
537                __composer: &#core_path::Composer
538                #(, #helper_inputs)*
539            ) -> #return_ty #where_clause {
540                #helper_body
541            }
542        };
543
544        // Wrapper args: pass all params except unhandled impl Trait on initial call
545        let wrapper_args: Vec<TokenStream2> = param_info
546            .iter()
547            .filter_map(|info| {
548                if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
549                    None
550                } else {
551                    let ident = &info.ident;
552                    Some(quote! { #ident })
553                }
554            })
555            .collect();
556
557        let wrapped = quote!({
558            #core_path::with_current_composer(|__composer: &#core_path::Composer| {
559                __composer.with_group(#key_expr, |__composer: &#core_path::Composer| {
560                    #helper_ident(__composer #(, #wrapper_args)*)
561                })
562            })
563        });
564        func.block = Box::new(syn::parse2(wrapped).expect("failed to build block"));
565        TokenStream::from(quote! {
566            #recranpose_fn
567            #helper_fn
568            #func
569        })
570    } else {
571        // no_skip path: still uses simple rebinds
572        let wrapped = quote!({
573            #core_path::with_current_composer(|__composer: &#core_path::Composer| {
574                __composer.with_group(#key_expr, |__scope: &#core_path::Composer| {
575                    #(#rebinds_for_no_skip)*
576                    #original_block
577                })
578            })
579        });
580        func.block = Box::new(syn::parse2(wrapped).expect("failed to build block"));
581        TokenStream::from(quote! { #func })
582    }
583}