hooks_derive_core/
hook_macro.rs

1use std::borrow::Cow;
2
3use darling::FromMeta;
4use proc_macro2::Span;
5use quote::{quote, quote_spanned};
6use syn::{parse_quote_spanned, spanned::Spanned};
7
8use crate::{
9    detect_hooks, detected_hooks_to_tokens,
10    utils::{
11        chain::Chain,
12        either::Either,
13        empty_or_trailing::AutoEmptyOrTrailing,
14        group::{angled, parened},
15        map::map_to_tokens,
16        path_or_lit::PathOrLit,
17        phantom::{make_phantom_or_ref, PhantomOfTy},
18        repeat::Repeat,
19        type_generics::TypeGenericsWithoutBraces,
20    },
21    DetectedHooksTokens,
22};
23
24pub type GenericParams = syn::punctuated::Punctuated<syn::GenericParam, syn::Token![,]>;
25
26#[cfg_attr(feature = "extra-traits", derive(PartialEq, Eq))]
27#[derive(Debug, Default, FromMeta)]
28#[non_exhaustive]
29#[darling(default)]
30pub struct HookArgs {
31    /// Defaults to `::hooks::core`
32    pub hooks_core_path: Option<PathOrLit<syn::Path>>,
33
34    /// Defaults to tuple of all lifetime generics except `'hook`
35    /// and all type generics.
36    ///
37    /// For example, default bounds of the following hook is
38    /// `(&'a (), &'b (), PhantomData<T>)`
39    ///
40    /// ```
41    /// # extern crate hooks_dev as hooks;
42    /// # use std::marker::PhantomData;
43    /// # use hooks::{hook, HookBounds};
44    ///
45    /// #[hook]
46    /// fn use_my_hook<'a, 'b, T>() {
47    /// }
48    ///
49    /// fn asserts<'a, 'b, T>() -> impl HookBounds<
50    ///     Bounds = (&'a (), &'b (), PhantomData<T>)
51    /// > {
52    ///     use_my_hook()
53    /// }
54    ///
55    /// # asserts::<()>();
56    /// ```
57    pub custom_bounds: Option<syn::Type>,
58
59    /// Generic params used only in `Args`.
60    /// Currently only lifetimes without bounds are supported.
61    /// Defaults to no generics.
62    pub args_generics: GenericParams,
63}
64
65impl HookArgs {
66    #[inline]
67    pub fn transform_item_fn(
68        self,
69        mut item_fn: syn::ItemFn,
70    ) -> (syn::ItemFn, Option<darling::Error>) {
71        let error = self.transform_item_fn_in_place(&mut item_fn);
72        (item_fn, error)
73    }
74
75    pub fn transform_item_fn_in_place(
76        mut self,
77        item_fn: &mut syn::ItemFn,
78    ) -> Option<darling::Error> {
79        let mut errors = darling::error::Accumulator::default();
80
81        let hooks_core_path = self.hooks_core_path.map_or_else(
82            || syn::Path {
83                leading_colon: Some(Default::default()),
84                segments: syn::punctuated::Punctuated::from_iter([
85                    syn::PathSegment::from(syn::Ident::new("hooks", Span::call_site())),
86                    syn::PathSegment::from(syn::Ident::new("core", Span::call_site())),
87                ]),
88            },
89            PathOrLit::unwrap,
90        );
91
92        let sig = &mut item_fn.sig;
93
94        let span_fn_name = sig.ident.span();
95
96        // let token_add: syn::Token![+];
97        // let lt_hook;
98
99        let (hook_args_pat, mut hook_args_ty) = {
100            let hook_args = std::mem::take(&mut sig.inputs);
101
102            let paren_token = syn::token::Paren(span_fn_name);
103
104            let (hook_args_pat, hook_args_ty) = hook_args
105                .into_pairs()
106                .into_iter()
107                .map(|pair| {
108                    let (arg, comma) = pair.into_tuple();
109                    let comma = comma.unwrap_or_else(|| syn::Token![,](arg.span()));
110
111                    let (pat, ty) = match arg {
112                        syn::FnArg::Receiver(syn::Receiver {
113                            attrs,
114                            reference,
115                            mutability,
116                            self_token,
117                        }) => {
118                            // In fact, this branch is not valid
119                            // because self cannot appear in closure args.
120                            // But we still transform it and
121                            // let the compiler complain about it.
122                            let self_type = syn::Type::Path(syn::TypePath {
123                                qself: None,
124                                path: syn::Token![Self](self_token.span).into(),
125                            });
126
127                            if let Some((and_token, lifetime)) = reference {
128                                let ty = syn::Type::Reference(syn::TypeReference {
129                                    and_token,
130                                    lifetime,
131                                    mutability,
132                                    elem: Box::new(self_type),
133                                });
134                                let pat = syn::Pat::Ident(syn::PatIdent {
135                                    attrs,
136                                    by_ref: None,
137                                    mutability: None,
138                                    ident: self_token.into(),
139                                    subpat: None,
140                                });
141                                (pat, ty)
142                            } else {
143                                (
144                                    syn::Pat::Ident(syn::PatIdent {
145                                        attrs,
146                                        by_ref: None,
147                                        mutability,
148                                        ident: self_token.into(),
149                                        subpat: None,
150                                    }),
151                                    self_type,
152                                )
153                            }
154                        }
155                        syn::FnArg::Typed(pat_ty) => {
156                            for attr in pat_ty.attrs {
157                                errors.push(
158                                    darling::Error::custom(
159                                        "arguments of hook cannot have attributes",
160                                    )
161                                    .with_span(&attr),
162                                );
163                            }
164                            (*pat_ty.pat, *pat_ty.ty)
165                        }
166                    };
167
168                    (
169                        syn::punctuated::Pair::Punctuated(pat, comma),
170                        syn::punctuated::Pair::Punctuated(ty, comma),
171                    )
172                })
173                .unzip();
174
175            let hook_args_pat = syn::PatTuple {
176                attrs: vec![],
177                paren_token,
178                elems: hook_args_pat,
179            };
180
181            let hook_args_ty = syn::TypeTuple {
182                paren_token,
183                elems: hook_args_ty,
184            };
185
186            (hook_args_pat, hook_args_ty)
187        };
188
189        crate::utils::elided_args_generics::auto_fill_lifetimes(
190            &mut self.args_generics,
191            &mut hook_args_ty.elems,
192        );
193
194        let args_lifetimes = &self.args_generics;
195
196        let args_lifetimes_empty = args_lifetimes.is_empty();
197
198        if !args_lifetimes_empty {
199            for g in self.args_generics.iter() {
200                match g {
201                    syn::GenericParam::Lifetime(_) => {}
202                    _ => errors.push(
203                        darling::Error::custom(
204                            "Currently args_generics only supports lifetimes without bounds",
205                        )
206                        .with_span(&g),
207                    ),
208                }
209            }
210        }
211
212        let generics = &sig.generics;
213
214        let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
215
216        let default_hook_bounds_fields_eot = map_to_tokens(&generics.params, |params| {
217            params.pairs().filter_map(|p| {
218                make_phantom_or_ref(p.value()).map(|v| {
219                    Chain(
220                        v,
221                        p.punct()
222                            .map_or_else(|| Cow::Owned(Default::default()), |v| Cow::Borrowed(*v)),
223                    )
224                })
225            })
226        });
227
228        let hook_bounds = self.custom_bounds.as_ref().map_or_else(
229            || Either::A(parened(&default_hook_bounds_fields_eot)),
230            |ty| Either::B(ty),
231        );
232
233        let mut output_ty: syn::Type = {
234            let fn_rt = &mut sig.output;
235            let (ra, output_ty) = match std::mem::replace(fn_rt, syn::ReturnType::Default) {
236                syn::ReturnType::Default => {
237                    let span = fn_rt.span();
238                    (
239                        syn::Token![->](span),
240                        syn::Type::Tuple(syn::TypeTuple {
241                            paren_token: syn::token::Paren(span),
242                            elems: Default::default(),
243                        }),
244                    )
245                }
246                syn::ReturnType::Type(ra, ty) => (ra, *ty),
247            };
248
249            let (for_hook, for_lifetimes) = if args_lifetimes_empty {
250                (None, None)
251            } else {
252                (
253                    Some(
254                        Chain(syn::Token![for](span_fn_name), syn::Token![<](span_fn_name))
255                            .chain(args_lifetimes)
256                            .chain(syn::Token![>](span_fn_name)),
257                    ),
258                    Some(Chain(syn::Token![,](span_fn_name), args_lifetimes)),
259                )
260            };
261
262            let return_ty = parse_quote_spanned! { span_fn_name =>
263                impl #for_hook #hooks_core_path ::Hook<#hook_args_ty>
264                    + for<'hook #for_lifetimes> #hooks_core_path ::HookLifetime<
265                        'hook,
266                        #hook_args_ty,
267                        &'hook #hook_bounds,
268                        Value = #output_ty
269                    >
270                    + #hooks_core_path ::HookBounds<Bounds = #hook_bounds>
271            };
272
273            *fn_rt = syn::ReturnType::Type(ra, return_ty);
274
275            output_ty
276        };
277
278        // T,
279        let fn_type_generics_eot = AutoEmptyOrTrailing(TypeGenericsWithoutBraces(&generics.params));
280
281        // HooksImplTrait0: Debug, HooksImplTrait1: Any,
282        //      introduced by impl trait in return position
283        let it_impl_generics_eot = extract_impl_trait_as_type_params(&mut output_ty);
284
285        // HooksImplTrait0, HooksImplTrait1,
286        let it_type_generics_eot = map_to_tokens(&it_impl_generics_eot, |v| {
287            v.iter().map(|pair| Chain(&pair.0.ident, &pair.1))
288        });
289
290        // ( PhantomData<T>, PhantomData<HooksImplTrait0>, PhantomData<HooksImplTrait1>, )
291        let hook_types_phantom;
292        // <T: Clone, HooksImplTrait0: Debug, HooksImplTrait1: Any,>
293        let hook_types_impl_generics;
294        // <T, HooksImplTrait0, HooksImplTrait1,>
295        let hook_types_type_generics;
296        // _, _,
297        let it_generics_elided_without_braces_eot;
298
299        // where T: SomeOtherTrait, HooksImplTrait0: 'hook,
300        // let hook_lifetime_where_clause;
301        // TODO: figure out when hook_lifetime_where_clause is needed.
302
303        if it_impl_generics_eot.is_empty() {
304            hook_types_phantom = Either::A(&hook_bounds);
305            hook_types_impl_generics = Either::A(impl_generics);
306            hook_types_type_generics = Either::A(&type_generics);
307            it_generics_elided_without_braces_eot = None;
308            // hook_lifetime_where_clause = Either::A(where_clause);
309        } else {
310            hook_types_phantom = Either::B(parened(Chain(
311                &default_hook_bounds_fields_eot,
312                map_to_tokens(&it_impl_generics_eot, |v| {
313                    v.iter()
314                        .map(|pair| Chain(PhantomOfTy(&pair.0.ident), pair.1))
315                }),
316            )));
317
318            hook_types_impl_generics = Either::B(angled(Chain(
319                AutoEmptyOrTrailing(&sig.generics.params),
320                map_to_tokens(&it_impl_generics_eot, |v| v.iter()),
321            )));
322
323            hook_types_type_generics =
324                Either::B(angled(Chain(&fn_type_generics_eot, &it_type_generics_eot)));
325
326            it_generics_elided_without_braces_eot = Some(Repeat(
327                Chain(<syn::Token![_]>::default(), <syn::Token![,]>::default()),
328                it_impl_generics_eot.len(),
329            ));
330
331            // token_add = <syn::Token![+]>::default();
332
333            // lt_hook = syn::Lifetime {
334            //     apostrophe: Span::call_site(),
335            //     ident: syn::Ident::new("hook", Span::call_site()),
336            // };
337
338            // let it_where_predicates_eot = map_to_tokens(&it_impl_generics_eot, |data| {
339            //     data.iter()
340            //         .map(|tp| chain![&tp.0, &token_add, &lt_hook, &tp.1,])
341            // });
342
343            // hook_lifetime_where_clause = Either::B({
344            //     match where_clause {
345            //         None => Chain(Default::default(), Either::A(it_where_predicates_eot)),
346            //         Some(where_clause) => Chain(
347            //             where_clause.where_token,
348            //             Either::B(Chain(it_where_predicates_eot, &where_clause.predicates)),
349            //         ),
350            //     }
351            // });
352        };
353
354        // T: Clone,
355        // The generics comes from `fn`, so there won't be default types like `<T = i32>`
356        let fn_impl_generics_without_braces_eot = AutoEmptyOrTrailing(&sig.generics.params);
357
358        let mut impl_use_hook = std::mem::take(&mut item_fn.block.stmts);
359
360        let used_hooks = detect_hooks(impl_use_hook.iter_mut(), &hooks_core_path);
361
362        let impl_poll_next_update = if used_hooks.is_empty() {
363            quote_spanned! { span_fn_name =>
364                #hooks_core_path ::fn_hook::poll_next_update_ready_false
365            }
366        } else {
367            quote_spanned! { span_fn_name =>
368                #hooks_core_path ::HookPollNextUpdate::poll_next_update
369            }
370        };
371
372        let DetectedHooksTokens {
373            data_expr: expr_hooks_data,
374            fn_arg_data_pat: arg_hooks_data,
375            fn_stmts_extract_data: impl_extract_hooks_data,
376        } = detected_hooks_to_tokens(
377            used_hooks,
378            &hooks_core_path,
379            quote!(()),
380            Some(quote!(())),
381            sig.fn_token.span,
382        );
383
384        let (args_generics_for_hook_lifetime_eot, stmt_ret) = if args_lifetimes_empty {
385            let stmt_ret: syn::Expr = parse_quote_spanned! { span_fn_name =>
386                #hooks_core_path ::fn_hook::new_fn_hook::<
387                    #hook_args_ty,
388                    _,
389                    __HookTypes <#fn_type_generics_eot  #it_generics_elided_without_braces_eot>
390                >(
391                    #expr_hooks_data,
392                    #impl_poll_next_update,
393                    |#arg_hooks_data, #hook_args_pat : #hook_args_ty| {
394                        #impl_extract_hooks_data
395
396                        #(#impl_use_hook)*
397                    }
398                )
399            };
400
401            (None, stmt_ret)
402        } else {
403            let stmt_ret: syn::Expr = parse_quote_spanned! { span_fn_name =>
404                {
405                    #[inline]
406                    fn _hooks_def_fn_hook<
407                        #fn_impl_generics_without_braces_eot
408                        #(#it_impl_generics_eot)*
409                        __HooksData,
410                        __HooksPoll: ::core::ops::Fn(::core::pin::Pin<&mut __HooksData>, &mut ::core::task::Context) -> ::core::task::Poll<::core::primitive::bool>,
411                        __HooksUseHook: for<'hook, #args_lifetimes> ::core::ops::Fn(::core::pin::Pin<&'hook mut __HooksData>, #hook_args_ty) -> #output_ty,
412                    >(
413                        hooks_data: __HooksData,
414                        hooks_poll: __HooksPoll,
415                        hooks_use_hook: __HooksUseHook
416                    ) -> #hooks_core_path ::fn_hook::FnHook::<__HooksData, __HooksPoll, __HooksUseHook, __HookTypes #hook_types_type_generics> #where_clause {
417                        #hooks_core_path ::fn_hook::FnHook::<__HooksData, __HooksPoll, __HooksUseHook, __HookTypes #hook_types_type_generics>::new(
418                            hooks_data,
419                            hooks_poll,
420                            hooks_use_hook
421                        )
422                    }
423
424                    _hooks_def_fn_hook::<
425                        #fn_type_generics_eot
426                        #it_generics_elided_without_braces_eot
427                        _, _, _
428                    >(
429                        #expr_hooks_data,
430                        #impl_poll_next_update,
431                        |#arg_hooks_data, #hook_args_pat| {
432                            #impl_extract_hooks_data
433
434                            #(#impl_use_hook)*
435                        },
436                    )
437                }
438            };
439
440            (Some(AutoEmptyOrTrailing(self.args_generics)), stmt_ret)
441        };
442
443        item_fn.block.stmts = parse_quote_spanned! { span_fn_name =>
444            struct __HookTypes #hook_types_impl_generics #where_clause {
445                __: ::core::marker::PhantomData< #hook_types_phantom >
446            }
447
448            impl #hook_types_impl_generics #hooks_core_path ::HookBounds for __HookTypes #hook_types_type_generics #where_clause {
449                type Bounds = #hook_bounds;
450            }
451
452            impl <
453                'hook,
454                #args_generics_for_hook_lifetime_eot
455                #fn_impl_generics_without_braces_eot
456                #(#it_impl_generics_eot)*
457            > #hooks_core_path ::HookLifetime<'hook, #hook_args_ty, &'hook #hook_bounds>
458                for __HookTypes #hook_types_type_generics #where_clause
459            {
460                type Value = #output_ty;
461            }
462        };
463
464        item_fn.block.stmts.push(syn::Stmt::Expr(stmt_ret));
465
466        errors.finish().err()
467    }
468
469    pub fn from_punctuated_meta_list(
470        meta_list: syn::punctuated::Punctuated<syn::NestedMeta, syn::Token![,]>,
471    ) -> darling::Result<Self> {
472        let args: Vec<syn::NestedMeta> = meta_list.into_iter().collect();
473        Self::from_list(&args)
474    }
475
476    pub fn with_args_generics(mut self, args_generics: GenericParams) -> Self {
477        self.args_generics = args_generics;
478        self
479    }
480}
481
482fn replace_impl_trait_in_type(
483    ty: &mut syn::Type,
484    f: &mut impl FnMut(&mut syn::TypeImplTrait) -> syn::Type,
485) {
486    match ty {
487        syn::Type::Array(ta) => replace_impl_trait_in_type(&mut ta.elem, f),
488        syn::Type::BareFn(_) => {}
489        syn::Type::Group(g) => replace_impl_trait_in_type(&mut g.elem, f),
490        syn::Type::ImplTrait(it) => {
491            // TODO: resolve `impl Trait` in it.bounds
492            // f(it.bounds)
493
494            *ty = f(it)
495        }
496        syn::Type::Infer(_) => {}
497        syn::Type::Macro(_) => {}
498        syn::Type::Never(_) => {}
499        syn::Type::Paren(p) => {
500            let is_impl_trait = matches!(&*p.elem, syn::Type::ImplTrait(_));
501            replace_impl_trait_in_type(&mut p.elem, f);
502
503            // also remove the paren for (HookImplTrait0)
504            if is_impl_trait {
505                let new_ty =
506                    std::mem::replace(&mut *p.elem, syn::Type::Verbatim(Default::default()));
507                *ty = new_ty;
508            }
509        }
510        syn::Type::Path(tp) => {
511            if let Some(qself) = &mut tp.qself {
512                replace_impl_trait_in_type(&mut qself.ty, f);
513            }
514            for seg in tp.path.segments.iter_mut() {
515                match &mut seg.arguments {
516                    syn::PathArguments::None => {}
517                    syn::PathArguments::AngleBracketed(a) => {
518                        for arg in a.args.iter_mut() {
519                            match arg {
520                                syn::GenericArgument::Lifetime(_) => {}
521                                syn::GenericArgument::Type(ty) => {
522                                    replace_impl_trait_in_type(ty, f);
523                                }
524                                syn::GenericArgument::Const(_) => {}
525                                syn::GenericArgument::Binding(b) => {
526                                    replace_impl_trait_in_type(&mut b.ty, f);
527                                }
528                                syn::GenericArgument::Constraint(_) => {}
529                            }
530                        }
531                    }
532                    syn::PathArguments::Parenthesized(_) => {
533                        // TODO: resolve `impl Trait` in path like `Fn(impl Trait) -> impl Trait`
534                    }
535                }
536            }
537            // TODO: resolve `impl Trait` in path like `Struct<impl Trait>`
538        }
539        syn::Type::Ptr(ptr) => replace_impl_trait_in_type(&mut ptr.elem, f),
540        syn::Type::Reference(r) => replace_impl_trait_in_type(&mut r.elem, f),
541        syn::Type::Slice(s) => replace_impl_trait_in_type(&mut s.elem, f),
542        syn::Type::TraitObject(_) => {
543            // TODO: resolve `impl Trait` in to.bounds
544            // f(to.bounds)
545        }
546        syn::Type::Tuple(t) => {
547            for elem in t.elems.iter_mut() {
548                replace_impl_trait_in_type(elem, f);
549            }
550        }
551        syn::Type::Verbatim(_) => {}
552        _ => {}
553    }
554}
555
556/// The returned Punctuated is guaranteed to be `empty_or_trailing`
557fn extract_impl_trait_as_type_params(
558    output_ty: &mut syn::Type,
559) -> Vec<Chain<syn::TypeParam, syn::Token![,]>> {
560    let mut ret = vec![];
561    replace_impl_trait_in_type(output_ty, &mut |ty| {
562        let id = ret.len();
563        let span = ty.impl_token.span;
564
565        let ident = syn::Ident::new(&format!("HooksImplTrait{id}"), span);
566
567        ret.push(Chain(
568            syn::TypeParam {
569                attrs: vec![],
570                ident: ident.clone(),
571                colon_token: Some(syn::Token![:](span)),
572                bounds: std::mem::take(&mut ty.bounds),
573                eq_token: None,
574                default: None,
575            },
576            syn::Token![,](span),
577        ));
578
579        syn::Type::Path(syn::TypePath {
580            qself: None,
581            path: ident.into(),
582        })
583    });
584    ret
585}