1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
//! # `::func_wrap`
//!
//! Helper crate for procedural macro authors that wish to duplicate some
//! received function inside its body, so as to be able to _wrap_ with some
//! prologue, epilogue, cache-ing, _etc._
//!
//! ## Examples
//!
//! See [https://docs.rs/require_unsafe_in_body] for a real-life example of
//! using it.
//!
//! [https://docs.rs/require_unsafe_in_body]: https://docs.rs/require_unsafe_in_body

#![allow(nonstandard_style, unused_imports)]

use ::proc_macro2::{
    Span, TokenStream,
};

use ::quote::{
    format_ident, quote, quote_spanned, ToTokens,
};

use ::syn::{*,
    parse::{Parse, Parser, ParseStream},
    punctuated::Punctuated,
    spanned::Spanned,
    Result,
};

use ::core::{mem, ops::Not as _};

#[derive(Clone, Copy)]
pub
enum ImplOrTrait<'__> {
    /// Default implementation of methods within a trait definition.
    DefaultMethod {
        trait_name: &'__ Ident,
    },

    /// An implementation of methods within an `impl` block.
    ImplMethod {
        implementor: &'__ Type,
        trait_name: Option<&'__ Path>, // `None` if inherent impl.
    }
}
use ImplOrTrait::*;

pub
struct WrappedFuncCall<'__> {
    pub
    outer_scope: Option<(&'__ Generics, ImplOrTrait<'__>)>,

    pub
    sig: Signature,

    pub
    block: Block,

    pub
    call_site_args: Vec<Expr>,

    pub
    awaited: Option<TokenStream>,
}

