1use std::borrow::Cow;
2
3use darling::FromMeta;
4use proc_macro2::Span;
5use quote::{quote, quote_spanned};
6use syn::{parse_quote_spanned, spanned::Spanned};
7
8use crate::{
9 detect_hooks, detected_hooks_to_tokens,
10 utils::{
11 chain::Chain,
12 either::Either,
13 empty_or_trailing::AutoEmptyOrTrailing,
14 group::{angled, parened},
15 map::map_to_tokens,
16 path_or_lit::PathOrLit,
17 phantom::{make_phantom_or_ref, PhantomOfTy},
18 repeat::Repeat,
19 type_generics::TypeGenericsWithoutBraces,
20 },
21 DetectedHooksTokens,
22};
23
24pub type GenericParams = syn::punctuated::Punctuated<syn::GenericParam, syn::Token![,]>;
25
26#[cfg_attr(feature = "extra-traits", derive(PartialEq, Eq))]
27#[derive(Debug, Default, FromMeta)]
28#[non_exhaustive]
29#[darling(default)]
30pub struct HookArgs {
31 pub hooks_core_path: Option<PathOrLit<syn::Path>>,
33
34 pub custom_bounds: Option<syn::Type>,
58
59 pub args_generics: GenericParams,
63}
64
65impl HookArgs {
66 #[inline]
67 pub fn transform_item_fn(
68 self,
69 mut item_fn: syn::ItemFn,
70 ) -> (syn::ItemFn, Option<darling::Error>) {
71 let error = self.transform_item_fn_in_place(&mut item_fn);
72 (item_fn, error)
73 }
74
75 pub fn transform_item_fn_in_place(
76 mut self,
77 item_fn: &mut syn::ItemFn,
78 ) -> Option<darling::Error> {
79 let mut errors = darling::error::Accumulator::default();
80
81 let hooks_core_path = self.hooks_core_path.map_or_else(
82 || syn::Path {
83 leading_colon: Some(Default::default()),
84 segments: syn::punctuated::Punctuated::from_iter([
85 syn::PathSegment::from(syn::Ident::new("hooks", Span::call_site())),
86 syn::PathSegment::from(syn::Ident::new("core", Span::call_site())),
87 ]),
88 },
89 PathOrLit::unwrap,
90 );
91
92 let sig = &mut item_fn.sig;
93
94 let span_fn_name = sig.ident.span();
95
96 let (hook_args_pat, mut hook_args_ty) = {
100 let hook_args = std::mem::take(&mut sig.inputs);
101
102 let paren_token = syn::token::Paren(span_fn_name);
103
104 let (hook_args_pat, hook_args_ty) = hook_args
105 .into_pairs()
106 .into_iter()
107 .map(|pair| {
108 let (arg, comma) = pair.into_tuple();
109 let comma = comma.unwrap_or_else(|| syn::Token));
110
111 let (pat, ty) = match arg {
112 syn::FnArg::Receiver(syn::Receiver {
113 attrs,
114 reference,
115 mutability,
116 self_token,
117 }) => {
118 let self_type = syn::Type::Path(syn::TypePath {
123 qself: None,
124 path: syn::Token.into(),
125 });
126
127 if let Some((and_token, lifetime)) = reference {
128 let ty = syn::Type::Reference(syn::TypeReference {
129 and_token,
130 lifetime,
131 mutability,
132 elem: Box::new(self_type),
133 });
134 let pat = syn::Pat::Ident(syn::PatIdent {
135 attrs,
136 by_ref: None,
137 mutability: None,
138 ident: self_token.into(),
139 subpat: None,
140 });
141 (pat, ty)
142 } else {
143 (
144 syn::Pat::Ident(syn::PatIdent {
145 attrs,
146 by_ref: None,
147 mutability,
148 ident: self_token.into(),
149 subpat: None,
150 }),
151 self_type,
152 )
153 }
154 }
155 syn::FnArg::Typed(pat_ty) => {
156 for attr in pat_ty.attrs {
157 errors.push(
158 darling::Error::custom(
159 "arguments of hook cannot have attributes",
160 )
161 .with_span(&attr),
162 );
163 }
164 (*pat_ty.pat, *pat_ty.ty)
165 }
166 };
167
168 (
169 syn::punctuated::Pair::Punctuated(pat, comma),
170 syn::punctuated::Pair::Punctuated(ty, comma),
171 )
172 })
173 .unzip();
174
175 let hook_args_pat = syn::PatTuple {
176 attrs: vec![],
177 paren_token,
178 elems: hook_args_pat,
179 };
180
181 let hook_args_ty = syn::TypeTuple {
182 paren_token,
183 elems: hook_args_ty,
184 };
185
186 (hook_args_pat, hook_args_ty)
187 };
188
189 crate::utils::elided_args_generics::auto_fill_lifetimes(
190 &mut self.args_generics,
191 &mut hook_args_ty.elems,
192 );
193
194 let args_lifetimes = &self.args_generics;
195
196 let args_lifetimes_empty = args_lifetimes.is_empty();
197
198 if !args_lifetimes_empty {
199 for g in self.args_generics.iter() {
200 match g {
201 syn::GenericParam::Lifetime(_) => {}
202 _ => errors.push(
203 darling::Error::custom(
204 "Currently args_generics only supports lifetimes without bounds",
205 )
206 .with_span(&g),
207 ),
208 }
209 }
210 }
211
212 let generics = &sig.generics;
213
214 let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
215
216 let default_hook_bounds_fields_eot = map_to_tokens(&generics.params, |params| {
217 params.pairs().filter_map(|p| {
218 make_phantom_or_ref(p.value()).map(|v| {
219 Chain(
220 v,
221 p.punct()
222 .map_or_else(|| Cow::Owned(Default::default()), |v| Cow::Borrowed(*v)),
223 )
224 })
225 })
226 });
227
228 let hook_bounds = self.custom_bounds.as_ref().map_or_else(
229 || Either::A(parened(&default_hook_bounds_fields_eot)),
230 |ty| Either::B(ty),
231 );
232
233 let mut output_ty: syn::Type = {
234 let fn_rt = &mut sig.output;
235 let (ra, output_ty) = match std::mem::replace(fn_rt, syn::ReturnType::Default) {
236 syn::ReturnType::Default => {
237 let span = fn_rt.span();
238 (
239 syn::Token,
240 syn::Type::Tuple(syn::TypeTuple {
241 paren_token: syn::token::Paren(span),
242 elems: Default::default(),
243 }),
244 )
245 }
246 syn::ReturnType::Type(ra, ty) => (ra, *ty),
247 };
248
249 let (for_hook, for_lifetimes) = if args_lifetimes_empty {
250 (None, None)
251 } else {
252 (
253 Some(
254 Chain(syn::Token, syn::Token)
255 .chain(args_lifetimes)
256 .chain(syn::Token),
257 ),
258 Some(Chain(syn::Token, args_lifetimes)),
259 )
260 };
261
262 let return_ty = parse_quote_spanned! { span_fn_name =>
263 impl #for_hook #hooks_core_path ::Hook<#hook_args_ty>
264 + for<'hook #for_lifetimes> #hooks_core_path ::HookLifetime<
265 'hook,
266 #hook_args_ty,
267 &'hook #hook_bounds,
268 Value = #output_ty
269 >
270 + #hooks_core_path ::HookBounds<Bounds = #hook_bounds>
271 };
272
273 *fn_rt = syn::ReturnType::Type(ra, return_ty);
274
275 output_ty
276 };
277
278 let fn_type_generics_eot = AutoEmptyOrTrailing(TypeGenericsWithoutBraces(&generics.params));
280
281 let it_impl_generics_eot = extract_impl_trait_as_type_params(&mut output_ty);
284
285 let it_type_generics_eot = map_to_tokens(&it_impl_generics_eot, |v| {
287 v.iter().map(|pair| Chain(&pair.0.ident, &pair.1))
288 });
289
290 let hook_types_phantom;
292 let hook_types_impl_generics;
294 let hook_types_type_generics;
296 let it_generics_elided_without_braces_eot;
298
299 if it_impl_generics_eot.is_empty() {
304 hook_types_phantom = Either::A(&hook_bounds);
305 hook_types_impl_generics = Either::A(impl_generics);
306 hook_types_type_generics = Either::A(&type_generics);
307 it_generics_elided_without_braces_eot = None;
308 } else {
310 hook_types_phantom = Either::B(parened(Chain(
311 &default_hook_bounds_fields_eot,
312 map_to_tokens(&it_impl_generics_eot, |v| {
313 v.iter()
314 .map(|pair| Chain(PhantomOfTy(&pair.0.ident), pair.1))
315 }),
316 )));
317
318 hook_types_impl_generics = Either::B(angled(Chain(
319 AutoEmptyOrTrailing(&sig.generics.params),
320 map_to_tokens(&it_impl_generics_eot, |v| v.iter()),
321 )));
322
323 hook_types_type_generics =
324 Either::B(angled(Chain(&fn_type_generics_eot, &it_type_generics_eot)));
325
326 it_generics_elided_without_braces_eot = Some(Repeat(
327 Chain(<syn::Token![_]>::default(), <syn::Token![,]>::default()),
328 it_impl_generics_eot.len(),
329 ));
330
331 };
353
354 let fn_impl_generics_without_braces_eot = AutoEmptyOrTrailing(&sig.generics.params);
357
358 let mut impl_use_hook = std::mem::take(&mut item_fn.block.stmts);
359
360 let used_hooks = detect_hooks(impl_use_hook.iter_mut(), &hooks_core_path);
361
362 let impl_poll_next_update = if used_hooks.is_empty() {
363 quote_spanned! { span_fn_name =>
364 #hooks_core_path ::fn_hook::poll_next_update_ready_false
365 }
366 } else {
367 quote_spanned! { span_fn_name =>
368 #hooks_core_path ::HookPollNextUpdate::poll_next_update
369 }
370 };
371
372 let DetectedHooksTokens {
373 data_expr: expr_hooks_data,
374 fn_arg_data_pat: arg_hooks_data,
375 fn_stmts_extract_data: impl_extract_hooks_data,
376 } = detected_hooks_to_tokens(
377 used_hooks,
378 &hooks_core_path,
379 quote!(()),
380 Some(quote!(())),
381 sig.fn_token.span,
382 );
383
384 let (args_generics_for_hook_lifetime_eot, stmt_ret) = if args_lifetimes_empty {
385 let stmt_ret: syn::Expr = parse_quote_spanned! { span_fn_name =>
386 #hooks_core_path ::fn_hook::new_fn_hook::<
387 #hook_args_ty,
388 _,
389 __HookTypes <#fn_type_generics_eot #it_generics_elided_without_braces_eot>
390 >(
391 #expr_hooks_data,
392 #impl_poll_next_update,
393 |#arg_hooks_data, #hook_args_pat : #hook_args_ty| {
394 #impl_extract_hooks_data
395
396 #(#impl_use_hook)*
397 }
398 )
399 };
400
401 (None, stmt_ret)
402 } else {
403 let stmt_ret: syn::Expr = parse_quote_spanned! { span_fn_name =>
404 {
405 #[inline]
406 fn _hooks_def_fn_hook<
407 #fn_impl_generics_without_braces_eot
408 #(#it_impl_generics_eot)*
409 __HooksData,
410 __HooksPoll: ::core::ops::Fn(::core::pin::Pin<&mut __HooksData>, &mut ::core::task::Context) -> ::core::task::Poll<::core::primitive::bool>,
411 __HooksUseHook: for<'hook, #args_lifetimes> ::core::ops::Fn(::core::pin::Pin<&'hook mut __HooksData>, #hook_args_ty) -> #output_ty,
412 >(
413 hooks_data: __HooksData,
414 hooks_poll: __HooksPoll,
415 hooks_use_hook: __HooksUseHook
416 ) -> #hooks_core_path ::fn_hook::FnHook::<__HooksData, __HooksPoll, __HooksUseHook, __HookTypes #hook_types_type_generics> #where_clause {
417 #hooks_core_path ::fn_hook::FnHook::<__HooksData, __HooksPoll, __HooksUseHook, __HookTypes #hook_types_type_generics>::new(
418 hooks_data,
419 hooks_poll,
420 hooks_use_hook
421 )
422 }
423
424 _hooks_def_fn_hook::<
425 #fn_type_generics_eot
426 #it_generics_elided_without_braces_eot
427 _, _, _
428 >(
429 #expr_hooks_data,
430 #impl_poll_next_update,
431 |#arg_hooks_data, #hook_args_pat| {
432 #impl_extract_hooks_data
433
434 #(#impl_use_hook)*
435 },
436 )
437 }
438 };
439
440 (Some(AutoEmptyOrTrailing(self.args_generics)), stmt_ret)
441 };
442
443 item_fn.block.stmts = parse_quote_spanned! { span_fn_name =>
444 struct __HookTypes #hook_types_impl_generics #where_clause {
445 __: ::core::marker::PhantomData< #hook_types_phantom >
446 }
447
448 impl #hook_types_impl_generics #hooks_core_path ::HookBounds for __HookTypes #hook_types_type_generics #where_clause {
449 type Bounds = #hook_bounds;
450 }
451
452 impl <
453 'hook,
454 #args_generics_for_hook_lifetime_eot
455 #fn_impl_generics_without_braces_eot
456 #(#it_impl_generics_eot)*
457 > #hooks_core_path ::HookLifetime<'hook, #hook_args_ty, &'hook #hook_bounds>
458 for __HookTypes #hook_types_type_generics #where_clause
459 {
460 type Value = #output_ty;
461 }
462 };
463
464 item_fn.block.stmts.push(syn::Stmt::Expr(stmt_ret));
465
466 errors.finish().err()
467 }
468
469 pub fn from_punctuated_meta_list(
470 meta_list: syn::punctuated::Punctuated<syn::NestedMeta, syn::Token![,]>,
471 ) -> darling::Result<Self> {
472 let args: Vec<syn::NestedMeta> = meta_list.into_iter().collect();
473 Self::from_list(&args)
474 }
475
476 pub fn with_args_generics(mut self, args_generics: GenericParams) -> Self {
477 self.args_generics = args_generics;
478 self
479 }
480}
481
482fn replace_impl_trait_in_type(
483 ty: &mut syn::Type,
484 f: &mut impl FnMut(&mut syn::TypeImplTrait) -> syn::Type,
485) {
486 match ty {
487 syn::Type::Array(ta) => replace_impl_trait_in_type(&mut ta.elem, f),
488 syn::Type::BareFn(_) => {}
489 syn::Type::Group(g) => replace_impl_trait_in_type(&mut g.elem, f),
490 syn::Type::ImplTrait(it) => {
491 *ty = f(it)
495 }
496 syn::Type::Infer(_) => {}
497 syn::Type::Macro(_) => {}
498 syn::Type::Never(_) => {}
499 syn::Type::Paren(p) => {
500 let is_impl_trait = matches!(&*p.elem, syn::Type::ImplTrait(_));
501 replace_impl_trait_in_type(&mut p.elem, f);
502
503 if is_impl_trait {
505 let new_ty =
506 std::mem::replace(&mut *p.elem, syn::Type::Verbatim(Default::default()));
507 *ty = new_ty;
508 }
509 }
510 syn::Type::Path(tp) => {
511 if let Some(qself) = &mut tp.qself {
512 replace_impl_trait_in_type(&mut qself.ty, f);
513 }
514 for seg in tp.path.segments.iter_mut() {
515 match &mut seg.arguments {
516 syn::PathArguments::None => {}
517 syn::PathArguments::AngleBracketed(a) => {
518 for arg in a.args.iter_mut() {
519 match arg {
520 syn::GenericArgument::Lifetime(_) => {}
521 syn::GenericArgument::Type(ty) => {
522 replace_impl_trait_in_type(ty, f);
523 }
524 syn::GenericArgument::Const(_) => {}
525 syn::GenericArgument::Binding(b) => {
526 replace_impl_trait_in_type(&mut b.ty, f);
527 }
528 syn::GenericArgument::Constraint(_) => {}
529 }
530 }
531 }
532 syn::PathArguments::Parenthesized(_) => {
533 }
535 }
536 }
537 }
539 syn::Type::Ptr(ptr) => replace_impl_trait_in_type(&mut ptr.elem, f),
540 syn::Type::Reference(r) => replace_impl_trait_in_type(&mut r.elem, f),
541 syn::Type::Slice(s) => replace_impl_trait_in_type(&mut s.elem, f),
542 syn::Type::TraitObject(_) => {
543 }
546 syn::Type::Tuple(t) => {
547 for elem in t.elems.iter_mut() {
548 replace_impl_trait_in_type(elem, f);
549 }
550 }
551 syn::Type::Verbatim(_) => {}
552 _ => {}
553 }
554}
555
556fn extract_impl_trait_as_type_params(
558 output_ty: &mut syn::Type,
559) -> Vec<Chain<syn::TypeParam, syn::Token![,]>> {
560 let mut ret = vec![];
561 replace_impl_trait_in_type(output_ty, &mut |ty| {
562 let id = ret.len();
563 let span = ty.impl_token.span;
564
565 let ident = syn::Ident::new(&format!("HooksImplTrait{id}"), span);
566
567 ret.push(Chain(
568 syn::TypeParam {
569 attrs: vec![],
570 ident: ident.clone(),
571 colon_token: Some(syn::Token),
572 bounds: std::mem::take(&mut ty.bounds),
573 eq_token: None,
574 default: None,
575 },
576 syn::Token,
577 ));
578
579 syn::Type::Path(syn::TypePath {
580 qself: None,
581 path: ident.into(),
582 })
583 });
584 ret
585}