1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use proc_macro_crate::{crate_name, FoundCrate};
4use quote::quote;
5use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Type};
6
7fn is_fn_like_type(ty: &Type) -> bool {
10 match ty {
11 Type::ImplTrait(impl_trait) => impl_trait.bounds.iter().any(|bound| {
13 if let syn::TypeParamBound::Trait(trait_bound) = bound {
14 let path = &trait_bound.path;
15 if let Some(segment) = path.segments.last() {
16 let ident_str = segment.ident.to_string();
17 return ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce";
18 }
19 }
20 false
21 }),
22 Type::Path(type_path) => {
24 if let Some(segment) = type_path.path.segments.last() {
25 if segment.ident == "Box" {
26 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
27 if let Some(syn::GenericArgument::Type(Type::TraitObject(trait_obj))) =
28 args.args.first()
29 {
30 return trait_obj.bounds.iter().any(|bound| {
31 if let syn::TypeParamBound::Trait(trait_bound) = bound {
32 let path = &trait_bound.path;
33 if let Some(segment) = path.segments.last() {
34 let ident_str = segment.ident.to_string();
35 return ident_str == "FnMut"
36 || ident_str == "Fn"
37 || ident_str == "FnOnce";
38 }
39 }
40 false
41 });
42 }
43 }
44 }
45 }
46 false
47 }
48 Type::BareFn(_) => true,
50 _ => false,
51 }
52}
53
54fn is_generic_fn_like(ty: &Type, generics: &syn::Generics) -> bool {
56 let type_ident = match ty {
58 Type::Path(type_path) if type_path.path.segments.len() == 1 => {
59 &type_path.path.segments[0].ident
60 }
61 _ => return false,
62 };
63
64 for param in &generics.params {
66 if let syn::GenericParam::Type(type_param) = param {
67 if type_param.ident == *type_ident {
68 for bound in &type_param.bounds {
70 if let syn::TypeParamBound::Trait(trait_bound) = bound {
71 if let Some(segment) = trait_bound.path.segments.last() {
72 let ident_str = segment.ident.to_string();
73 if ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce" {
74 return true;
75 }
76 }
77 }
78 }
79 }
80 }
81 }
82
83 if let Some(where_clause) = &generics.where_clause {
85 for predicate in &where_clause.predicates {
86 if let syn::WherePredicate::Type(pred) = predicate {
87 if let Type::Path(bounded_type) = &pred.bounded_ty {
88 if bounded_type.path.segments.len() == 1
89 && bounded_type.path.segments[0].ident == *type_ident
90 {
91 for bound in &pred.bounds {
92 if let syn::TypeParamBound::Trait(trait_bound) = bound {
93 if let Some(segment) = trait_bound.path.segments.last() {
94 let ident_str = segment.ident.to_string();
95 if ident_str == "FnMut"
96 || ident_str == "Fn"
97 || ident_str == "FnOnce"
98 {
99 return true;
100 }
101 }
102 }
103 }
104 }
105 }
106 }
107 }
108 }
109
110 false
111}
112
113fn is_fn_param(ty: &Type, generics: &syn::Generics) -> bool {
115 is_fn_like_type(ty) || is_generic_fn_like(ty, generics)
116}
117
118fn is_zero_arg_fn_impl_trait(ty: &Type) -> bool {
122 if let Type::ImplTrait(impl_trait) = ty {
123 impl_trait.bounds.iter().any(|bound| {
124 if let syn::TypeParamBound::Trait(trait_bound) = bound {
125 if let Some(segment) = trait_bound.path.segments.last() {
126 let ident_str = segment.ident.to_string();
127 if ident_str == "Fn" || ident_str == "FnMut" {
128 if let syn::PathArguments::Parenthesized(args) = &segment.arguments {
129 return args.inputs.is_empty();
130 }
131 }
132 }
133 }
134 false
135 })
136 } else {
137 false
138 }
139}
140
141fn is_node_id_return(ty: &Type) -> bool {
142 matches!(
143 ty,
144 Type::Path(type_path)
145 if type_path
146 .path
147 .segments
148 .last()
149 .is_some_and(|segment| segment.ident == "NodeId")
150 )
151}
152
153fn core_crate_path() -> TokenStream2 {
154 let crate_name = crate_name("cranpose")
155 .ok()
156 .or_else(|| crate_name("cranpose-core").ok());
157
158 match crate_name {
159 Some(FoundCrate::Itself) => quote!(crate),
160 Some(FoundCrate::Name(name)) => {
161 let ident = Ident::new(&name, Span::call_site());
162 quote!(#ident)
163 }
164 None => quote!(cranpose_core),
165 }
166}
167
168#[proc_macro_attribute]
169pub fn composable(attr: TokenStream, item: TokenStream) -> TokenStream {
170 let attr_tokens = TokenStream2::from(attr);
171 let mut enable_skip = true;
172 let core_path = core_crate_path();
173 if !attr_tokens.is_empty() {
174 match syn::parse2::<Ident>(attr_tokens) {
175 Ok(ident) if ident == "no_skip" => enable_skip = false,
176 Ok(other) => {
177 return syn::Error::new_spanned(other, "unsupported composable attribute")
178 .to_compile_error()
179 .into();
180 }
181 Err(err) => {
182 return err.to_compile_error().into();
183 }
184 }
185 }
186
187 let mut func = parse_macro_input!(item as ItemFn);
188
189 struct ParamInfo {
190 ident: Ident,
191 pat: Box<Pat>,
192 ty: Type,
193 pat_is_mut: bool,
194 is_impl_trait: bool,
195 }
196
197 let mut param_info: Vec<ParamInfo> = Vec::new();
198
199 for (index, arg) in func.sig.inputs.iter_mut().enumerate() {
200 if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
201 let pat_is_mut = matches!(
202 pat.as_ref(),
203 Pat::Ident(pat_ident) if pat_ident.mutability.is_some()
204 );
205 let is_impl_trait = matches!(**ty, Type::ImplTrait(_));
206
207 if is_impl_trait {
208 let original_pat: Box<Pat> = pat.clone();
209 if let Pat::Ident(pat_ident) = &**pat {
210 param_info.push(ParamInfo {
211 ident: pat_ident.ident.clone(),
212 pat: original_pat,
213 ty: ty.as_ref().clone(),
214 pat_is_mut,
215 is_impl_trait: true,
216 });
217 } else {
218 param_info.push(ParamInfo {
219 ident: Ident::new(&format!("__arg{}", index), Span::call_site()),
220 pat: original_pat,
221 ty: ty.as_ref().clone(),
222 pat_is_mut,
223 is_impl_trait: true,
224 });
225 }
226 } else {
227 let ident = Ident::new(&format!("__arg{}", index), Span::call_site());
228 let original_pat: Box<Pat> = pat.clone();
229 **pat = syn::parse_quote! { #ident };
230 param_info.push(ParamInfo {
231 ident,
232 pat: original_pat,
233 ty: ty.as_ref().clone(),
234 pat_is_mut,
235 is_impl_trait: false,
236 });
237 }
238 }
239 }
240
241 let scope_label_ident = func.sig.ident.clone();
242 let original_block = func.block.clone();
243 let helper_block = original_block.clone();
244 let recranpose_block = original_block.clone();
245 let key_expr = quote! { #core_path::location_key(file!(), line!(), column!()) };
246
247 let rebinds_for_no_skip: Vec<_> = param_info
249 .iter()
250 .map(|info| {
251 let ident = &info.ident;
252 let pat = &info.pat;
253 quote! { let #pat = #ident; }
254 })
255 .collect();
256
257 let return_ty: syn::Type = match &func.sig.output {
258 ReturnType::Default => syn::parse_quote! { () },
259 ReturnType::Type(_, ty) => ty.as_ref().clone(),
260 };
261 let returns_unit = match &func.sig.output {
262 ReturnType::Default => true,
263 ReturnType::Type(_, ty) => {
264 matches!(ty.as_ref(), Type::Tuple(tuple) if tuple.elems.is_empty())
265 }
266 };
267 let invalidate_return_consumer = if returns_unit || is_node_id_return(&return_ty) {
268 quote! {}
269 } else {
270 quote! { __composer.__invalidate_return_consumer_scope(); }
271 };
272 let _helper_ident = Ident::new(
273 &format!("__cranpose_impl_{}", func.sig.ident),
274 Span::call_site(),
275 );
276 let generics = func.sig.generics.clone();
277 let (_impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
278
279 let _helper_inputs: Vec<TokenStream2> = param_info
280 .iter()
281 .map(|info| {
282 let ident = &info.ident;
283 let ty = &info.ty;
284 quote! { #ident: #ty }
285 })
286 .collect();
287
288 let has_unhandled_impl_trait = param_info
291 .iter()
292 .any(|info| info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty));
293
294 if enable_skip && !has_unhandled_impl_trait {
295 let helper_ident = Ident::new(
296 &format!("__cranpose_impl_{}", func.sig.ident),
297 Span::call_site(),
298 );
299 let generics = func.sig.generics.clone();
300 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
301 let ty_generics_turbofish = ty_generics.as_turbofish();
302
303 let helper_inputs: Vec<TokenStream2> = param_info
306 .iter()
307 .filter_map(|info| {
308 if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
309 None
310 } else {
311 let ident = &info.ident;
312 let ty = &info.ty;
313 Some(quote! { #ident: #ty })
314 }
315 })
316 .collect();
317
318 let param_state_slots: Vec<Ident> = (0..param_info.len())
320 .map(|index| Ident::new(&format!("__param_state_slot{}", index), Span::call_site()))
321 .collect();
322
323 let param_setup: Vec<TokenStream2> = param_info
324 .iter()
325 .zip(param_state_slots.iter())
326 .map(|(info, slot_ident)| {
327 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
329 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
330 {
331 let ident = &info.ident;
332 quote! {
333 let #slot_ident = __composer
334 .__use_param_slot(|| #core_path::CallbackHolder::new());
335 __composer.with_slot_value::<#core_path::CallbackHolder, _>(
336 #slot_ident,
337 |holder| {
338 holder.update(#ident);
339 },
340 );
341 __changed = true;
342 }
343 } else if info.is_impl_trait {
344 quote! { __changed = true; }
346 } else {
347 let ident = &info.ident;
348 let ty = &info.ty;
349 quote! {
350 let #slot_ident = __composer
351 .__use_param_slot(|| #core_path::ParamState::<#ty>::default());
352 if __composer.with_slot_value_mut::<#core_path::ParamState<#ty>, _>(
353 #slot_ident,
354 |state| state.update(&#ident),
355 )
356 {
357 __changed = true;
358 }
359 }
360 }
361 })
362 .collect();
363
364 let param_setup_recompose: Vec<TokenStream2> = param_info
365 .iter()
366 .zip(param_state_slots.iter())
367 .map(|(info, slot_ident)| {
368 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
369 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
370 {
371 quote! {
372 let #slot_ident = __composer
373 .__use_param_slot(|| #core_path::CallbackHolder::new());
374 }
375 } else if info.is_impl_trait {
376 quote! {}
377 } else {
378 let ty = &info.ty;
379 quote! {
380 let #slot_ident = __composer
381 .__use_param_slot(|| #core_path::ParamState::<#ty>::default());
382 }
383 }
384 })
385 .collect();
386
387 let rebinds: Vec<TokenStream2> = param_info
388 .iter()
389 .zip(param_state_slots.iter())
390 .map(|(info, slot_ident)| {
391 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
392 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
393 {
394 let pat = &info.pat;
395 let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
396 if can_add_mut && !info.pat_is_mut {
397 quote! {
398 #[allow(unused_mut)]
399 let mut #pat = __composer
400 .with_slot_value::<#core_path::CallbackHolder, _>(
401 #slot_ident,
402 |holder| holder.clone_rc(),
403 );
404 }
405 } else {
406 quote! {
407 #[allow(unused_mut)]
408 let #pat = __composer
409 .with_slot_value::<#core_path::CallbackHolder, _>(
410 #slot_ident,
411 |holder| holder.clone_rc(),
412 );
413 }
414 }
415 } else if info.is_impl_trait {
416 quote! {}
417 } else {
418 let pat = &info.pat;
419 let ident = &info.ident;
420 quote! {
421 let #pat = #ident;
422 }
423 }
424 })
425 .collect();
426
427 let rebinds_for_recompose: Vec<TokenStream2> = param_info
428 .iter()
429 .zip(param_state_slots.iter())
430 .map(|(info, slot_ident)| {
431 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
432 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
433 {
434 let pat = &info.pat;
435 let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
436 if can_add_mut && !info.pat_is_mut {
437 quote! {
438 #[allow(unused_mut)]
439 let mut #pat = __composer
440 .with_slot_value::<#core_path::CallbackHolder, _>(
441 #slot_ident,
442 |holder| holder.clone_rc(),
443 );
444 }
445 } else {
446 quote! {
447 #[allow(unused_mut)]
448 let #pat = __composer
449 .with_slot_value::<#core_path::CallbackHolder, _>(
450 #slot_ident,
451 |holder| holder.clone_rc(),
452 );
453 }
454 }
455 } else if info.is_impl_trait {
456 quote! {}
457 } else {
458 let pat = &info.pat;
459 let ty = &info.ty;
460 quote! {
461 let #pat = __composer
462 .with_slot_value::<#core_path::ParamState<#ty>, _>(
463 #slot_ident,
464 |state| {
465 state
466 .value()
467 .expect("composable parameter missing for recomposition")
468 },
469 );
470 }
471 }
472 })
473 .collect();
474
475 let recranpose_fn_ident = Ident::new(
476 &format!("__cranpose_recranpose_{}", func.sig.ident),
477 Span::call_site(),
478 );
479
480 let recranpose_setter = quote! {
481 {
482 __composer.set_recranpose_callback(move |
483 __composer: &#core_path::Composer|
484 {
485 #recranpose_fn_ident #ty_generics_turbofish (
486 __composer
487 );
488 });
489 }
490 };
491
492 let helper_body = if returns_unit {
493 quote! {
494 #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
495 let __current_scope = __composer
496 .current_recranpose_scope()
497 .expect("missing recompose scope");
498 let mut __changed = __current_scope.should_recompose();
499 #(#param_setup)*
500 #recranpose_setter
501 if !__changed && __current_scope.has_composed_once() {
502 __composer.skip_current_group();
503 return;
504 }
505 #(#rebinds)*
506 #helper_block
507 }
508 } else {
509 quote! {
510 #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
511 let __current_scope = __composer
512 .current_recranpose_scope()
513 .expect("missing recompose scope");
514 let mut __changed = __current_scope.should_recompose();
515 #(#param_setup)*
516 #recranpose_setter
517 let __result_slot_index = __composer
518 .__use_return_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
519 let __has_previous = __composer
520 .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
521 __result_slot_index,
522 |slot| slot.get().is_some(),
523 );
524 if !__changed && __has_previous {
525 __composer.skip_current_group();
526 let __result = __composer
527 .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
528 __result_slot_index,
529 |slot| {
530 slot.get()
531 .expect("composable return value missing during skip")
532 },
533 );
534 return __result;
535 }
536 let __value: #return_ty = {
537 #(#rebinds)*
538 #helper_block
539 };
540 __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
541 __result_slot_index,
542 |slot| {
543 slot.store(__value.clone());
544 },
545 );
546 __value
547 }
548 };
549
550 let recranpose_fn_body = if returns_unit {
551 quote! {
552 #(#param_setup_recompose)*
553 #(#rebinds_for_recompose)*
554 #recranpose_block
555 #recranpose_setter
556 }
557 } else {
558 quote! {
559 #(#param_setup_recompose)*
560 let __result_slot_index = __composer
561 .__use_return_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
562 #(#rebinds_for_recompose)*
563 let __value: #return_ty = {
564 #recranpose_block
565 };
566 __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
567 __result_slot_index,
568 |slot| {
569 slot.store(__value.clone());
570 },
571 );
572 #recranpose_setter
573 #invalidate_return_consumer
574 __value
575 }
576 };
577
578 let recranpose_fn = quote! {
579 #[allow(non_snake_case)]
580 fn #recranpose_fn_ident #impl_generics (
581 __composer: &#core_path::Composer
582 ) -> #return_ty #where_clause {
583 #recranpose_fn_body
584 }
585 };
586
587 let helper_fn = quote! {
588 #[allow(non_snake_case, clippy::too_many_arguments)]
589 fn #helper_ident #impl_generics (
590 __composer: &#core_path::Composer
591 #(, #helper_inputs)*
592 ) -> #return_ty #where_clause {
593 #helper_body
594 }
595 };
596
597 let wrapper_args: Vec<TokenStream2> = param_info
599 .iter()
600 .filter_map(|info| {
601 if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
602 None
603 } else {
604 let ident = &info.ident;
605 Some(quote! { #ident })
606 }
607 })
608 .collect();
609
610 let wrapped = quote!({
611 #core_path::with_current_composer(|__composer: &#core_path::Composer| {
612 __composer.with_group(#key_expr, |__composer: &#core_path::Composer| {
613 #helper_ident(__composer #(, #wrapper_args)*)
614 })
615 })
616 });
617 *func.block = syn::parse2(wrapped).expect("failed to build block");
618 TokenStream::from(quote! {
619 #recranpose_fn
620 #helper_fn
621 #func
622 })
623 } else {
624 let wrapped = quote!({
626 #core_path::with_current_composer(|__composer: &#core_path::Composer| {
627 __composer.with_group(#key_expr, |__scope: &#core_path::Composer| {
628 #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
629 #(#rebinds_for_no_skip)*
630 #original_block
631 })
632 })
633 });
634 *func.block = syn::parse2(wrapped).expect("failed to build block");
635 TokenStream::from(quote! { #func })
636 }
637}