pub
fn func_wrap<'lt> (
    func: &'_ mut Signature,
    block: Block,
    outer_scope: Option<(&'lt Generics, ImplOrTrait<'lt>)>,
) -> Option<WrappedFuncCall<'lt>>
{Some({
    WrappedFuncCall {
        sig: func.clone(),
        call_site_args:
            func.inputs
                .iter_mut()
                .enumerate()
                .map(|(n, fn_arg)| Some(match *fn_arg {
                    | FnArg::Receiver(ref receiver) => {
                        if outer_scope.is_none() { return None; }
                        let self_ = format_ident!(
                            "self",
                            span = receiver.self_token.span,
                        )
                        ;
                        parse_quote!( #self_ )
                    },
                    | FnArg::Typed(ref mut pat_ty) => {
                        if let Pat::Ident(ref mut pat) = *pat_ty.pat {
                            let ident = &mut pat.ident;
                            if *ident == "self" {
                                if outer_scope.is_none() { return None; }
                            } else {
                                if ident.to_string().starts_with("arg_") {
                                    *ident = format_ident!("arg_{}", n);
                                }
                            }
                            parse_quote!( #ident )
                        } else {
                            let ident = format_ident!("arg_{}", n);
                            *pat_ty.pat = parse_quote!( #ident );
                            parse_quote!( #ident )
                        }
                    },
                }))
                .collect::<Option<Vec<Expr>>>()?
        ,
        outer_scope,
        block,
        awaited: func.asyncness.map(|_| quote!( .await )),
    }
})}

impl ToTokens for WrappedFuncCall<'_> {
    fn to_tokens (self: &'_ Self, out: &'_ mut TokenStream)
    {
        let Self { sig, outer_scope, block, call_site_args, awaited } = self;
        let fname = &sig.ident;
        let (_, temp, _) = sig.generics.split_for_impl();
        // let turbofish = temp.as_turbofish();
        // Turbofish currently bundles the lifetime parameters, which for
        // a function param, leads to an error with late-bound lifetimes 😫
        // Manually hand-roll our own turbofish, then:
        let turbofish = {
            let _ = temp;
            let each_ty = sig.generics.type_params().map(|it| &it.ident);
            let each_const = sig.generics.const_params().map(|it| &it.ident);
            quote!(
                ::<
                    #(#each_ty ,)*
                    #( {#each_const} ),*
                >
            )
        };
        out.extend(match outer_scope {
            | None => quote!(
                ({
                    #[inline(always)]
                    #sig
                    #block

                    #fname #turbofish
                })(#(#call_site_args),*) #awaited
            ),

            | Some((
                generics,
                DefaultMethod { trait_name },
            )) => {
                let (intro_generics, feed_generics, where_clauses) =
                    generics.split_for_impl()
                ;
                let trait_def = quote!(
                    trait __FuncWrap #intro_generics
                    :
                        #trait_name #feed_generics
                    #where_clauses
                    {
                        #[inline(always)]
                        #sig
                        #block
                    }
                );
                let mut impl_generics = (*generics).clone();
                impl_generics.params.push(parse_quote!(
                    __Self: ?Sized + #trait_name #feed_generics
                ));
                let (impl_generics, _, _) = impl_generics.split_for_impl();
                quote!(
                    ({
                        #trait_def

                        impl #impl_generics
                            __FuncWrap #feed_generics
                        for
                            __Self
                        #where_clauses
                        {}

                        <Self as __FuncWrap #feed_generics>::#fname #turbofish
                    })(#(#call_site_args),*) #awaited
                )
            },

            | Some((
                generics,
                ImplMethod { implementor, trait_name },
            )) => {
                let (intro_generics, feed_generics, where_clauses) =
                    generics.split_for_impl()
                ;
                let mut empty_sig = sig.clone();
                empty_sig.inputs.iter_mut().for_each(|fn_arg| match *fn_arg {
                    | FnArg::Typed(ref mut pat_ty)
                        if matches!(
                            *pat_ty.pat,
                            Pat::Ident(ref pat)
                            if pat.ident == "self"
                        ).not()
                    => {
                        *pat_ty.pat = parse_quote!( _ );
                    },
                    | _ => {},
                });
                let super_trait = trait_name.map(|it| quote!( : #it ));
                quote!(
                    ({
                        trait __FuncWrap #intro_generics
                            #super_trait
                        #where_clauses
                        {
                            #empty_sig;
                        }

                        impl #intro_generics
                            __FuncWrap #feed_generics
                        for
                            #implementor
                        #where_clauses
                        {
                            #[inline(always)]
                            #sig
                            #block
                        }

                        <Self as __FuncWrap #feed_generics>::#fname #turbofish
                    })(#(#call_site_args),*) #awaited
                )
            },
        })
    }
}

pub
fn parse_and_func_wrap_with (
    input: impl Into<TokenStream>,
    mut with: impl FnMut(
        &'_ mut ImplItemMethod,
        Option<WrappedFuncCall<'_>>,
    ) -> Result<()>,
) -> Result<Item>
{Ok({
    let mut input: Item = parse2(input.into())?;
    match input {
        | Item::Fn(ref mut it_fn) => {
            let outer_scope = None;
            let ItemFn { attrs, vis, sig, block } =
                mem::replace(it_fn, parse_quote!( fn __() {} ))
            ;
            let mut func = ImplItemMethod {
                attrs, vis, sig,
                block: parse_quote!( {} ),
                defaultness: None,
            };
            let wrapped_func = func_wrap(
                &mut func.sig,
                *block,
                outer_scope,
            );
            let () = with(&mut func, wrapped_func)?;
            let ImplItemMethod { attrs, vis, sig, block, .. } = func;
            *it_fn = ItemFn {
                attrs, vis, sig, block: Box::new(block),
            };
        },

        | Item::Trait(ref mut it_trait) => {
            let outer_scope = Some((
                &it_trait.generics,
                ImplOrTrait::DefaultMethod { trait_name: &it_trait.ident },
            ));

            it_trait.items.iter_mut().try_for_each(|it| Result::Ok(match *it {
                | TraitItem::Method(ref mut method) => match method.default {
                    | Some(ref mut block) => {
                        let block = mem::replace(block, parse_quote!( {} ));
                        let TraitItemMethod { attrs, sig, .. } =
                            mem::replace(method, parse_quote!(fn __ () {}))
                        ;
                        let mut func = ImplItemMethod {
                            attrs, sig,
                            vis: Visibility::Inherited,
                            block: parse_quote!( {} ),
                            defaultness: None,
                        };
                        let wrapped_func = func_wrap(
                            &mut func.sig,
                            block,
                            outer_scope,
                        );
                        let () = with(&mut func, wrapped_func)?;
                        let ImplItemMethod { attrs, sig, block, .. } = func;
                        *method = TraitItemMethod {
                            attrs, sig, default: Some(block),
                            semi_token: None,
                        };
                    },
                    | _ => {},
                },
                | _ => {},
            }))?;
        },

        | Item::Impl(ref mut it_impl) => {
            let outer_scope = Some((
                &it_impl.generics,
                ImplOrTrait::ImplMethod {
                    implementor: &it_impl.self_ty,
                    trait_name: it_impl.trait_.as_ref().map(|(_, it, _)| it)
                },
            ));

            it_impl.items.iter_mut().try_for_each(|it| Result::Ok(match *it {
                | ImplItem::Method(ref mut func) => {
                    let wrapped_func = func_wrap(
                        &mut func.sig,
                        mem::replace(&mut func.block, parse_quote!( {} )),
                        outer_scope,
                    );
                    let () = with(func, wrapped_func)?;
                },
                | _ => {},
            }))?;
        },

        | otherwise => return Err(Error::new(otherwise.span(),
            "Expected an `fn` item, a `trait` definition, or an `impl` block."
        )),
    }
    input
})}

#[cfg(test)]
mod tests;