1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens};
use syn::{parse, parse2, ImplItem, ImplItemFn, ItemImpl, Result};

/// Create and impl an extension trait.
///
/// # Examples
///
/// ```
/// #[local_impl::local_impl]
/// impl<T> VecExt for Vec<T> {
///     fn not_empty(&self) -> bool {
///         !self.is_empty()
///     }
/// }
///
/// # fn main() {
/// let mut v = Vec::new();
/// assert!(!v.not_empty());
/// v.push(1);
/// assert!(v.not_empty());
/// # }
/// ```
///
/// Imported across modules/traits and using trait bounds.
///
/// ```
/// mod other_module {
///     #[local_impl::local_impl]
///     impl<T: Default> VecExt for Vec<T> {
///         fn push_default(&mut self) {
///             self.push(T::default());
///         }
///     }
/// }
///
/// # fn main() {
/// use other_module::VecExt;
/// let mut v: Vec<()> = Vec::new();
/// v.push_default();
/// v.push_default();
/// assert_eq!(v, vec![(), ()]);
/// # }
/// ```
#[proc_macro_attribute]
pub fn local_impl(attrs: TokenStream, input: TokenStream) -> TokenStream {
    let mut orig_input = input.clone();

    local_impl_impl(attrs.into(), input.into()).map_or_else(
        |err| {
            orig_input.extend::<proc_macro::TokenStream>(err.to_compile_error().into());
            orig_input
        },
        Into::into,
    )
}

fn local_impl_impl(attrs: TokenStream2, input: TokenStream2) -> Result<TokenStream2> {
    let _ = parse2::<parse::Nothing>(attrs)?;
    let input = parse2::<ItemImpl>(input)?;

    let (impl_generics, generics, where_clause) = input.generics.split_for_impl();
    let trait_name = input.trait_.clone().expect("Expected a trait name").1;

    let method_heads = input
        .items
        .iter()
        .map(|item| match item {
            ImplItem::Fn(ImplItemFn { attrs: _, sig, .. }) => {
                quote! {
                    #sig;
                }
            }
            _ => panic!("Only methods are allowed"),
        })
        .collect::<Vec<_>>();

    let trait_def = quote!(
        pub(crate) trait #trait_name #generics #where_clause {
            #(#method_heads)*
        }
    );

    let items = input.items;
    let self_ty = input.self_ty;

    let impl_block = quote!(
        impl #impl_generics #trait_name #generics for #self_ty {
            #(#items)*
        }
    );

    Ok(quote!(
        #trait_def
        #impl_block
    )
    .into_token_stream())
}