Skip to main content

func_wrap/
lib.rs

1//! # `::func_wrap`
2//!
3//! Helper crate for procedural macro authors that wish to duplicate some
4//! received function inside its body, so as to be able to _wrap_ with some
5//! prologue, epilogue, cache-ing, _etc._
6//!
7//! ## Examples
8//!
9//! See [https://docs.rs/require_unsafe_in_body] for a real-life example of
10//! using it.
11//!
12//! [https://docs.rs/require_unsafe_in_body]: https://docs.rs/require_unsafe_in_body
13
14#![allow(nonstandard_style, unused_imports)]
15
16use ::proc_macro2::{
17    Span, TokenStream,
18};
19
20use ::quote::{
21    format_ident, quote, quote_spanned, ToTokens,
22};
23
24use ::syn::{*,
25    parse::{Parse, Parser, ParseStream},
26    punctuated::Punctuated,
27    spanned::Spanned,
28    Result,
29};
30
31use ::core::{mem, ops::Not as _};
32
33#[derive(Clone, Copy)]
34pub
35enum ImplOrTrait<'__> {
36    /// Default implementation of methods within a trait definition.
37    DefaultMethod {
38        trait_name: &'__ Ident,
39    },
40
41    /// An implementation of methods within an `impl` block.
42    ImplMethod {
43        implementor: &'__ Type,
44        trait_name: Option<&'__ Path>, // `None` if inherent impl.
45    }
46}
47use ImplOrTrait::*;
48
49pub
50struct WrappedFuncCall<'__> {
51    pub
52    outer_scope: Option<(&'__ Generics, ImplOrTrait<'__>)>,
53
54    pub
55    sig: Signature,
56
57    pub
58    block: Block,
59
60    pub
61    call_site_args: Vec<Expr>,
62
63    pub
64    awaited: Option<TokenStream>,
65}
66
67pub
68fn func_wrap<'lt> (
69    func: &'_ mut Signature,
70    block: Block,
71    outer_scope: Option<(&'lt Generics, ImplOrTrait<'lt>)>,
72) -> Option<WrappedFuncCall<'lt>>
73{Some({
74    WrappedFuncCall {
75        sig: func.clone(),
76        call_site_args:
77            func.inputs
78                .iter_mut()
79                .enumerate()
80                .map(|(n, fn_arg)| Some(match *fn_arg {
81                    | FnArg::Receiver(ref receiver) => {
82                        if outer_scope.is_none() { return None; }
83                        let self_ = format_ident!(
84                            "self",
85                            span = receiver.self_token.span,
86                        )
87                        ;
88                        parse_quote!( #self_ )
89                    },
90                    | FnArg::Typed(ref mut pat_ty) => {
91                        if let Pat::Ident(ref mut pat) = *pat_ty.pat {
92                            let ident = &mut pat.ident;
93                            if *ident == "self" {
94                                if outer_scope.is_none() { return None; }
95                            } else {
96                                if ident.to_string().starts_with("arg_") {
97                                    *ident = format_ident!("arg_{}", n);
98                                }
99                            }
100                            parse_quote!( #ident )
101                        } else {
102                            let ident = format_ident!("arg_{}", n);
103                            *pat_ty.pat = parse_quote!( #ident );
104                            parse_quote!( #ident )
105                        }
106                    },
107                }))
108                .collect::<Option<Vec<Expr>>>()?
109        ,
110        outer_scope,
111        block,
112        awaited: func.asyncness.map(|_| quote!( .await )),
113    }
114})}
115
116impl ToTokens for WrappedFuncCall<'_> {
117    fn to_tokens (self: &'_ Self, out: &'_ mut TokenStream)
118    {
119        let Self { sig, outer_scope, block, call_site_args, awaited } = self;
120        let fname = &sig.ident;
121        let (_, temp, _) = sig.generics.split_for_impl();
122        // let turbofish = temp.as_turbofish();
123        // Turbofish currently bundles the lifetime parameters, which for
124        // a function param, leads to an error with late-bound lifetimes 😫
125        // Manually hand-roll our own turbofish, then:
126        let turbofish = {
127            let _ = temp;
128            let each_ty = sig.generics.type_params().map(|it| &it.ident);
129            let each_const = sig.generics.const_params().map(|it| &it.ident);
130            quote!(
131                ::<
132                    #(#each_ty ,)*
133                    #( {#each_const} ),*
134                >
135            )
136        };
137        out.extend(match outer_scope {
138            | None => quote!(
139                ({
140                    #[inline(always)]
141                    #sig
142                    #block
143
144                    #fname #turbofish
145                })(#(#call_site_args),*) #awaited
146            ),
147
148            | Some((
149                generics,
150                DefaultMethod { trait_name },
151            )) => {
152                let (intro_generics, feed_generics, where_clauses) =
153                    generics.split_for_impl()
154                ;
155                let trait_def = quote!(
156                    trait __FuncWrap #intro_generics
157                    :
158                        #trait_name #feed_generics
159                    #where_clauses
160                    {
161                        #[inline(always)]
162                        #sig
163                        #block
164                    }
165                );
166                let mut impl_generics = (*generics).clone();
167                impl_generics.params.push(parse_quote!(
168                    __Self: ?Sized + #trait_name #feed_generics
169                ));
170                let (impl_generics, _, _) = impl_generics.split_for_impl();
171                quote!(
172                    ({
173                        #trait_def
174
175                        impl #impl_generics
176                            __FuncWrap #feed_generics
177                        for
178                            __Self
179                        #where_clauses
180                        {}
181
182                        <Self as __FuncWrap #feed_generics>::#fname #turbofish
183                    })(#(#call_site_args),*) #awaited
184                )
185            },
186
187            | Some((
188                generics,
189                ImplMethod { implementor, trait_name },
190            )) => {
191                let (intro_generics, feed_generics, where_clauses) =
192                    generics.split_for_impl()
193                ;
194                let mut empty_sig = sig.clone();
195                empty_sig.inputs.iter_mut().for_each(|fn_arg| match *fn_arg {
196                    | FnArg::Typed(ref mut pat_ty)
197                        if matches!(
198                            *pat_ty.pat,
199                            Pat::Ident(ref pat)
200                            if pat.ident == "self"
201                        ).not()
202                    => {
203                        *pat_ty.pat = parse_quote!( _ );
204                    },
205                    | _ => {},
206                });
207                let super_trait = trait_name.map(|it| quote!( : #it ));
208                quote!(
209                    ({
210                        trait __FuncWrap #intro_generics
211                            #super_trait
212                        #where_clauses
213                        {
214                            #empty_sig;
215                        }
216
217                        impl #intro_generics
218                            __FuncWrap #feed_generics
219                        for
220                            #implementor
221                        #where_clauses
222                        {
223                            #[inline(always)]
224                            #sig
225                            #block
226                        }
227
228                        <Self as __FuncWrap #feed_generics>::#fname #turbofish
229                    })(#(#call_site_args),*) #awaited
230                )
231            },
232        })
233    }
234}
235
236pub
237fn parse_and_func_wrap_with (
238    input: impl Into<TokenStream>,
239    mut with: impl FnMut(
240        &'_ mut ImplItemMethod,
241        Option<WrappedFuncCall<'_>>,
242    ) -> Result<()>,
243) -> Result<Item>
244{Ok({
245    let mut input: Item = parse2(input.into())?;
246    match input {
247        | Item::Fn(ref mut it_fn) => {
248            let outer_scope = None;
249            let ItemFn { attrs, vis, sig, block } =
250                mem::replace(it_fn, parse_quote!( fn __() {} ))
251            ;
252            let mut func = ImplItemMethod {
253                attrs, vis, sig,
254                block: parse_quote!( {} ),
255                defaultness: None,
256            };
257            let wrapped_func = func_wrap(
258                &mut func.sig,
259                *block,
260                outer_scope,
261            );
262            let () = with(&mut func, wrapped_func)?;
263            let ImplItemMethod { attrs, vis, sig, block, .. } = func;
264            *it_fn = ItemFn {
265                attrs, vis, sig, block: Box::new(block),
266            };
267        },
268
269        | Item::Trait(ref mut it_trait) => {
270            let outer_scope = Some((
271                &it_trait.generics,
272                ImplOrTrait::DefaultMethod { trait_name: &it_trait.ident },
273            ));
274
275            it_trait.items.iter_mut().try_for_each(|it| Result::Ok(match *it {
276                | TraitItem::Method(ref mut method) => match method.default {
277                    | Some(ref mut block) => {
278                        let block = mem::replace(block, parse_quote!( {} ));
279                        let TraitItemMethod { attrs, sig, .. } =
280                            mem::replace(method, parse_quote!(fn __ () {}))
281                        ;
282                        let mut func = ImplItemMethod {
283                            attrs, sig,
284                            vis: Visibility::Inherited,
285                            block: parse_quote!( {} ),
286                            defaultness: None,
287                        };
288                        let wrapped_func = func_wrap(
289                            &mut func.sig,
290                            block,
291                            outer_scope,
292                        );
293                        let () = with(&mut func, wrapped_func)?;
294                        let ImplItemMethod { attrs, sig, block, .. } = func;
295                        *method = TraitItemMethod {
296                            attrs, sig, default: Some(block),
297                            semi_token: None,
298                        };
299                    },
300                    | _ => {},
301                },
302                | _ => {},
303            }))?;
304        },
305
306        | Item::Impl(ref mut it_impl) => {
307            let outer_scope = Some((
308                &it_impl.generics,
309                ImplOrTrait::ImplMethod {
310                    implementor: &it_impl.self_ty,
311                    trait_name: it_impl.trait_.as_ref().map(|(_, it, _)| it)
312                },
313            ));
314
315            it_impl.items.iter_mut().try_for_each(|it| Result::Ok(match *it {
316                | ImplItem::Method(ref mut func) => {
317                    let wrapped_func = func_wrap(
318                        &mut func.sig,
319                        mem::replace(&mut func.block, parse_quote!( {} )),
320                        outer_scope,
321                    );
322                    let () = with(func, wrapped_func)?;
323                },
324                | _ => {},
325            }))?;
326        },
327
328        | otherwise => return Err(Error::new(otherwise.span(),
329            "Expected an `fn` item, a `trait` definition, or an `impl` block."
330        )),
331    }
332    input
333})}
334
335#[cfg(test)]
336mod tests;