1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use proc_macro_crate::{crate_name, FoundCrate};
4use quote::quote;
5use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Type};
6
7fn is_fn_like_type(ty: &Type) -> bool {
10 match ty {
11 Type::ImplTrait(impl_trait) => impl_trait.bounds.iter().any(|bound| {
13 if let syn::TypeParamBound::Trait(trait_bound) = bound {
14 let path = &trait_bound.path;
15 if let Some(segment) = path.segments.last() {
16 let ident_str = segment.ident.to_string();
17 return ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce";
18 }
19 }
20 false
21 }),
22 Type::Path(type_path) => {
24 if let Some(segment) = type_path.path.segments.last() {
25 if segment.ident == "Box" {
26 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
27 if let Some(syn::GenericArgument::Type(Type::TraitObject(trait_obj))) =
28 args.args.first()
29 {
30 return trait_obj.bounds.iter().any(|bound| {
31 if let syn::TypeParamBound::Trait(trait_bound) = bound {
32 let path = &trait_bound.path;
33 if let Some(segment) = path.segments.last() {
34 let ident_str = segment.ident.to_string();
35 return ident_str == "FnMut"
36 || ident_str == "Fn"
37 || ident_str == "FnOnce";
38 }
39 }
40 false
41 });
42 }
43 }
44 }
45 }
46 false
47 }
48 Type::BareFn(_) => true,
50 _ => false,
51 }
52}
53
54fn is_generic_fn_like(ty: &Type, generics: &syn::Generics) -> bool {
56 let type_ident = match ty {
58 Type::Path(type_path) if type_path.path.segments.len() == 1 => {
59 &type_path.path.segments[0].ident
60 }
61 _ => return false,
62 };
63
64 for param in &generics.params {
66 if let syn::GenericParam::Type(type_param) = param {
67 if type_param.ident == *type_ident {
68 for bound in &type_param.bounds {
70 if let syn::TypeParamBound::Trait(trait_bound) = bound {
71 if let Some(segment) = trait_bound.path.segments.last() {
72 let ident_str = segment.ident.to_string();
73 if ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce" {
74 return true;
75 }
76 }
77 }
78 }
79 }
80 }
81 }
82
83 if let Some(where_clause) = &generics.where_clause {
85 for predicate in &where_clause.predicates {
86 if let syn::WherePredicate::Type(pred) = predicate {
87 if let Type::Path(bounded_type) = &pred.bounded_ty {
88 if bounded_type.path.segments.len() == 1
89 && bounded_type.path.segments[0].ident == *type_ident
90 {
91 for bound in &pred.bounds {
92 if let syn::TypeParamBound::Trait(trait_bound) = bound {
93 if let Some(segment) = trait_bound.path.segments.last() {
94 let ident_str = segment.ident.to_string();
95 if ident_str == "FnMut"
96 || ident_str == "Fn"
97 || ident_str == "FnOnce"
98 {
99 return true;
100 }
101 }
102 }
103 }
104 }
105 }
106 }
107 }
108 }
109
110 false
111}
112
113fn is_fn_param(ty: &Type, generics: &syn::Generics) -> bool {
115 is_fn_like_type(ty) || is_generic_fn_like(ty, generics)
116}
117
118fn core_crate_path() -> TokenStream2 {
119 let crate_name = crate_name("cranpose")
120 .ok()
121 .or_else(|| crate_name("cranpose-core").ok());
122
123 match crate_name {
124 Some(FoundCrate::Itself) => quote!(crate),
125 Some(FoundCrate::Name(name)) => {
126 let ident = Ident::new(&name, Span::call_site());
127 quote!(#ident)
128 }
129 None => quote!(cranpose_core),
130 }
131}
132
133#[proc_macro_attribute]
134pub fn composable(attr: TokenStream, item: TokenStream) -> TokenStream {
135 let attr_tokens = TokenStream2::from(attr);
136 let mut enable_skip = true;
137 let core_path = core_crate_path();
138 if !attr_tokens.is_empty() {
139 match syn::parse2::<Ident>(attr_tokens) {
140 Ok(ident) if ident == "no_skip" => enable_skip = false,
141 Ok(other) => {
142 return syn::Error::new_spanned(other, "unsupported composable attribute")
143 .to_compile_error()
144 .into();
145 }
146 Err(err) => {
147 return err.to_compile_error().into();
148 }
149 }
150 }
151
152 let mut func = parse_macro_input!(item as ItemFn);
153
154 struct ParamInfo {
155 ident: Ident,
156 pat: Box<Pat>,
157 ty: Type,
158 pat_is_mut: bool,
159 is_impl_trait: bool,
160 }
161
162 let mut param_info: Vec<ParamInfo> = Vec::new();
163
164 for (index, arg) in func.sig.inputs.iter_mut().enumerate() {
165 if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
166 let pat_is_mut = matches!(
167 pat.as_ref(),
168 Pat::Ident(pat_ident) if pat_ident.mutability.is_some()
169 );
170 let is_impl_trait = matches!(**ty, Type::ImplTrait(_));
171
172 if is_impl_trait {
173 let original_pat: Box<Pat> = pat.clone();
174 if let Pat::Ident(pat_ident) = &**pat {
175 param_info.push(ParamInfo {
176 ident: pat_ident.ident.clone(),
177 pat: original_pat,
178 ty: ty.as_ref().clone(),
179 pat_is_mut,
180 is_impl_trait: true,
181 });
182 } else {
183 param_info.push(ParamInfo {
184 ident: Ident::new(&format!("__arg{}", index), Span::call_site()),
185 pat: original_pat,
186 ty: ty.as_ref().clone(),
187 pat_is_mut,
188 is_impl_trait: true,
189 });
190 }
191 } else {
192 let ident = Ident::new(&format!("__arg{}", index), Span::call_site());
193 let original_pat: Box<Pat> = pat.clone();
194 *pat = Box::new(syn::parse_quote! { #ident });
195 param_info.push(ParamInfo {
196 ident,
197 pat: original_pat,
198 ty: ty.as_ref().clone(),
199 pat_is_mut,
200 is_impl_trait: false,
201 });
202 }
203 }
204 }
205
206 let original_block = func.block.clone();
207 let helper_block = original_block.clone();
208 let recranpose_block = original_block.clone();
209 let key_expr = quote! { #core_path::location_key(file!(), line!(), column!()) };
210
211 let rebinds_for_no_skip: Vec<_> = param_info
213 .iter()
214 .map(|info| {
215 let ident = &info.ident;
216 let pat = &info.pat;
217 quote! { let #pat = #ident; }
218 })
219 .collect();
220
221 let return_ty: syn::Type = match &func.sig.output {
222 ReturnType::Default => syn::parse_quote! { () },
223 ReturnType::Type(_, ty) => ty.as_ref().clone(),
224 };
225 let _helper_ident = Ident::new(
226 &format!("__cranpose_impl_{}", func.sig.ident),
227 Span::call_site(),
228 );
229 let generics = func.sig.generics.clone();
230 let (_impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
231
232 let _helper_inputs: Vec<TokenStream2> = param_info
233 .iter()
234 .map(|info| {
235 let ident = &info.ident;
236 let ty = &info.ty;
237 quote! { #ident: #ty }
238 })
239 .collect();
240
241 let has_impl_trait = param_info
243 .iter()
244 .any(|info| matches!(info.ty, Type::ImplTrait(_)));
245
246 if enable_skip && !has_impl_trait {
247 let helper_ident = Ident::new(
248 &format!("__cranpose_impl_{}", func.sig.ident),
249 Span::call_site(),
250 );
251 let generics = func.sig.generics.clone();
252 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
253 let ty_generics_turbofish = ty_generics.as_turbofish();
254
255 let helper_inputs: Vec<TokenStream2> = param_info
257 .iter()
258 .filter_map(|info| {
259 if info.is_impl_trait {
260 None
261 } else {
262 let ident = &info.ident;
263 let ty = &info.ty;
264 Some(quote! { #ident: #ty })
265 }
266 })
267 .collect();
268
269 let param_state_slots: Vec<Ident> = (0..param_info.len())
271 .map(|index| Ident::new(&format!("__param_state_slot{}", index), Span::call_site()))
272 .collect();
273
274 let param_setup: Vec<TokenStream2> = param_info
275 .iter()
276 .zip(param_state_slots.iter())
277 .map(|(info, slot_ident)| {
278 if info.is_impl_trait {
279 quote! { __changed = true; }
280 } else if is_fn_param(&info.ty, &generics) {
281 let ident = &info.ident;
282 quote! {
283 let #slot_ident = __composer
284 .use_value_slot(|| #core_path::CallbackHolder::new());
285 __composer.with_slot_value::<#core_path::CallbackHolder, _>(
286 #slot_ident,
287 |holder| {
288 holder.update(#ident);
289 },
290 );
291 __changed = true;
292 }
293 } else {
294 let ident = &info.ident;
295 let ty = &info.ty;
296 quote! {
297 let #slot_ident = __composer
298 .use_value_slot(|| #core_path::ParamState::<#ty>::default());
299 if __composer.with_slot_value_mut::<#core_path::ParamState<#ty>, _>(
300 #slot_ident,
301 |state| state.update(&#ident),
302 )
303 {
304 __changed = true;
305 }
306 }
307 }
308 })
309 .collect();
310
311 let param_setup_recompose: Vec<TokenStream2> = param_info
312 .iter()
313 .zip(param_state_slots.iter())
314 .map(|(info, slot_ident)| {
315 if info.is_impl_trait {
316 quote! {}
317 } else if is_fn_param(&info.ty, &generics) {
318 quote! {
319 let #slot_ident = __composer
320 .use_value_slot(|| #core_path::CallbackHolder::new());
321 }
322 } else {
323 let ty = &info.ty;
324 quote! {
325 let #slot_ident = __composer
326 .use_value_slot(|| #core_path::ParamState::<#ty>::default());
327 }
328 }
329 })
330 .collect();
331
332 let rebinds: Vec<TokenStream2> = param_info
333 .iter()
334 .zip(param_state_slots.iter())
335 .map(|(info, slot_ident)| {
336 if info.is_impl_trait {
337 quote! {}
338 } else if is_fn_param(&info.ty, &generics) {
339 let pat = &info.pat;
340 let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
341 if can_add_mut && !info.pat_is_mut {
342 quote! {
343 #[allow(unused_mut)]
344 let mut #pat = __composer
345 .with_slot_value::<#core_path::CallbackHolder, _>(
346 #slot_ident,
347 |holder| holder.clone_rc(),
348 );
349 }
350 } else {
351 quote! {
352 #[allow(unused_mut)]
353 let #pat = __composer
354 .with_slot_value::<#core_path::CallbackHolder, _>(
355 #slot_ident,
356 |holder| holder.clone_rc(),
357 );
358 }
359 }
360 } else {
361 let pat = &info.pat;
362 let ident = &info.ident;
363 quote! {
364 let #pat = #ident;
365 }
366 }
367 })
368 .collect();
369
370 let rebinds_for_recompose: Vec<TokenStream2> = param_info
371 .iter()
372 .zip(param_state_slots.iter())
373 .map(|(info, slot_ident)| {
374 if info.is_impl_trait {
375 quote! {}
376 } else if is_fn_param(&info.ty, &generics) {
377 let pat = &info.pat;
378 let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
379 if can_add_mut && !info.pat_is_mut {
380 quote! {
381 #[allow(unused_mut)]
382 let mut #pat = __composer
383 .with_slot_value::<#core_path::CallbackHolder, _>(
384 #slot_ident,
385 |holder| holder.clone_rc(),
386 );
387 }
388 } else {
389 quote! {
390 #[allow(unused_mut)]
391 let #pat = __composer
392 .with_slot_value::<#core_path::CallbackHolder, _>(
393 #slot_ident,
394 |holder| holder.clone_rc(),
395 );
396 }
397 }
398 } else {
399 let pat = &info.pat;
400 let ty = &info.ty;
401 quote! {
402 let #pat = __composer
403 .with_slot_value::<#core_path::ParamState<#ty>, _>(
404 #slot_ident,
405 |state| {
406 state
407 .value()
408 .expect("composable parameter missing for recomposition")
409 },
410 );
411 }
412 }
413 })
414 .collect();
415
416 let recranpose_fn_ident = Ident::new(
417 &format!("__cranpose_recranpose_{}", func.sig.ident),
418 Span::call_site(),
419 );
420
421 let recranpose_setter = quote! {
422 {
423 __composer.set_recranpose_callback(move |
424 __composer: &#core_path::Composer|
425 {
426 #recranpose_fn_ident #ty_generics_turbofish (
427 __composer
428 );
429 });
430 }
431 };
432
433 let helper_body = quote! {
434 let __current_scope = __composer
435 .current_recranpose_scope()
436 .expect("missing recompose scope");
437 let mut __changed = __current_scope.should_recompose();
438 #(#param_setup)*
439 let __result_slot_index = __composer
440 .use_value_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
441 let __has_previous = __composer
442 .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
443 __result_slot_index,
444 |slot| slot.get().is_some(),
445 );
446 if !__changed && __has_previous {
447 __composer.skip_current_group();
448 let __result = __composer
449 .with_slot_value::<#core_path::ReturnSlot<#return_ty>, _>(
450 __result_slot_index,
451 |slot| {
452 slot.get()
453 .expect("composable return value missing during skip")
454 },
455 );
456 return __result;
457 }
458 let __value: #return_ty = {
459 #(#rebinds)*
460 #helper_block
461 };
462 __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
463 __result_slot_index,
464 |slot| {
465 slot.store(__value.clone());
466 },
467 );
468 #recranpose_setter
469 __value
470 };
471
472 let recranpose_fn_body = quote! {
473 #(#param_setup_recompose)*
474 let __result_slot_index = __composer
475 .use_value_slot(|| #core_path::ReturnSlot::<#return_ty>::default());
476 #(#rebinds_for_recompose)*
477 let __value: #return_ty = {
478 #recranpose_block
479 };
480 __composer.with_slot_value_mut::<#core_path::ReturnSlot<#return_ty>, _>(
481 __result_slot_index,
482 |slot| {
483 slot.store(__value.clone());
484 },
485 );
486 #recranpose_setter
487 __value
488 };
489
490 let recranpose_fn = quote! {
491 #[allow(non_snake_case)]
492 fn #recranpose_fn_ident #impl_generics (
493 __composer: &#core_path::Composer
494 ) -> #return_ty #where_clause {
495 #recranpose_fn_body
496 }
497 };
498
499 let helper_fn = quote! {
500 #[allow(non_snake_case)]
501 fn #helper_ident #impl_generics (
502 __composer: &#core_path::Composer
503 #(, #helper_inputs)*
504 ) -> #return_ty #where_clause {
505 #helper_body
506 }
507 };
508
509 let wrapper_args: Vec<TokenStream2> = param_info
511 .iter()
512 .filter_map(|info| {
513 if info.is_impl_trait {
514 None
515 } else {
516 let ident = &info.ident;
517 Some(quote! { #ident })
518 }
519 })
520 .collect();
521
522 let wrapped = quote!({
523 #core_path::with_current_composer(|__composer: &#core_path::Composer| {
524 __composer.with_group(#key_expr, |__composer: &#core_path::Composer| {
525 #helper_ident(__composer #(, #wrapper_args)*)
526 })
527 })
528 });
529 func.block = Box::new(syn::parse2(wrapped).expect("failed to build block"));
530 TokenStream::from(quote! {
531 #recranpose_fn
532 #helper_fn
533 #func
534 })
535 } else {
536 let wrapped = quote!({
538 #core_path::with_current_composer(|__composer: &#core_path::Composer| {
539 __composer.with_group(#key_expr, |__scope: &#core_path::Composer| {
540 #(#rebinds_for_no_skip)*
541 #original_block
542 })
543 })
544 });
545 func.block = Box::new(syn::parse2(wrapped).expect("failed to build block"));
546 TokenStream::from(quote! { #func })
547 }
548}