cranpose_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::quote;
4use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Type};
5
6/// Check if a type is Fn-like (impl FnMut/Fn/FnOnce, Box<dyn FnMut>, generic with Fn bound, etc.)
7/// For generic type parameters (e.g., `F` where F: FnMut()), we need to check the bounds.
8fn is_fn_like_type(ty: &Type) -> bool {
9    match ty {
10        // impl FnMut(...) + 'static, impl Fn(...), etc.
11        Type::ImplTrait(impl_trait) => impl_trait.bounds.iter().any(|bound| {
12            if let syn::TypeParamBound::Trait(trait_bound) = bound {
13                let path = &trait_bound.path;
14                if let Some(segment) = path.segments.last() {
15                    let ident_str = segment.ident.to_string();
16                    return ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce";
17                }
18            }
19            false
20        }),
21        // Box<dyn FnMut(...)>
22        Type::Path(type_path) => {
23            if let Some(segment) = type_path.path.segments.last() {
24                if segment.ident == "Box" {
25                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
26                        if let Some(syn::GenericArgument::Type(Type::TraitObject(trait_obj))) =
27                            args.args.first()
28                        {
29                            return trait_obj.bounds.iter().any(|bound| {
30                                if let syn::TypeParamBound::Trait(trait_bound) = bound {
31                                    let path = &trait_bound.path;
32                                    if let Some(segment) = path.segments.last() {
33                                        let ident_str = segment.ident.to_string();
34                                        return ident_str == "FnMut"
35                                            || ident_str == "Fn"
36                                            || ident_str == "FnOnce";
37                                    }
38                                }
39                                false
40                            });
41                        }
42                    }
43                }
44            }
45            false
46        }
47        // bare fn(...) -> ...
48        Type::BareFn(_) => true,
49        _ => false,
50    }
51}
52
53/// Check if a generic type parameter has Fn-like bounds by looking at the where clause and bounds
54fn is_generic_fn_like(ty: &Type, generics: &syn::Generics) -> bool {
55    // Extract the ident for Type::Path that might be a generic param
56    let type_ident = match ty {
57        Type::Path(type_path) if type_path.path.segments.len() == 1 => {
58            &type_path.path.segments[0].ident
59        }
60        _ => return false,
61    };
62
63    // Check if it's a type parameter with Fn bounds
64    for param in &generics.params {
65        if let syn::GenericParam::Type(type_param) = param {
66            if type_param.ident == *type_ident {
67                // Check the bounds on the type parameter
68                for bound in &type_param.bounds {
69                    if let syn::TypeParamBound::Trait(trait_bound) = bound {
70                        if let Some(segment) = trait_bound.path.segments.last() {
71                            let ident_str = segment.ident.to_string();
72                            if ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce" {
73                                return true;
74                            }
75                        }
76                    }
77                }
78            }
79        }
80    }
81
82    // Also check where clause
83    if let Some(where_clause) = &generics.where_clause {
84        for predicate in &where_clause.predicates {
85            if let syn::WherePredicate::Type(pred) = predicate {
86                if let Type::Path(bounded_type) = &pred.bounded_ty {
87                    if bounded_type.path.segments.len() == 1
88                        && bounded_type.path.segments[0].ident == *type_ident
89                    {
90                        for bound in &pred.bounds {
91                            if let syn::TypeParamBound::Trait(trait_bound) = bound {
92                                if let Some(segment) = trait_bound.path.segments.last() {
93                                    let ident_str = segment.ident.to_string();
94                                    if ident_str == "FnMut"
95                                        || ident_str == "Fn"
96                                        || ident_str == "FnOnce"
97                                    {
98                                        return true;
99                                    }
100                                }
101                            }
102                        }
103                    }
104                }
105            }
106        }
107    }
108
109    false
110}
111
112/// Unified check: is this type Fn-like, either syntactically or via generic bounds?
113fn is_fn_param(ty: &Type, generics: &syn::Generics) -> bool {
114    is_fn_like_type(ty) || is_generic_fn_like(ty, generics)
115}
116
117#[proc_macro_attribute]
118pub fn composable(attr: TokenStream, item: TokenStream) -> TokenStream {
119    let attr_tokens = TokenStream2::from(attr);
120    let mut enable_skip = true;
121    if !attr_tokens.is_empty() {
122        match syn::parse2::<Ident>(attr_tokens) {
123            Ok(ident) if ident == "no_skip" => enable_skip = false,
124            Ok(other) => {
125                return syn::Error::new_spanned(other, "unsupported composable attribute")
126                    .to_compile_error()
127                    .into();
128            }
129            Err(err) => {
130                return err.to_compile_error().into();
131            }
132        }
133    }
134
135    let mut func = parse_macro_input!(item as ItemFn);
136
137    struct ParamInfo {
138        ident: Ident,
139        pat: Box<Pat>,
140        ty: Type,
141        pat_is_mut: bool,
142        is_impl_trait: bool,
143    }
144
145    let mut param_info: Vec<ParamInfo> = Vec::new();
146
147    for (index, arg) in func.sig.inputs.iter_mut().enumerate() {
148        if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
149            let pat_is_mut = matches!(
150                pat.as_ref(),
151                Pat::Ident(pat_ident) if pat_ident.mutability.is_some()
152            );
153            let is_impl_trait = matches!(**ty, Type::ImplTrait(_));
154
155            if is_impl_trait {
156                let original_pat: Box<Pat> = pat.clone();
157                if let Pat::Ident(pat_ident) = &**pat {
158                    param_info.push(ParamInfo {
159                        ident: pat_ident.ident.clone(),
160                        pat: original_pat,
161                        ty: ty.as_ref().clone(),
162                        pat_is_mut,
163                        is_impl_trait: true,
164                    });
165                } else {
166                    param_info.push(ParamInfo {
167                        ident: Ident::new(&format!("__arg{}", index), Span::call_site()),
168                        pat: original_pat,
169                        ty: ty.as_ref().clone(),
170                        pat_is_mut,
171                        is_impl_trait: true,
172                    });
173                }
174            } else {
175                let ident = Ident::new(&format!("__arg{}", index), Span::call_site());
176                let original_pat: Box<Pat> = pat.clone();
177                *pat = Box::new(syn::parse_quote! { #ident });
178                param_info.push(ParamInfo {
179                    ident,
180                    pat: original_pat,
181                    ty: ty.as_ref().clone(),
182                    pat_is_mut,
183                    is_impl_trait: false,
184                });
185            }
186        }
187    }
188
189    let original_block = func.block.clone();
190    let helper_block = original_block.clone();
191    let recranpose_block = original_block.clone();
192    let key_expr = quote! { cranpose_core::location_key(file!(), line!(), column!()) };
193
194    // Rebinds will be generated later in the helper_body context where we have access to slots
195    let rebinds_for_no_skip: Vec<_> = param_info
196        .iter()
197        .map(|info| {
198            let ident = &info.ident;
199            let pat = &info.pat;
200            quote! { let #pat = #ident; }
201        })
202        .collect();
203
204    let return_ty: syn::Type = match &func.sig.output {
205        ReturnType::Default => syn::parse_quote! { () },
206        ReturnType::Type(_, ty) => ty.as_ref().clone(),
207    };
208    let _helper_ident = Ident::new(
209        &format!("__cranpose_impl_{}", func.sig.ident),
210        Span::call_site(),
211    );
212    let generics = func.sig.generics.clone();
213    let (_impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
214
215    let _helper_inputs: Vec<TokenStream2> = param_info
216        .iter()
217        .map(|info| {
218            let ident = &info.ident;
219            let ty = &info.ty;
220            quote! { #ident: #ty }
221        })
222        .collect();
223
224    // Check if any params are impl Trait - if so, can't use skip optimization
225    let has_impl_trait = param_info
226        .iter()
227        .any(|info| matches!(info.ty, Type::ImplTrait(_)));
228
229    if enable_skip && !has_impl_trait {
230        let helper_ident = Ident::new(
231            &format!("__cranpose_impl_{}", func.sig.ident),
232            Span::call_site(),
233        );
234        let generics = func.sig.generics.clone();
235        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
236        let ty_generics_turbofish = ty_generics.as_turbofish();
237
238        // Helper function signature: all params except impl Trait (which can't be named)
239        let helper_inputs: Vec<TokenStream2> = param_info
240            .iter()
241            .filter_map(|info| {
242                if info.is_impl_trait {
243                    None
244                } else {
245                    let ident = &info.ident;
246                    let ty = &info.ty;
247                    Some(quote! { #ident: #ty })
248                }
249            })
250            .collect();
251
252        // Separate Fn-like params from regular params
253        let param_state_slots: Vec<Ident> = (0..param_info.len())
254            .map(|index| Ident::new(&format!("__param_state_slot{}", index), Span::call_site()))
255            .collect();
256
257        let param_setup: Vec<TokenStream2> = param_info
258            .iter()
259            .zip(param_state_slots.iter())
260            .map(|(info, slot_ident)| {
261                if info.is_impl_trait {
262                    quote! { __changed = true; }
263                } else if is_fn_param(&info.ty, &generics) {
264                    let ident = &info.ident;
265                    quote! {
266                        let #slot_ident = __composer
267                            .use_value_slot(|| cranpose_core::CallbackHolder::new());
268                        __composer.with_slot_value::<cranpose_core::CallbackHolder, _>(
269                            #slot_ident,
270                            |holder| {
271                                holder.update(#ident);
272                            },
273                        );
274                        __changed = true;
275                    }
276                } else {
277                    let ident = &info.ident;
278                    let ty = &info.ty;
279                    quote! {
280                        let #slot_ident = __composer
281                            .use_value_slot(|| cranpose_core::ParamState::<#ty>::default());
282                        if __composer.with_slot_value_mut::<cranpose_core::ParamState<#ty>, _>(
283                            #slot_ident,
284                            |state| state.update(&#ident),
285                        )
286                        {
287                            __changed = true;
288                        }
289                    }
290                }
291            })
292            .collect();
293
294        let param_setup_recompose: Vec<TokenStream2> = param_info
295            .iter()
296            .zip(param_state_slots.iter())
297            .map(|(info, slot_ident)| {
298                if info.is_impl_trait {
299                    quote! {}
300                } else if is_fn_param(&info.ty, &generics) {
301                    quote! {
302                        let #slot_ident = __composer
303                            .use_value_slot(|| cranpose_core::CallbackHolder::new());
304                    }
305                } else {
306                    let ty = &info.ty;
307                    quote! {
308                        let #slot_ident = __composer
309                            .use_value_slot(|| cranpose_core::ParamState::<#ty>::default());
310                    }
311                }
312            })
313            .collect();
314
315        let rebinds: Vec<TokenStream2> = param_info
316            .iter()
317            .zip(param_state_slots.iter())
318            .map(|(info, slot_ident)| {
319                if info.is_impl_trait {
320                    quote! {}
321                } else if is_fn_param(&info.ty, &generics) {
322                    let pat = &info.pat;
323                    let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
324                    if can_add_mut && !info.pat_is_mut {
325                        quote! {
326                            #[allow(unused_mut)]
327                            let mut #pat = __composer
328                                .with_slot_value::<cranpose_core::CallbackHolder, _>(
329                                    #slot_ident,
330                                    |holder| holder.clone_rc(),
331                                );
332                        }
333                    } else {
334                        quote! {
335                            #[allow(unused_mut)]
336                            let #pat = __composer
337                                .with_slot_value::<cranpose_core::CallbackHolder, _>(
338                                    #slot_ident,
339                                    |holder| holder.clone_rc(),
340                                );
341                        }
342                    }
343                } else {
344                    let pat = &info.pat;
345                    let ident = &info.ident;
346                    quote! {
347                        let #pat = #ident;
348                    }
349                }
350            })
351            .collect();
352
353        let rebinds_for_recompose: Vec<TokenStream2> = param_info
354            .iter()
355            .zip(param_state_slots.iter())
356            .map(|(info, slot_ident)| {
357                if info.is_impl_trait {
358                    quote! {}
359                } else if is_fn_param(&info.ty, &generics) {
360                    let pat = &info.pat;
361                    let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
362                    if can_add_mut && !info.pat_is_mut {
363                        quote! {
364                            #[allow(unused_mut)]
365                            let mut #pat = __composer
366                                .with_slot_value::<cranpose_core::CallbackHolder, _>(
367                                    #slot_ident,
368                                    |holder| holder.clone_rc(),
369                                );
370                        }
371                    } else {
372                        quote! {
373                            #[allow(unused_mut)]
374                            let #pat = __composer
375                                .with_slot_value::<cranpose_core::CallbackHolder, _>(
376                                    #slot_ident,
377                                    |holder| holder.clone_rc(),
378                                );
379                        }
380                    }
381                } else {
382                    let pat = &info.pat;
383                    let ty = &info.ty;
384                    quote! {
385                        let #pat = __composer
386                            .with_slot_value::<cranpose_core::ParamState<#ty>, _>(
387                                #slot_ident,
388                                |state| {
389                                    state
390                                        .value()
391                                        .expect("composable parameter missing for recomposition")
392                                },
393                            );
394                    }
395                }
396            })
397            .collect();
398
399        let recranpose_fn_ident = Ident::new(
400            &format!("__cranpose_recranpose_{}", func.sig.ident),
401            Span::call_site(),
402        );
403
404        let recranpose_setter = quote! {
405            {
406                __composer.set_recranpose_callback(move |
407                    __composer: &cranpose_core::Composer|
408                {
409                    #recranpose_fn_ident #ty_generics_turbofish (
410                        __composer
411                    );
412                });
413            }
414        };
415
416        let helper_body = quote! {
417            let __current_scope = __composer
418                .current_recranpose_scope()
419                .expect("missing recompose scope");
420            let mut __changed = __current_scope.should_recompose();
421            #(#param_setup)*
422            let __result_slot_index = __composer
423                .use_value_slot(|| cranpose_core::ReturnSlot::<#return_ty>::default());
424            let __has_previous = __composer
425                .with_slot_value::<cranpose_core::ReturnSlot<#return_ty>, _>(
426                    __result_slot_index,
427                    |slot| slot.get().is_some(),
428                );
429            if !__changed && __has_previous {
430                __composer.skip_current_group();
431                let __result = __composer
432                    .with_slot_value::<cranpose_core::ReturnSlot<#return_ty>, _>(
433                        __result_slot_index,
434                        |slot| {
435                            slot.get()
436                                .expect("composable return value missing during skip")
437                        },
438                    );
439                return __result;
440            }
441            let __value: #return_ty = {
442                #(#rebinds)*
443                #helper_block
444            };
445            __composer.with_slot_value_mut::<cranpose_core::ReturnSlot<#return_ty>, _>(
446                __result_slot_index,
447                |slot| {
448                    slot.store(__value.clone());
449                },
450            );
451            #recranpose_setter
452            __value
453        };
454
455        let recranpose_fn_body = quote! {
456            #(#param_setup_recompose)*
457            let __result_slot_index = __composer
458                .use_value_slot(|| cranpose_core::ReturnSlot::<#return_ty>::default());
459            #(#rebinds_for_recompose)*
460            let __value: #return_ty = {
461                #recranpose_block
462            };
463            __composer.with_slot_value_mut::<cranpose_core::ReturnSlot<#return_ty>, _>(
464                __result_slot_index,
465                |slot| {
466                    slot.store(__value.clone());
467                },
468            );
469            #recranpose_setter
470            __value
471        };
472
473        let recranpose_fn = quote! {
474            #[allow(non_snake_case)]
475            fn #recranpose_fn_ident #impl_generics (
476                __composer: &cranpose_core::Composer
477            ) -> #return_ty #where_clause {
478                #recranpose_fn_body
479            }
480        };
481
482        let helper_fn = quote! {
483            #[allow(non_snake_case)]
484            fn #helper_ident #impl_generics (
485                __composer: &cranpose_core::Composer
486                #(, #helper_inputs)*
487            ) -> #return_ty #where_clause {
488                #helper_body
489            }
490        };
491
492        // Wrapper args: pass all params except impl Trait on initial call
493        let wrapper_args: Vec<TokenStream2> = param_info
494            .iter()
495            .filter_map(|info| {
496                if info.is_impl_trait {
497                    None
498                } else {
499                    let ident = &info.ident;
500                    Some(quote! { #ident })
501                }
502            })
503            .collect();
504
505        let wrapped = quote!({
506            cranpose_core::with_current_composer(|__composer: &cranpose_core::Composer| {
507                __composer.with_group(#key_expr, |__composer: &cranpose_core::Composer| {
508                    #helper_ident(__composer #(, #wrapper_args)*)
509                })
510            })
511        });
512        func.block = Box::new(syn::parse2(wrapped).expect("failed to build block"));
513        TokenStream::from(quote! {
514            #recranpose_fn
515            #helper_fn
516            #func
517        })
518    } else {
519        // no_skip path: still uses simple rebinds
520        let wrapped = quote!({
521            cranpose_core::with_current_composer(|__composer: &cranpose_core::Composer| {
522                __composer.with_group(#key_expr, |__scope: &cranpose_core::Composer| {
523                    #(#rebinds_for_no_skip)*
524                    #original_block
525                })
526            })
527        });
528        func.block = Box::new(syn::parse2(wrapped).expect("failed to build block"));
529        TokenStream::from(quote! { #func })
530    }
531}