1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use proc_macro_crate::{crate_name, FoundCrate};
4use quote::quote;
5use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Type};
6
7fn is_fn_like_type(ty: &Type) -> bool {
10 match ty {
11 Type::ImplTrait(impl_trait) => impl_trait.bounds.iter().any(|bound| {
13 if let syn::TypeParamBound::Trait(trait_bound) = bound {
14 let path = &trait_bound.path;
15 if let Some(segment) = path.segments.last() {
16 let ident_str = segment.ident.to_string();
17 return ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce";
18 }
19 }
20 false
21 }),
22 Type::Path(type_path) => {
24 if let Some(segment) = type_path.path.segments.last() {
25 if segment.ident == "Box" {
26 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
27 if let Some(syn::GenericArgument::Type(Type::TraitObject(trait_obj))) =
28 args.args.first()
29 {
30 return trait_obj.bounds.iter().any(|bound| {
31 if let syn::TypeParamBound::Trait(trait_bound) = bound {
32 let path = &trait_bound.path;
33 if let Some(segment) = path.segments.last() {
34 let ident_str = segment.ident.to_string();
35 return ident_str == "FnMut"
36 || ident_str == "Fn"
37 || ident_str == "FnOnce";
38 }
39 }
40 false
41 });
42 }
43 }
44 }
45 }
46 false
47 }
48 Type::BareFn(_) => true,
50 _ => false,
51 }
52}
53
54fn is_generic_fn_like(ty: &Type, generics: &syn::Generics) -> bool {
56 let type_ident = match ty {
58 Type::Path(type_path) if type_path.path.segments.len() == 1 => {
59 &type_path.path.segments[0].ident
60 }
61 _ => return false,
62 };
63
64 for param in &generics.params {
66 if let syn::GenericParam::Type(type_param) = param {
67 if type_param.ident == *type_ident {
68 for bound in &type_param.bounds {
70 if let syn::TypeParamBound::Trait(trait_bound) = bound {
71 if let Some(segment) = trait_bound.path.segments.last() {
72 let ident_str = segment.ident.to_string();
73 if ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce" {
74 return true;
75 }
76 }
77 }
78 }
79 }
80 }
81 }
82
83 if let Some(where_clause) = &generics.where_clause {
85 for predicate in &where_clause.predicates {
86 if let syn::WherePredicate::Type(pred) = predicate {
87 if let Type::Path(bounded_type) = &pred.bounded_ty {
88 if bounded_type.path.segments.len() == 1
89 && bounded_type.path.segments[0].ident == *type_ident
90 {
91 for bound in &pred.bounds {
92 if let syn::TypeParamBound::Trait(trait_bound) = bound {
93 if let Some(segment) = trait_bound.path.segments.last() {
94 let ident_str = segment.ident.to_string();
95 if ident_str == "FnMut"
96 || ident_str == "Fn"
97 || ident_str == "FnOnce"
98 {
99 return true;
100 }
101 }
102 }
103 }
104 }
105 }
106 }
107 }
108 }
109
110 false
111}
112
113fn is_fn_param(ty: &Type, generics: &syn::Generics) -> bool {
115 is_fn_like_type(ty) || is_generic_fn_like(ty, generics)
116}
117
118fn is_zero_arg_fn_impl_trait(ty: &Type) -> bool {
122 if let Type::ImplTrait(impl_trait) = ty {
123 impl_trait.bounds.iter().any(|bound| {
124 if let syn::TypeParamBound::Trait(trait_bound) = bound {
125 if let Some(segment) = trait_bound.path.segments.last() {
126 let ident_str = segment.ident.to_string();
127 if ident_str == "Fn" || ident_str == "FnMut" {
128 if let syn::PathArguments::Parenthesized(args) = &segment.arguments {
129 return args.inputs.is_empty();
130 }
131 }
132 }
133 }
134 false
135 })
136 } else {
137 false
138 }
139}
140
141fn core_crate_path() -> TokenStream2 {
142 let crate_name = crate_name("cranpose")
143 .ok()
144 .or_else(|| crate_name("cranpose-core").ok());
145
146 match crate_name {
147 Some(FoundCrate::Itself) => quote!(crate),
148 Some(FoundCrate::Name(name)) => {
149 let ident = Ident::new(&name, Span::call_site());
150 quote!(#ident)
151 }
152 None => quote!(cranpose_core),
153 }
154}
155
156#[proc_macro_attribute]
157pub fn composable(attr: TokenStream, item: TokenStream) -> TokenStream {
158 let attr_tokens = TokenStream2::from(attr);
159 let mut enable_skip = true;
160 let core_path = core_crate_path();
161 if !attr_tokens.is_empty() {
162 match syn::parse2::<Ident>(attr_tokens) {
163 Ok(ident) if ident == "no_skip" => enable_skip = false,
164 Ok(other) => {
165 return syn::Error::new_spanned(other, "unsupported composable attribute")
166 .to_compile_error()
167 .into();
168 }
169 Err(err) => {
170 return err.to_compile_error().into();
171 }
172 }
173 }
174
175 let mut func = parse_macro_input!(item as ItemFn);
176
177 struct ParamInfo {
178 ident: Ident,
179 pat: Box<Pat>,
180 ty: Type,
181 pat_is_mut: bool,
182 is_impl_trait: bool,
183 }
184
185 let mut param_info: Vec<ParamInfo> = Vec::new();
186
187 for (index, arg) in func.sig.inputs.iter_mut().enumerate() {
188 if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
189 let pat_is_mut = matches!(
190 pat.as_ref(),
191 Pat::Ident(pat_ident) if pat_ident.mutability.is_some()
192 );
193 let is_impl_trait = matches!(**ty, Type::ImplTrait(_));
194
195 if is_impl_trait {
196 let original_pat: Box<Pat> = pat.clone();
197 if let Pat::Ident(pat_ident) = &**pat {
198 param_info.push(ParamInfo {
199 ident: pat_ident.ident.clone(),
200 pat: original_pat,
201 ty: ty.as_ref().clone(),
202 pat_is_mut,
203 is_impl_trait: true,
204 });
205 } else {
206 param_info.push(ParamInfo {
207 ident: Ident::new(&format!("__arg{}", index), Span::call_site()),
208 pat: original_pat,
209 ty: ty.as_ref().clone(),
210 pat_is_mut,
211 is_impl_trait: true,
212 });
213 }
214 } else {
215 let ident = Ident::new(&format!("__arg{}", index), Span::call_site());
216 let original_pat: Box<Pat> = pat.clone();
217 **pat = syn::parse_quote! { #ident };
218 param_info.push(ParamInfo {
219 ident,
220 pat: original_pat,
221 ty: ty.as_ref().clone(),
222 pat_is_mut,
223 is_impl_trait: false,
224 });
225 }
226 }
227 }
228
229 let scope_label_ident = func.sig.ident.clone();
230 let original_block = func.block.clone();
231 let helper_block = original_block.clone();
232 let recranpose_block = original_block.clone();
233 let key_expr = quote! { #core_path::location_key(file!(), line!(), column!()) };
234
235 let rebinds_for_no_skip: Vec<_> = param_info
237 .iter()
238 .map(|info| {
239 let ident = &info.ident;
240 let pat = &info.pat;
241 quote! { let #pat = #ident; }
242 })
243 .collect();
244
245 let return_ty: syn::Type = match &func.sig.output {
246 ReturnType::Default => syn::parse_quote! { () },
247 ReturnType::Type(_, ty) => ty.as_ref().clone(),
248 };
249 let returns_unit = match &func.sig.output {
250 ReturnType::Default => true,
251 ReturnType::Type(_, ty) => {
252 matches!(ty.as_ref(), Type::Tuple(tuple) if tuple.elems.is_empty())
253 }
254 };
255 let _helper_ident = Ident::new(
256 &format!("__cranpose_impl_{}", func.sig.ident),
257 Span::call_site(),
258 );
259 let generics = func.sig.generics.clone();
260 let (_impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
261
262 let _helper_inputs: Vec<TokenStream2> = param_info
263 .iter()
264 .map(|info| {
265 let ident = &info.ident;
266 let ty = &info.ty;
267 quote! { #ident: #ty }
268 })
269 .collect();
270
271 let has_unhandled_impl_trait = param_info
274 .iter()
275 .any(|info| info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty));
276
277 if enable_skip && !has_unhandled_impl_trait {
278 let helper_ident = Ident::new(
279 &format!("__cranpose_impl_{}", func.sig.ident),
280 Span::call_site(),
281 );
282 let generics = func.sig.generics.clone();
283 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
284 let ty_generics_turbofish = ty_generics.as_turbofish();
285
286 let helper_inputs: Vec<TokenStream2> = param_info
289 .iter()
290 .filter_map(|info| {
291 if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
292 None
293 } else {
294 let ident = &info.ident;
295 let ty = &info.ty;
296 Some(quote! { #ident: #ty })
297 }
298 })
299 .collect();
300
301 let param_state_slots: Vec<Ident> = (0..param_info.len())
303 .map(|index| Ident::new(&format!("__param_state_slot{}", index), Span::call_site()))
304 .collect();
305
306 let param_setup: Vec<TokenStream2> = param_info
307 .iter()
308 .zip(param_state_slots.iter())
309 .map(|(info, slot_ident)| {
310 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
312 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
313 {
314 let ident = &info.ident;
315 quote! {
316 let #slot_ident = __composer
317 .__use_param_slot(|| #core_path::CallbackHolder::new());
318 __composer.with_slot_value::<#core_path::CallbackHolder, _>(
319 #slot_ident,
320 |holder| {
321 holder.update(#ident);
322 },
323 );
324 __changed = true;
325 }
326 } else if info.is_impl_trait {
327 quote! { __changed = true; }
329 } else {
330 let ident = &info.ident;
331 let ty = &info.ty;
332 quote! {
333 let #slot_ident = __composer
334 .__use_param_slot(|| #core_path::ParamState::<#ty>::default());
335 if __composer.with_slot_value_mut::<#core_path::ParamState<#ty>, _>(
336 #slot_ident,
337 |state| state.update(&#ident),
338 )
339 {
340 __changed = true;
341 }
342 }
343 }
344 })
345 .collect();
346
347 let param_setup_recompose: Vec<TokenStream2> = param_info
348 .iter()
349 .zip(param_state_slots.iter())
350 .map(|(info, slot_ident)| {
351 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
352 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
353 {
354 quote! {
355 let #slot_ident = __composer
356 .__use_param_slot(|| #core_path::CallbackHolder::new());
357 }
358 } else if info.is_impl_trait {
359 quote! {}
360 } else {
361 let ty = &info.ty;
362 quote! {
363 let #slot_ident = __composer
364 .__use_param_slot(|| #core_path::ParamState::<#ty>::default());
365 }
366 }
367 })
368 .collect();
369
370 let rebinds: Vec<TokenStream2> = param_info
371 .iter()
372 .zip(param_state_slots.iter())
373 .map(|(info, slot_ident)| {
374 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
375 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
376 {
377 let pat = &info.pat;
378 let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
379 if can_add_mut && !info.pat_is_mut {
380 quote! {
381 #[allow(unused_mut)]
382 let mut #pat = __composer
383 .with_slot_value::<#core_path::CallbackHolder, _>(
384 #slot_ident,
385 |holder| holder.clone_rc(),
386 );
387 }
388 } else {
389 quote! {
390 #[allow(unused_mut)]
391 let #pat = __composer
392 .with_slot_value::<#core_path::CallbackHolder, _>(
393 #slot_ident,
394 |holder| holder.clone_rc(),
395 );
396 }
397 }
398 } else if info.is_impl_trait {
399 quote! {}
400 } else {
401 let pat = &info.pat;
402 let ident = &info.ident;
403 quote! {
404 let #pat = #ident;
405 }
406 }
407 })
408 .collect();
409
410 let rebinds_for_recompose: Vec<TokenStream2> = param_info
411 .iter()
412 .zip(param_state_slots.iter())
413 .map(|(info, slot_ident)| {
414 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
415 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
416 {
417 let pat = &info.pat;
418 let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
419 if can_add_mut && !info.pat_is_mut {
420 quote! {
421 #[allow(unused_mut)]
422 let mut #pat = __composer
423 .with_slot_value::<#core_path::CallbackHolder, _>(
424 #slot_ident,
425 |holder| holder.clone_rc(),
426 );
427 }
428 } else {
429 quote! {
430 #[allow(unused_mut)]
431 let #pat = __composer
432 .with_slot_value::<#core_path::CallbackHolder, _>(
433 #slot_ident,
434 |holder| holder.clone_rc(),
435 );
436 }
437 }
438 } else if info.is_impl_trait {
439 quote! {}
440 } else {
441 let pat = &info.pat;
442 let ty = &info.ty;
443 quote! {
444 let #pat = __composer
445 .with_slot_value::<#core_path::ParamState<#ty>, _>(
446 #slot_ident,
447 |state| {
448 state
449 .value()
450 .expect("composable parameter missing for recomposition")
451 },
452 );
453 }
454 }
455 })
456 .collect();
457
458 let recranpose_fn_ident = Ident::new(
459 &format!("__cranpose_recranpose_{}", func.sig.ident),
460 Span::call_site(),
461 );
462
463 let recranpose_setter = quote! {
464 {
465 __composer.set_recranpose_callback(move |
466 __composer: &#core_path::Composer|
467 {
468 #recranpose_fn_ident #ty_generics_turbofish (
469 __composer
470 );
471 });
472 }
473 };
474
475 let helper_body = if returns_unit {
476 quote! {
477 #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
478 let __current_scope = __composer
479 .current_recranpose_scope()
480 .expect("missing recompose scope");
481 let mut __changed = __current_scope.should_recompose();
482 #(#param_setup)*
483 #recranpose_setter
484 if !__changed && __current_scope.has_composed_once() {
485 __composer.skip_current_group();
486 return;
487 }
488 #(#rebinds)*
489 #helper_block
490 }
491 } else {
492 quote! {
493 #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
494 let __current_scope = __composer
495 .current_recranpose_scope()
496 .expect("missing recompose scope");
497 let mut __changed = __current_scope.should_recompose();
498 #(#param_setup)*
499 #recranpose_setter
500 let __result_slot_index = __composer
501 .__use_return_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
502 let __has_previous = __composer
503 .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
504 __result_slot_index,
505 |slot| slot.get().is_some(),
506 );
507 if !__changed && __has_previous {
508 __composer.skip_current_group();
509 let __result = __composer
510 .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
511 __result_slot_index,
512 |slot| {
513 slot.get()
514 .expect("composable return value missing during skip")
515 },
516 );
517 return __result;
518 }
519 let __value: #return_ty = {
520 #(#rebinds)*
521 #helper_block
522 };
523 __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
524 __result_slot_index,
525 |slot| {
526 slot.store(__value.clone());
527 },
528 );
529 __value
530 }
531 };
532
533 let recranpose_fn_body = if returns_unit {
534 quote! {
535 #(#param_setup_recompose)*
536 #(#rebinds_for_recompose)*
537 #recranpose_block
538 #recranpose_setter
539 }
540 } else {
541 quote! {
542 #(#param_setup_recompose)*
543 let __result_slot_index = __composer
544 .__use_return_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
545 #(#rebinds_for_recompose)*
546 let __value: #return_ty = {
547 #recranpose_block
548 };
549 __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
550 __result_slot_index,
551 |slot| {
552 slot.store(__value.clone());
553 },
554 );
555 #recranpose_setter
556 __value
557 }
558 };
559
560 let recranpose_fn = quote! {
561 #[allow(non_snake_case)]
562 fn #recranpose_fn_ident #impl_generics (
563 __composer: &#core_path::Composer
564 ) -> #return_ty #where_clause {
565 #recranpose_fn_body
566 }
567 };
568
569 let helper_fn = quote! {
570 #[allow(non_snake_case, clippy::too_many_arguments)]
571 fn #helper_ident #impl_generics (
572 __composer: &#core_path::Composer
573 #(, #helper_inputs)*
574 ) -> #return_ty #where_clause {
575 #helper_body
576 }
577 };
578
579 let wrapper_args: Vec<TokenStream2> = param_info
581 .iter()
582 .filter_map(|info| {
583 if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
584 None
585 } else {
586 let ident = &info.ident;
587 Some(quote! { #ident })
588 }
589 })
590 .collect();
591
592 let wrapped = quote!({
593 #core_path::with_current_composer(|__composer: &#core_path::Composer| {
594 __composer.with_group(#key_expr, |__composer: &#core_path::Composer| {
595 #helper_ident(__composer #(, #wrapper_args)*)
596 })
597 })
598 });
599 *func.block = syn::parse2(wrapped).expect("failed to build block");
600 TokenStream::from(quote! {
601 #recranpose_fn
602 #helper_fn
603 #func
604 })
605 } else {
606 let wrapped = quote!({
608 #core_path::with_current_composer(|__composer: &#core_path::Composer| {
609 __composer.with_group(#key_expr, |__scope: &#core_path::Composer| {
610 #core_path::debug_label_current_scope(stringify!(#scope_label_ident));
611 #(#rebinds_for_no_skip)*
612 #original_block
613 })
614 })
615 });
616 *func.block = syn::parse2(wrapped).expect("failed to build block");
617 TokenStream::from(quote! { #func })
618 }
619}