1#![deny(unsafe_code)]
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use proc_macro_crate::{crate_name, FoundCrate};
6use quote::quote;
7use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Type};
8
9fn is_fn_like_type(ty: &Type) -> bool {
12 match ty {
13 Type::ImplTrait(impl_trait) => impl_trait.bounds.iter().any(|bound| {
15 if let syn::TypeParamBound::Trait(trait_bound) = bound {
16 let path = &trait_bound.path;
17 if let Some(segment) = path.segments.last() {
18 let ident_str = segment.ident.to_string();
19 return ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce";
20 }
21 }
22 false
23 }),
24 Type::Path(type_path) => {
26 if let Some(segment) = type_path.path.segments.last() {
27 if segment.ident == "Box" {
28 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
29 if let Some(syn::GenericArgument::Type(Type::TraitObject(trait_obj))) =
30 args.args.first()
31 {
32 return trait_obj.bounds.iter().any(|bound| {
33 if let syn::TypeParamBound::Trait(trait_bound) = bound {
34 let path = &trait_bound.path;
35 if let Some(segment) = path.segments.last() {
36 let ident_str = segment.ident.to_string();
37 return ident_str == "FnMut"
38 || ident_str == "Fn"
39 || ident_str == "FnOnce";
40 }
41 }
42 false
43 });
44 }
45 }
46 }
47 }
48 false
49 }
50 Type::BareFn(_) => true,
52 _ => false,
53 }
54}
55
56fn is_generic_fn_like(ty: &Type, generics: &syn::Generics) -> bool {
58 let type_ident = match ty {
60 Type::Path(type_path) if type_path.path.segments.len() == 1 => {
61 &type_path.path.segments[0].ident
62 }
63 _ => return false,
64 };
65
66 for param in &generics.params {
68 if let syn::GenericParam::Type(type_param) = param {
69 if type_param.ident == *type_ident {
70 for bound in &type_param.bounds {
72 if let syn::TypeParamBound::Trait(trait_bound) = bound {
73 if let Some(segment) = trait_bound.path.segments.last() {
74 let ident_str = segment.ident.to_string();
75 if ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce" {
76 return true;
77 }
78 }
79 }
80 }
81 }
82 }
83 }
84
85 if let Some(where_clause) = &generics.where_clause {
87 for predicate in &where_clause.predicates {
88 if let syn::WherePredicate::Type(pred) = predicate {
89 if let Type::Path(bounded_type) = &pred.bounded_ty {
90 if bounded_type.path.segments.len() == 1
91 && bounded_type.path.segments[0].ident == *type_ident
92 {
93 for bound in &pred.bounds {
94 if let syn::TypeParamBound::Trait(trait_bound) = bound {
95 if let Some(segment) = trait_bound.path.segments.last() {
96 let ident_str = segment.ident.to_string();
97 if ident_str == "FnMut"
98 || ident_str == "Fn"
99 || ident_str == "FnOnce"
100 {
101 return true;
102 }
103 }
104 }
105 }
106 }
107 }
108 }
109 }
110 }
111
112 false
113}
114
115fn is_fn_param(ty: &Type, generics: &syn::Generics) -> bool {
117 is_fn_like_type(ty) || is_generic_fn_like(ty, generics)
118}
119
120fn is_zero_arg_fn_impl_trait(ty: &Type) -> bool {
124 if let Type::ImplTrait(impl_trait) = ty {
125 impl_trait.bounds.iter().any(|bound| {
126 if let syn::TypeParamBound::Trait(trait_bound) = bound {
127 if let Some(segment) = trait_bound.path.segments.last() {
128 let ident_str = segment.ident.to_string();
129 if ident_str == "Fn" || ident_str == "FnMut" {
130 if let syn::PathArguments::Parenthesized(args) = &segment.arguments {
131 return args.inputs.is_empty();
132 }
133 }
134 }
135 }
136 false
137 })
138 } else {
139 false
140 }
141}
142
143fn is_node_id_return(ty: &Type) -> bool {
144 matches!(
145 ty,
146 Type::Path(type_path)
147 if type_path
148 .path
149 .segments
150 .last()
151 .is_some_and(|segment| segment.ident == "NodeId")
152 )
153}
154
155fn core_crate_path() -> TokenStream2 {
156 let crate_name = crate_name("cranpose")
157 .ok()
158 .or_else(|| crate_name("cranpose-core").ok());
159
160 match crate_name {
161 Some(FoundCrate::Itself) => quote!(crate),
162 Some(FoundCrate::Name(name)) => {
163 let ident = Ident::new(&name, Span::call_site());
164 quote!(#ident)
165 }
166 None => quote!(cranpose_core),
167 }
168}
169
170#[proc_macro_attribute]
171pub fn composable(attr: TokenStream, item: TokenStream) -> TokenStream {
172 let attr_tokens = TokenStream2::from(attr);
173 let mut enable_skip = true;
174 let core_path = core_crate_path();
175 if !attr_tokens.is_empty() {
176 match syn::parse2::<Ident>(attr_tokens) {
177 Ok(ident) if ident == "no_skip" => enable_skip = false,
178 Ok(other) => {
179 return syn::Error::new_spanned(other, "unsupported composable attribute")
180 .to_compile_error()
181 .into();
182 }
183 Err(err) => {
184 return err.to_compile_error().into();
185 }
186 }
187 }
188
189 let mut func = parse_macro_input!(item as ItemFn);
190
191 struct ParamInfo {
192 ident: Ident,
193 pat: Box<Pat>,
194 ty: Type,
195 pat_is_mut: bool,
196 is_impl_trait: bool,
197 }
198
199 let mut param_info: Vec<ParamInfo> = Vec::new();
200
201 for (index, arg) in func.sig.inputs.iter_mut().enumerate() {
202 if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
203 let pat_is_mut = matches!(
204 pat.as_ref(),
205 Pat::Ident(pat_ident) if pat_ident.mutability.is_some()
206 );
207 let is_impl_trait = matches!(**ty, Type::ImplTrait(_));
208
209 if is_impl_trait {
210 let original_pat: Box<Pat> = pat.clone();
211 if let Pat::Ident(pat_ident) = &**pat {
212 param_info.push(ParamInfo {
213 ident: pat_ident.ident.clone(),
214 pat: original_pat,
215 ty: ty.as_ref().clone(),
216 pat_is_mut,
217 is_impl_trait: true,
218 });
219 } else {
220 param_info.push(ParamInfo {
221 ident: Ident::new(&format!("__arg{}", index), Span::call_site()),
222 pat: original_pat,
223 ty: ty.as_ref().clone(),
224 pat_is_mut,
225 is_impl_trait: true,
226 });
227 }
228 } else {
229 let ident = Ident::new(&format!("__arg{}", index), Span::call_site());
230 let original_pat: Box<Pat> = pat.clone();
231 **pat = syn::parse_quote! { #ident };
232 param_info.push(ParamInfo {
233 ident,
234 pat: original_pat,
235 ty: ty.as_ref().clone(),
236 pat_is_mut,
237 is_impl_trait: false,
238 });
239 }
240 }
241 }
242
243 let scope_label_ident = func.sig.ident.clone();
244 let original_block = func.block.clone();
245 let helper_block = original_block.clone();
246 let recranpose_block = original_block.clone();
247 let key_expr = quote! { #core_path::location_key(file!(), line!(), column!()) };
248
249 let rebinds_for_no_skip: Vec<_> = param_info
251 .iter()
252 .map(|info| {
253 let ident = &info.ident;
254 let pat = &info.pat;
255 quote! { let #pat = #ident; }
256 })
257 .collect();
258
259 let return_ty: syn::Type = match &func.sig.output {
260 ReturnType::Default => syn::parse_quote! { () },
261 ReturnType::Type(_, ty) => ty.as_ref().clone(),
262 };
263 let returns_unit = match &func.sig.output {
264 ReturnType::Default => true,
265 ReturnType::Type(_, ty) => {
266 matches!(ty.as_ref(), Type::Tuple(tuple) if tuple.elems.is_empty())
267 }
268 };
269 let invalidate_return_consumer = if returns_unit || is_node_id_return(&return_ty) {
270 quote! {}
271 } else {
272 quote! { __composer.__invalidate_return_consumer_scope(); }
273 };
274 let _helper_ident = Ident::new(
275 &format!("__cranpose_impl_{}", func.sig.ident),
276 Span::call_site(),
277 );
278 let generics = func.sig.generics.clone();
279 let (_impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
280
281 let _helper_inputs: Vec<TokenStream2> = param_info
282 .iter()
283 .map(|info| {
284 let ident = &info.ident;
285 let ty = &info.ty;
286 quote! { #ident: #ty }
287 })
288 .collect();
289
290 let has_unhandled_impl_trait = param_info
293 .iter()
294 .any(|info| info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty));
295
296 if enable_skip && !has_unhandled_impl_trait {
297 let helper_ident = Ident::new(
298 &format!("__cranpose_impl_{}", func.sig.ident),
299 Span::call_site(),
300 );
301 let generics = func.sig.generics.clone();
302 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
303 let ty_generics_turbofish = ty_generics.as_turbofish();
304
305 let helper_inputs: Vec<TokenStream2> = param_info
308 .iter()
309 .filter_map(|info| {
310 if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
311 None
312 } else {
313 let ident = &info.ident;
314 let ty = &info.ty;
315 Some(quote! { #ident: #ty })
316 }
317 })
318 .collect();
319
320 let param_state_slots: Vec<Ident> = (0..param_info.len())
322 .map(|index| Ident::new(&format!("__param_state_slot{}", index), Span::call_site()))
323 .collect();
324
325 let param_setup: Vec<TokenStream2> = param_info
326 .iter()
327 .zip(param_state_slots.iter())
328 .map(|(info, slot_ident)| {
329 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
331 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
332 {
333 let ident = &info.ident;
334 quote! {
335 let #slot_ident = __composer
336 .__use_param_slot(|| #core_path::CallbackHolder::new());
337 __composer.with_slot_value::<#core_path::CallbackHolder, _>(
338 #slot_ident,
339 |holder| {
340 holder.update(#ident);
341 },
342 );
343 __changed = true;
344 }
345 } else if info.is_impl_trait {
346 quote! { __changed = true; }
348 } else {
349 let ident = &info.ident;
350 let ty = &info.ty;
351 quote! {
352 let #slot_ident = __composer
353 .__use_param_slot(|| #core_path::ParamState::<#ty>::default());
354 if __composer.with_slot_value_mut::<#core_path::ParamState<#ty>, _>(
355 #slot_ident,
356 |state| state.update(&#ident),
357 )
358 {
359 __changed = true;
360 }
361 }
362 }
363 })
364 .collect();
365
366 let param_setup_recompose: Vec<TokenStream2> = param_info
367 .iter()
368 .zip(param_state_slots.iter())
369 .map(|(info, slot_ident)| {
370 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
371 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
372 {
373 quote! {
374 let #slot_ident = __composer
375 .__use_param_slot(|| #core_path::CallbackHolder::new());
376 }
377 } else if info.is_impl_trait {
378 quote! {}
379 } else {
380 let ty = &info.ty;
381 quote! {
382 let #slot_ident = __composer
383 .__use_param_slot(|| #core_path::ParamState::<#ty>::default());
384 }
385 }
386 })
387 .collect();
388
389 let rebinds: Vec<TokenStream2> = param_info
390 .iter()
391 .zip(param_state_slots.iter())
392 .map(|(info, slot_ident)| {
393 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
394 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
395 {
396 let pat = &info.pat;
397 let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
398 if can_add_mut && !info.pat_is_mut {
399 quote! {
400 #[allow(unused_mut)]
401 let mut #pat = __composer
402 .with_slot_value::<#core_path::CallbackHolder, _>(
403 #slot_ident,
404 |holder| holder.clone_rc(),
405 );
406 }
407 } else {
408 quote! {
409 #[allow(unused_mut)]
410 let #pat = __composer
411 .with_slot_value::<#core_path::CallbackHolder, _>(
412 #slot_ident,
413 |holder| holder.clone_rc(),
414 );
415 }
416 }
417 } else if info.is_impl_trait {
418 quote! {}
419 } else {
420 let pat = &info.pat;
421 let ident = &info.ident;
422 quote! {
423 let #pat = #ident;
424 }
425 }
426 })
427 .collect();
428
429 let rebinds_for_recompose: Vec<TokenStream2> = param_info
430 .iter()
431 .zip(param_state_slots.iter())
432 .map(|(info, slot_ident)| {
433 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
434 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
435 {
436 let pat = &info.pat;
437 let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
438 if can_add_mut && !info.pat_is_mut {
439 quote! {
440 #[allow(unused_mut)]
441 let mut #pat = __composer
442 .with_slot_value::<#core_path::CallbackHolder, _>(
443 #slot_ident,
444 |holder| holder.clone_rc(),
445 );
446 }
447 } else {
448 quote! {
449 #[allow(unused_mut)]
450 let #pat = __composer
451 .with_slot_value::<#core_path::CallbackHolder, _>(
452 #slot_ident,
453 |holder| holder.clone_rc(),
454 );
455 }
456 }
457 } else if info.is_impl_trait {
458 quote! {}
459 } else {
460 let pat = &info.pat;
461 let ty = &info.ty;
462 quote! {
463 let #pat = __composer
464 .with_slot_value::<#core_path::ParamState<#ty>, _>(
465 #slot_ident,
466 |state| {
467 state
468 .value()
469 .expect("composable parameter missing for recomposition")
470 },
471 );
472 }
473 }
474 })
475 .collect();
476
477 let recranpose_fn_ident = Ident::new(
478 &format!("__cranpose_recranpose_{}", func.sig.ident),
479 Span::call_site(),
480 );
481
482 let recranpose_setter = quote! {
483 {
484 __composer.set_recranpose_callback(move |
485 __composer: &#core_path::Composer|
486 {
487 #recranpose_fn_ident #ty_generics_turbofish (
488 __composer
489 );
490 });
491 }
492 };
493
494 let helper_body = if returns_unit {
495 quote! {
496 #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
497 let __current_scope = __composer
498 .current_recranpose_scope()
499 .expect("missing recompose scope");
500 let mut __changed = __current_scope.should_recompose();
501 #(#param_setup)*
502 #recranpose_setter
503 if !__changed && __current_scope.has_composed_once() {
504 __composer.skip_current_group();
505 return;
506 }
507 #(#rebinds)*
508 #helper_block
509 }
510 } else {
511 quote! {
512 #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
513 let __current_scope = __composer
514 .current_recranpose_scope()
515 .expect("missing recompose scope");
516 let mut __changed = __current_scope.should_recompose();
517 #(#param_setup)*
518 #recranpose_setter
519 let __result_slot_index = __composer
520 .__use_return_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
521 let __has_previous = __composer
522 .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
523 __result_slot_index,
524 |slot| slot.get().is_some(),
525 );
526 if !__changed && __has_previous {
527 __composer.skip_current_group();
528 let __result = __composer
529 .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
530 __result_slot_index,
531 |slot| {
532 slot.get()
533 .expect("composable return value missing during skip")
534 },
535 );
536 return __result;
537 }
538 let __value: #return_ty = {
539 #(#rebinds)*
540 #helper_block
541 };
542 __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
543 __result_slot_index,
544 |slot| {
545 slot.store(__value.clone());
546 },
547 );
548 __value
549 }
550 };
551
552 let recranpose_fn_body = if returns_unit {
553 quote! {
554 #(#param_setup_recompose)*
555 #(#rebinds_for_recompose)*
556 #recranpose_block
557 #recranpose_setter
558 }
559 } else {
560 quote! {
561 #(#param_setup_recompose)*
562 let __result_slot_index = __composer
563 .__use_return_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
564 #(#rebinds_for_recompose)*
565 let __value: #return_ty = {
566 #recranpose_block
567 };
568 __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
569 __result_slot_index,
570 |slot| {
571 slot.store(__value.clone());
572 },
573 );
574 #recranpose_setter
575 #invalidate_return_consumer
576 __value
577 }
578 };
579
580 let recranpose_fn = quote! {
581 #[allow(non_snake_case)]
582 fn #recranpose_fn_ident #impl_generics (
583 __composer: &#core_path::Composer
584 ) -> #return_ty #where_clause {
585 #recranpose_fn_body
586 }
587 };
588
589 let helper_fn = quote! {
590 #[allow(non_snake_case, clippy::too_many_arguments)]
591 fn #helper_ident #impl_generics (
592 __composer: &#core_path::Composer
593 #(, #helper_inputs)*
594 ) -> #return_ty #where_clause {
595 #helper_body
596 }
597 };
598
599 let wrapper_args: Vec<TokenStream2> = param_info
601 .iter()
602 .filter_map(|info| {
603 if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
604 None
605 } else {
606 let ident = &info.ident;
607 Some(quote! { #ident })
608 }
609 })
610 .collect();
611
612 let wrapped = quote!({
613 #core_path::with_current_composer(|__composer: &#core_path::Composer| {
614 __composer.with_group(#key_expr, |__composer: &#core_path::Composer| {
615 #helper_ident(__composer #(, #wrapper_args)*)
616 })
617 })
618 });
619 *func.block = syn::parse2(wrapped).expect("failed to build block");
620 TokenStream::from(quote! {
621 #recranpose_fn
622 #helper_fn
623 #func
624 })
625 } else {
626 let wrapped = quote!({
628 #core_path::with_current_composer(|__composer: &#core_path::Composer| {
629 __composer.with_group(#key_expr, |__scope: &#core_path::Composer| {
630 #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
631 #(#rebinds_for_no_skip)*
632 #original_block
633 })
634 })
635 });
636 *func.block = syn::parse2(wrapped).expect("failed to build block");
637 TokenStream::from(quote! { #func })
638 }
639}