1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::quote;
4use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Type};
5
6fn is_fn_like_type(ty: &Type) -> bool {
9 match ty {
10 Type::ImplTrait(impl_trait) => impl_trait.bounds.iter().any(|bound| {
12 if let syn::TypeParamBound::Trait(trait_bound) = bound {
13 let path = &trait_bound.path;
14 if let Some(segment) = path.segments.last() {
15 let ident_str = segment.ident.to_string();
16 return ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce";
17 }
18 }
19 false
20 }),
21 Type::Path(type_path) => {
23 if let Some(segment) = type_path.path.segments.last() {
24 if segment.ident == "Box" {
25 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
26 if let Some(syn::GenericArgument::Type(Type::TraitObject(trait_obj))) =
27 args.args.first()
28 {
29 return trait_obj.bounds.iter().any(|bound| {
30 if let syn::TypeParamBound::Trait(trait_bound) = bound {
31 let path = &trait_bound.path;
32 if let Some(segment) = path.segments.last() {
33 let ident_str = segment.ident.to_string();
34 return ident_str == "FnMut"
35 || ident_str == "Fn"
36 || ident_str == "FnOnce";
37 }
38 }
39 false
40 });
41 }
42 }
43 }
44 }
45 false
46 }
47 Type::BareFn(_) => true,
49 _ => false,
50 }
51}
52
53fn is_generic_fn_like(ty: &Type, generics: &syn::Generics) -> bool {
55 let type_ident = match ty {
57 Type::Path(type_path) if type_path.path.segments.len() == 1 => {
58 &type_path.path.segments[0].ident
59 }
60 _ => return false,
61 };
62
63 for param in &generics.params {
65 if let syn::GenericParam::Type(type_param) = param {
66 if type_param.ident == *type_ident {
67 for bound in &type_param.bounds {
69 if let syn::TypeParamBound::Trait(trait_bound) = bound {
70 if let Some(segment) = trait_bound.path.segments.last() {
71 let ident_str = segment.ident.to_string();
72 if ident_str == "FnMut" || ident_str == "Fn" || ident_str == "FnOnce" {
73 return true;
74 }
75 }
76 }
77 }
78 }
79 }
80 }
81
82 if let Some(where_clause) = &generics.where_clause {
84 for predicate in &where_clause.predicates {
85 if let syn::WherePredicate::Type(pred) = predicate {
86 if let Type::Path(bounded_type) = &pred.bounded_ty {
87 if bounded_type.path.segments.len() == 1
88 && bounded_type.path.segments[0].ident == *type_ident
89 {
90 for bound in &pred.bounds {
91 if let syn::TypeParamBound::Trait(trait_bound) = bound {
92 if let Some(segment) = trait_bound.path.segments.last() {
93 let ident_str = segment.ident.to_string();
94 if ident_str == "FnMut"
95 || ident_str == "Fn"
96 || ident_str == "FnOnce"
97 {
98 return true;
99 }
100 }
101 }
102 }
103 }
104 }
105 }
106 }
107 }
108
109 false
110}
111
112fn is_fn_param(ty: &Type, generics: &syn::Generics) -> bool {
114 is_fn_like_type(ty) || is_generic_fn_like(ty, generics)
115}
116
117#[proc_macro_attribute]
118pub fn composable(attr: TokenStream, item: TokenStream) -> TokenStream {
119 let attr_tokens = TokenStream2::from(attr);
120 let mut enable_skip = true;
121 if !attr_tokens.is_empty() {
122 match syn::parse2::<Ident>(attr_tokens) {
123 Ok(ident) if ident == "no_skip" => enable_skip = false,
124 Ok(other) => {
125 return syn::Error::new_spanned(other, "unsupported composable attribute")
126 .to_compile_error()
127 .into();
128 }
129 Err(err) => {
130 return err.to_compile_error().into();
131 }
132 }
133 }
134
135 let mut func = parse_macro_input!(item as ItemFn);
136
137 struct ParamInfo {
138 ident: Ident,
139 pat: Box<Pat>,
140 ty: Type,
141 pat_is_mut: bool,
142 is_impl_trait: bool,
143 }
144
145 let mut param_info: Vec<ParamInfo> = Vec::new();
146
147 for (index, arg) in func.sig.inputs.iter_mut().enumerate() {
148 if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
149 let pat_is_mut = matches!(
150 pat.as_ref(),
151 Pat::Ident(pat_ident) if pat_ident.mutability.is_some()
152 );
153 let is_impl_trait = matches!(**ty, Type::ImplTrait(_));
154
155 if is_impl_trait {
156 let original_pat: Box<Pat> = pat.clone();
157 if let Pat::Ident(pat_ident) = &**pat {
158 param_info.push(ParamInfo {
159 ident: pat_ident.ident.clone(),
160 pat: original_pat,
161 ty: ty.as_ref().clone(),
162 pat_is_mut,
163 is_impl_trait: true,
164 });
165 } else {
166 param_info.push(ParamInfo {
167 ident: Ident::new(&format!("__arg{}", index), Span::call_site()),
168 pat: original_pat,
169 ty: ty.as_ref().clone(),
170 pat_is_mut,
171 is_impl_trait: true,
172 });
173 }
174 } else {
175 let ident = Ident::new(&format!("__arg{}", index), Span::call_site());
176 let original_pat: Box<Pat> = pat.clone();
177 *pat = Box::new(syn::parse_quote! { #ident });
178 param_info.push(ParamInfo {
179 ident,
180 pat: original_pat,
181 ty: ty.as_ref().clone(),
182 pat_is_mut,
183 is_impl_trait: false,
184 });
185 }
186 }
187 }
188
189 let original_block = func.block.clone();
190 let helper_block = original_block.clone();
191 let recranpose_block = original_block.clone();
192 let key_expr = quote! { cranpose_core::location_key(file!(), line!(), column!()) };
193
194 let rebinds_for_no_skip: Vec<_> = param_info
196 .iter()
197 .map(|info| {
198 let ident = &info.ident;
199 let pat = &info.pat;
200 quote! { let #pat = #ident; }
201 })
202 .collect();
203
204 let return_ty: syn::Type = match &func.sig.output {
205 ReturnType::Default => syn::parse_quote! { () },
206 ReturnType::Type(_, ty) => ty.as_ref().clone(),
207 };
208 let _helper_ident = Ident::new(
209 &format!("__cranpose_impl_{}", func.sig.ident),
210 Span::call_site(),
211 );
212 let generics = func.sig.generics.clone();
213 let (_impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
214
215 let _helper_inputs: Vec<TokenStream2> = param_info
216 .iter()
217 .map(|info| {
218 let ident = &info.ident;
219 let ty = &info.ty;
220 quote! { #ident: #ty }
221 })
222 .collect();
223
224 let has_impl_trait = param_info
226 .iter()
227 .any(|info| matches!(info.ty, Type::ImplTrait(_)));
228
229 if enable_skip && !has_impl_trait {
230 let helper_ident = Ident::new(
231 &format!("__cranpose_impl_{}", func.sig.ident),
232 Span::call_site(),
233 );
234 let generics = func.sig.generics.clone();
235 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
236 let ty_generics_turbofish = ty_generics.as_turbofish();
237
238 let helper_inputs: Vec<TokenStream2> = param_info
240 .iter()
241 .filter_map(|info| {
242 if info.is_impl_trait {
243 None
244 } else {
245 let ident = &info.ident;
246 let ty = &info.ty;
247 Some(quote! { #ident: #ty })
248 }
249 })
250 .collect();
251
252 let param_state_slots: Vec<Ident> = (0..param_info.len())
254 .map(|index| Ident::new(&format!("__param_state_slot{}", index), Span::call_site()))
255 .collect();
256
257 let param_setup: Vec<TokenStream2> = param_info
258 .iter()
259 .zip(param_state_slots.iter())
260 .map(|(info, slot_ident)| {
261 if info.is_impl_trait {
262 quote! { __changed = true; }
263 } else if is_fn_param(&info.ty, &generics) {
264 let ident = &info.ident;
265 quote! {
266 let #slot_ident = __composer
267 .use_value_slot(|| cranpose_core::CallbackHolder::new());
268 __composer.with_slot_value::<cranpose_core::CallbackHolder, _>(
269 #slot_ident,
270 |holder| {
271 holder.update(#ident);
272 },
273 );
274 __changed = true;
275 }
276 } else {
277 let ident = &info.ident;
278 let ty = &info.ty;
279 quote! {
280 let #slot_ident = __composer
281 .use_value_slot(|| cranpose_core::ParamState::<#ty>::default());
282 if __composer.with_slot_value_mut::<cranpose_core::ParamState<#ty>, _>(
283 #slot_ident,
284 |state| state.update(&#ident),
285 )
286 {
287 __changed = true;
288 }
289 }
290 }
291 })
292 .collect();
293
294 let param_setup_recompose: Vec<TokenStream2> = param_info
295 .iter()
296 .zip(param_state_slots.iter())
297 .map(|(info, slot_ident)| {
298 if info.is_impl_trait {
299 quote! {}
300 } else if is_fn_param(&info.ty, &generics) {
301 quote! {
302 let #slot_ident = __composer
303 .use_value_slot(|| cranpose_core::CallbackHolder::new());
304 }
305 } else {
306 let ty = &info.ty;
307 quote! {
308 let #slot_ident = __composer
309 .use_value_slot(|| cranpose_core::ParamState::<#ty>::default());
310 }
311 }
312 })
313 .collect();
314
315 let rebinds: Vec<TokenStream2> = param_info
316 .iter()
317 .zip(param_state_slots.iter())
318 .map(|(info, slot_ident)| {
319 if info.is_impl_trait {
320 quote! {}
321 } else if is_fn_param(&info.ty, &generics) {
322 let pat = &info.pat;
323 let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
324 if can_add_mut && !info.pat_is_mut {
325 quote! {
326 #[allow(unused_mut)]
327 let mut #pat = __composer
328 .with_slot_value::<cranpose_core::CallbackHolder, _>(
329 #slot_ident,
330 |holder| holder.clone_rc(),
331 );
332 }
333 } else {
334 quote! {
335 #[allow(unused_mut)]
336 let #pat = __composer
337 .with_slot_value::<cranpose_core::CallbackHolder, _>(
338 #slot_ident,
339 |holder| holder.clone_rc(),
340 );
341 }
342 }
343 } else {
344 let pat = &info.pat;
345 let ident = &info.ident;
346 quote! {
347 let #pat = #ident;
348 }
349 }
350 })
351 .collect();
352
353 let rebinds_for_recompose: Vec<TokenStream2> = param_info
354 .iter()
355 .zip(param_state_slots.iter())
356 .map(|(info, slot_ident)| {
357 if info.is_impl_trait {
358 quote! {}
359 } else if is_fn_param(&info.ty, &generics) {
360 let pat = &info.pat;
361 let can_add_mut = matches!(pat.as_ref(), Pat::Ident(_));
362 if can_add_mut && !info.pat_is_mut {
363 quote! {
364 #[allow(unused_mut)]
365 let mut #pat = __composer
366 .with_slot_value::<cranpose_core::CallbackHolder, _>(
367 #slot_ident,
368 |holder| holder.clone_rc(),
369 );
370 }
371 } else {
372 quote! {
373 #[allow(unused_mut)]
374 let #pat = __composer
375 .with_slot_value::<cranpose_core::CallbackHolder, _>(
376 #slot_ident,
377 |holder| holder.clone_rc(),
378 );
379 }
380 }
381 } else {
382 let pat = &info.pat;
383 let ty = &info.ty;
384 quote! {
385 let #pat = __composer
386 .with_slot_value::<cranpose_core::ParamState<#ty>, _>(
387 #slot_ident,
388 |state| {
389 state
390 .value()
391 .expect("composable parameter missing for recomposition")
392 },
393 );
394 }
395 }
396 })
397 .collect();
398
399 let recranpose_fn_ident = Ident::new(
400 &format!("__cranpose_recranpose_{}", func.sig.ident),
401 Span::call_site(),
402 );
403
404 let recranpose_setter = quote! {
405 {
406 __composer.set_recranpose_callback(move |
407 __composer: &cranpose_core::Composer|
408 {
409 #recranpose_fn_ident #ty_generics_turbofish (
410 __composer
411 );
412 });
413 }
414 };
415
416 let helper_body = quote! {
417 let __current_scope = __composer
418 .current_recranpose_scope()
419 .expect("missing recompose scope");
420 let mut __changed = __current_scope.should_recompose();
421 #(#param_setup)*
422 let __result_slot_index = __composer
423 .use_value_slot(|| cranpose_core::ReturnSlot::<#return_ty>::default());
424 let __has_previous = __composer
425 .with_slot_value::<cranpose_core::ReturnSlot<#return_ty>, _>(
426 __result_slot_index,
427 |slot| slot.get().is_some(),
428 );
429 if !__changed && __has_previous {
430 __composer.skip_current_group();
431 let __result = __composer
432 .with_slot_value::<cranpose_core::ReturnSlot<#return_ty>, _>(
433 __result_slot_index,
434 |slot| {
435 slot.get()
436 .expect("composable return value missing during skip")
437 },
438 );
439 return __result;
440 }
441 let __value: #return_ty = {
442 #(#rebinds)*
443 #helper_block
444 };
445 __composer.with_slot_value_mut::<cranpose_core::ReturnSlot<#return_ty>, _>(
446 __result_slot_index,
447 |slot| {
448 slot.store(__value.clone());
449 },
450 );
451 #recranpose_setter
452 __value
453 };
454
455 let recranpose_fn_body = quote! {
456 #(#param_setup_recompose)*
457 let __result_slot_index = __composer
458 .use_value_slot(|| cranpose_core::ReturnSlot::<#return_ty>::default());
459 #(#rebinds_for_recompose)*
460 let __value: #return_ty = {
461 #recranpose_block
462 };
463 __composer.with_slot_value_mut::<cranpose_core::ReturnSlot<#return_ty>, _>(
464 __result_slot_index,
465 |slot| {
466 slot.store(__value.clone());
467 },
468 );
469 #recranpose_setter
470 __value
471 };
472
473 let recranpose_fn = quote! {
474 #[allow(non_snake_case)]
475 fn #recranpose_fn_ident #impl_generics (
476 __composer: &cranpose_core::Composer
477 ) -> #return_ty #where_clause {
478 #recranpose_fn_body
479 }
480 };
481
482 let helper_fn = quote! {
483 #[allow(non_snake_case)]
484 fn #helper_ident #impl_generics (
485 __composer: &cranpose_core::Composer
486 #(, #helper_inputs)*
487 ) -> #return_ty #where_clause {
488 #helper_body
489 }
490 };
491
492 let wrapper_args: Vec<TokenStream2> = param_info
494 .iter()
495 .filter_map(|info| {
496 if info.is_impl_trait {
497 None
498 } else {
499 let ident = &info.ident;
500 Some(quote! { #ident })
501 }
502 })
503 .collect();
504
505 let wrapped = quote!({
506 cranpose_core::with_current_composer(|__composer: &cranpose_core::Composer| {
507 __composer.with_group(#key_expr, |__composer: &cranpose_core::Composer| {
508 #helper_ident(__composer #(, #wrapper_args)*)
509 })
510 })
511 });
512 func.block = Box::new(syn::parse2(wrapped).expect("failed to build block"));
513 TokenStream::from(quote! {
514 #recranpose_fn
515 #helper_fn
516 #func
517 })
518 } else {
519 let wrapped = quote!({
521 cranpose_core::with_current_composer(|__composer: &cranpose_core::Composer| {
522 __composer.with_group(#key_expr, |__scope: &cranpose_core::Composer| {
523 #(#rebinds_for_no_skip)*
524 #original_block
525 })
526 })
527 });
528 func.block = Box::new(syn::parse2(wrapped).expect("failed to build block"));
529 TokenStream::from(quote! { #func })
530 }
531}