1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use proc_macro_crate::{crate_name, FoundCrate};
4use quote::quote;
5use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Type};
6
7fn is_fn_like_type(ty: &Type) -> bool {
10 match ty {
11 Type::ImplTrait(impl_trait) => impl_trait.bounds.iter().any(|bound| {
13 if let syn::TypeParamBound::Trait(trait_bound) = bound {
14 let path = &trait_bound.path;
15 if let Some(segment) = path.segments.last() {
16 let ident_str = segment.ident.to_string();
17 return ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce";
18 }
19 }
20 false
21 }),
22 Type::Path(type_path) => {
24 if let Some(segment) = type_path.path.segments.last() {
25 if segment.ident == "Box" {
26 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
27 if let Some(syn::GenericArgument::Type(Type::TraitObject(trait_obj))) =
28 args.args.first()
29 {
30 return trait_obj.bounds.iter().any(|bound| {
31 if let syn::TypeParamBound::Trait(trait_bound) = bound {
32 let path = &trait_bound.path;
33 if let Some(segment) = path.segments.last() {
34 let ident_str = segment.ident.to_string();
35 return ident_str == "FnMut"
36 || ident_str == "Fn"
37 || ident_str == "FnOnce";
38 }
39 }
40 false
41 });
42 }
43 }
44 }
45 }
46 false
47 }
48 Type::BareFn(_) => true,
50 _ => false,
51 }
52}
53
54fn is_generic_fn_like(ty: &Type, generics: &syn::Generics) -> bool {
56 let type_ident = match ty {
58 Type::Path(type_path) if type_path.path.segments.len() == 1 => {
59 &type_path.path.segments[0].ident
60 }
61 _ => return false,
62 };
63
64 for param in &generics.params {
66 if let syn::GenericParam::Type(type_param) = param {
67 if type_param.ident == *type_ident {
68 for bound in &type_param.bounds {
70 if let syn::TypeParamBound::Trait(trait_bound) = bound {
71 if let Some(segment) = trait_bound.path.segments.last() {
72 let ident_str = segment.ident.to_string();
73 if ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce" {
74 return true;
75 }
76 }
77 }
78 }
79 }
80 }
81 }
82
83 if let Some(where_clause) = &generics.where_clause {
85 for predicate in &where_clause.predicates {
86 if let syn::WherePredicate::Type(pred) = predicate {
87 if let Type::Path(bounded_type) = &pred.bounded_ty {
88 if bounded_type.path.segments.len() == 1
89 && bounded_type.path.segments[0].ident == *type_ident
90 {
91 for bound in &pred.bounds {
92 if let syn::TypeParamBound::Trait(trait_bound) = bound {
93 if let Some(segment) = trait_bound.path.segments.last() {
94 let ident_str = segment.ident.to_string();
95 if ident_str == "FnMut"
96 || ident_str == "Fn"
97 || ident_str == "FnOnce"
98 {
99 return true;
100 }
101 }
102 }
103 }
104 }
105 }
106 }
107 }
108 }
109
110 false
111}
112
113fn is_fn_param(ty: &Type, generics: &syn::Generics) -> bool {
115 is_fn_like_type(ty) || is_generic_fn_like(ty, generics)
116}
117
118fn is_zero_arg_fn_impl_trait(ty: &Type) -> bool {
122 if let Type::ImplTrait(impl_trait) = ty {
123 impl_trait.bounds.iter().any(|bound| {
124 if let syn::TypeParamBound::Trait(trait_bound) = bound {
125 if let Some(segment) = trait_bound.path.segments.last() {
126 let ident_str = segment.ident.to_string();
127 if ident_str == "Fn" || ident_str == "FnMut" {
128 if let syn::PathArguments::Parenthesized(args) = &segment.arguments {
129 return args.inputs.is_empty();
130 }
131 }
132 }
133 }
134 false
135 })
136 } else {
137 false
138 }
139}
140
141fn core_crate_path() -> TokenStream2 {
142 let crate_name = crate_name("cranpose")
143 .ok()
144 .or_else(|| crate_name("cranpose-core").ok());
145
146 match crate_name {
147 Some(FoundCrate::Itself) => quote!(crate),
148 Some(FoundCrate::Name(name)) => {
149 let ident = Ident::new(&name, Span::call_site());
150 quote!(#ident)
151 }
152 None => quote!(cranpose_core),
153 }
154}
155
156#[proc_macro_attribute]
157pub fn composable(attr: TokenStream, item: TokenStream) -> TokenStream {
158 let attr_tokens = TokenStream2::from(attr);
159 let mut enable_skip = true;
160 let core_path = core_crate_path();
161 if !attr_tokens.is_empty() {
162 match syn::parse2::<Ident>(attr_tokens) {
163 Ok(ident) if ident == "no_skip" => enable_skip = false,
164 Ok(other) => {
165 return syn::Error::new_spanned(other, "unsupported composable attribute")
166 .to_compile_error()
167 .into();
168 }
169 Err(err) => {
170 return err.to_compile_error().into();
171 }
172 }
173 }
174
175 let mut func = parse_macro_input!(item as ItemFn);
176
177 struct ParamInfo {
178 ident: Ident,
179 pat: Box<Pat>,
180 ty: Type,
181 pat_is_mut: bool,
182 is_impl_trait: bool,
183 }
184
185 let mut param_info: Vec<ParamInfo> = Vec::new();
186
187 for (index, arg) in func.sig.inputs.iter_mut().enumerate() {
188 if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
189 let pat_is_mut = matches!(
190 pat.as_ref(),
191 Pat::Ident(pat_ident) if pat_ident.mutability.is_some()
192 );
193 let is_impl_trait = matches!(**ty, Type::ImplTrait(_));
194
195 if is_impl_trait {
196 let original_pat: Box<Pat> = pat.clone();
197 if let Pat::Ident(pat_ident) = &**pat {
198 param_info.push(ParamInfo {
199 ident: pat_ident.ident.clone(),
200 pat: original_pat,
201 ty: ty.as_ref().clone(),
202 pat_is_mut,
203 is_impl_trait: true,
204 });
205 } else {
206 param_info.push(ParamInfo {
207 ident: Ident::new(&format!("__arg{}", index), Span::call_site()),
208 pat: original_pat,
209 ty: ty.as_ref().clone(),
210 pat_is_mut,
211 is_impl_trait: true,
212 });
213 }
214 } else {
215 let ident = Ident::new(&format!("__arg{}", index), Span::call_site());
216 let original_pat: Box<Pat> = pat.clone();
217 *pat = Box::new(syn::parse_quote! { #ident });
218 param_info.push(ParamInfo {
219 ident,
220 pat: original_pat,
221 ty: ty.as_ref().clone(),
222 pat_is_mut,
223 is_impl_trait: false,
224 });
225 }
226 }
227 }
228
229 let original_block = func.block.clone();
230 let helper_block = original_block.clone();
231 let recranpose_block = original_block.clone();
232 let key_expr = quote! { #core_path::location_key(file!(), line!(), column!()) };
233
234 let rebinds_for_no_skip: Vec<_> = param_info
236 .iter()
237 .map(|info| {
238 let ident = &info.ident;
239 let pat = &info.pat;
240 quote! { let #pat = #ident; }
241 })
242 .collect();
243
244 let return_ty: syn::Type = match &func.sig.output {
245 ReturnType::Default => syn::parse_quote! { () },
246 ReturnType::Type(_, ty) => ty.as_ref().clone(),
247 };
248 let _helper_ident = Ident::new(
249 &format!("__cranpose_impl_{}", func.sig.ident),
250 Span::call_site(),
251 );
252 let generics = func.sig.generics.clone();
253 let (_impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
254
255 let _helper_inputs: Vec<TokenStream2> = param_info
256 .iter()
257 .map(|info| {
258 let ident = &info.ident;
259 let ty = &info.ty;
260 quote! { #ident: #ty }
261 })
262 .collect();
263
264 let has_unhandled_impl_trait = param_info
267 .iter()
268 .any(|info| info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty));
269
270 if enable_skip && !has_unhandled_impl_trait {
271 let helper_ident = Ident::new(
272 &format!("__cranpose_impl_{}", func.sig.ident),
273 Span::call_site(),
274 );
275 let generics = func.sig.generics.clone();
276 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
277 let ty_generics_turbofish = ty_generics.as_turbofish();
278
279 let helper_inputs: Vec<TokenStream2> = param_info
282 .iter()
283 .filter_map(|info| {
284 if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
285 None
286 } else {
287 let ident = &info.ident;
288 let ty = &info.ty;
289 Some(quote! { #ident: #ty })
290 }
291 })
292 .collect();
293
294 let param_state_slots: Vec<Ident> = (0..param_info.len())
296 .map(|index| Ident::new(&format!("__param_state_slot{}", index), Span::call_site()))
297 .collect();
298
299 let param_setup: Vec<TokenStream2> = param_info
300 .iter()
301 .zip(param_state_slots.iter())
302 .map(|(info, slot_ident)| {
303 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
305 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
306 {
307 let ident = &info.ident;
308 quote! {
309 let #slot_ident = __composer
310 .use_value_slot(|| #core_path::CallbackHolder::new());
311 __composer.with_slot_value::<#core_path::CallbackHolder, _>(
312 #slot_ident,
313 |holder| {
314 holder.update(#ident);
315 },
316 );
317 __changed = true;
318 }
319 } else if info.is_impl_trait {
320 quote! { __changed = true; }
322 } else {
323 let ident = &info.ident;
324 let ty = &info.ty;
325 quote! {
326 let #slot_ident = __composer
327 .use_value_slot(|| #core_path::ParamState::<#ty>::default());
328 if __composer.with_slot_value_mut::<#core_path::ParamState<#ty>, _>(
329 #slot_ident,
330 |state| state.update(&#ident),
331 )
332 {
333 __changed = true;
334 }
335 }
336 }
337 })
338 .collect();
339
340 let param_setup_recompose: Vec<TokenStream2> = param_info
341 .iter()
342 .zip(param_state_slots.iter())
343 .map(|(info, slot_ident)| {
344 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
345 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
346 {
347 quote! {
348 let #slot_ident = __composer
349 .use_value_slot(|| #core_path::CallbackHolder::new());
350 }
351 } else if info.is_impl_trait {
352 quote! {}
353 } else {
354 let ty = &info.ty;
355 quote! {
356 let #slot_ident = __composer
357 .use_value_slot(|| #core_path::ParamState::<#ty>::default());
358 }
359 }
360 })
361 .collect();
362
363 let rebinds: Vec<TokenStream2> = param_info
364 .iter()
365 .zip(param_state_slots.iter())
366 .map(|(info, slot_ident)| {
367 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
368 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
369 {
370 let pat = &info.pat;
371 let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
372 if can_add_mut && !info.pat_is_mut {
373 quote! {
374 #[allow(unused_mut)]
375 let mut #pat = __composer
376 .with_slot_value::<#core_path::CallbackHolder, _>(
377 #slot_ident,
378 |holder| holder.clone_rc(),
379 );
380 }
381 } else {
382 quote! {
383 #[allow(unused_mut)]
384 let #pat = __composer
385 .with_slot_value::<#core_path::CallbackHolder, _>(
386 #slot_ident,
387 |holder| holder.clone_rc(),
388 );
389 }
390 }
391 } else if info.is_impl_trait {
392 quote! {}
393 } else {
394 let pat = &info.pat;
395 let ident = &info.ident;
396 quote! {
397 let #pat = #ident;
398 }
399 }
400 })
401 .collect();
402
403 let rebinds_for_recompose: Vec<TokenStream2> = param_info
404 .iter()
405 .zip(param_state_slots.iter())
406 .map(|(info, slot_ident)| {
407 if (info.is_impl_trait && is_zero_arg_fn_impl_trait(&info.ty))
408 || (!info.is_impl_trait && is_fn_param(&info.ty, &generics))
409 {
410 let pat = &info.pat;
411 let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
412 if can_add_mut && !info.pat_is_mut {
413 quote! {
414 #[allow(unused_mut)]
415 let mut #pat = __composer
416 .with_slot_value::<#core_path::CallbackHolder, _>(
417 #slot_ident,
418 |holder| holder.clone_rc(),
419 );
420 }
421 } else {
422 quote! {
423 #[allow(unused_mut)]
424 let #pat = __composer
425 .with_slot_value::<#core_path::CallbackHolder, _>(
426 #slot_ident,
427 |holder| holder.clone_rc(),
428 );
429 }
430 }
431 } else if info.is_impl_trait {
432 quote! {}
433 } else {
434 let pat = &info.pat;
435 let ty = &info.ty;
436 quote! {
437 let #pat = __composer
438 .with_slot_value::<#core_path::ParamState<#ty>, _>(
439 #slot_ident,
440 |state| {
441 state
442 .value()
443 .expect("composable parameter missing for recomposition")
444 },
445 );
446 }
447 }
448 })
449 .collect();
450
451 let recranpose_fn_ident = Ident::new(
452 &format!("__cranpose_recranpose_{}", func.sig.ident),
453 Span::call_site(),
454 );
455
456 let recranpose_setter = quote! {
457 {
458 __composer.set_recranpose_callback(move |
459 __composer: &#core_path::Composer|
460 {
461 #recranpose_fn_ident #ty_generics_turbofish (
462 __composer
463 );
464 });
465 }
466 };
467
468 let helper_body = quote! {
469 let __current_scope = __composer
470 .current_recranpose_scope()
471 .expect("missing recompose scope");
472 let mut __changed = __current_scope.should_recompose();
473 #(#param_setup)*
474 let __result_slot_index = __composer
475 .use_value_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
476 let __has_previous = __composer
477 .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
478 __result_slot_index,
479 |slot| slot.get().is_some(),
480 );
481 if !__changed && __has_previous {
482 __composer.skip_current_group();
483 let __result = __composer
484 .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
485 __result_slot_index,
486 |slot| {
487 slot.get()
488 .expect("composable return value missing during skip")
489 },
490 );
491 return __result;
492 }
493 let __value: #return_ty = {
494 #(#rebinds)*
495 #helper_block
496 };
497 __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
498 __result_slot_index,
499 |slot| {
500 slot.store(__value.clone());
501 },
502 );
503 #recranpose_setter
504 __value
505 };
506
507 let recranpose_fn_body = quote! {
508 #(#param_setup_recompose)*
509 let __result_slot_index = __composer
510 .use_value_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
511 #(#rebinds_for_recompose)*
512 let __value: #return_ty = {
513 #recranpose_block
514 };
515 __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
516 __result_slot_index,
517 |slot| {
518 slot.store(__value.clone());
519 },
520 );
521 #recranpose_setter
522 __value
523 };
524
525 let recranpose_fn = quote! {
526 #[allow(non_snake_case)]
527 fn #recranpose_fn_ident #impl_generics (
528 __composer: &#core_path::Composer
529 ) -> #return_ty #where_clause {
530 #recranpose_fn_body
531 }
532 };
533
534 let helper_fn = quote! {
535 #[allow(non_snake_case, clippy::too_many_arguments)]
536 fn #helper_ident #impl_generics (
537 __composer: &#core_path::Composer
538 #(, #helper_inputs)*
539 ) -> #return_ty #where_clause {
540 #helper_body
541 }
542 };
543
544 let wrapper_args: Vec<TokenStream2> = param_info
546 .iter()
547 .filter_map(|info| {
548 if info.is_impl_trait && !is_zero_arg_fn_impl_trait(&info.ty) {
549 None
550 } else {
551 let ident = &info.ident;
552 Some(quote! { #ident })
553 }
554 })
555 .collect();
556
557 let wrapped = quote!({
558 #core_path::with_current_composer(|__composer: &#core_path::Composer| {
559 __composer.with_group(#key_expr, |__composer: &#core_path::Composer| {
560 #helper_ident(__composer #(, #wrapper_args)*)
561 })
562 })
563 });
564 func.block = Box::new(syn::parse2(wrapped).expect("failed to build block"));
565 TokenStream::from(quote! {
566 #recranpose_fn
567 #helper_fn
568 #func
569 })
570 } else {
571 let wrapped = quote!({
573 #core_path::with_current_composer(|__composer: &#core_path::Composer| {
574 __composer.with_group(#key_expr, |__scope: &#core_path::Composer| {
575 #(#rebinds_for_no_skip)*
576 #original_block
577 })
578 })
579 });
580 func.block = Box::new(syn::parse2(wrapped).expect("failed to build block"));
581 TokenStream::from(quote! { #func })
582 }
583}