1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
extern crate proc_macro;

use proc_macro::TokenStream;
use quote::quote;
use syn::spanned::Spanned;
use syn::{
    parse_macro_input, FnArg, ImplItem, ImplItemConst, ImplItemMethod, ImplItemType, ItemImpl, Pat,
    PatIdent, PatType, Signature, Visibility,
};

/// Declares an extension trait
///
/// # Example
///
/// ```
/// #[macro_use]
/// extern crate extension_trait;
///
/// #[extension_trait(pub)]
/// impl DoubleExt for str {
///    fn double(&self) -> String {
///        self.repeat(2)
///    }
/// }
///
/// fn main() {
///     assert_eq!("Hello".double(), "HelloHello");
/// }
/// ```
#[proc_macro_attribute]
pub fn extension_trait(args: TokenStream, input: TokenStream) -> TokenStream {
    let visibility = parse_macro_input!(args as Visibility);
    let input_cloned = input.clone();
    let ItemImpl {
        impl_token,
        attrs,
        unsafety,
        trait_,
        items,
        ..
    } = parse_macro_input!(input_cloned as ItemImpl);
    let items = items.into_iter().map(|item| match item {
        ImplItem::Const(ImplItemConst {
            attrs, ident, ty, ..
        }) => quote! { #(#attrs)* const #ident: #ty; },
        ImplItem::Method(ImplItemMethod { attrs, sig: Signature {
            constness, asyncness, unsafety, abi, ident, generics, inputs, variadic, output, ..
        }, .. }) => {
            let inputs = inputs.into_iter().map(|arg| {
                if let FnArg::Typed(PatType { attrs, pat, ty, .. }) = &arg {
                    match **pat {
                        Pat::Ident(PatIdent {
                            by_ref: None,
                            mutability: None,
                            subpat: None,
                            ..
                        }) => {},
                        _ => return quote! { #(#attrs)* _: #ty },
                    }
                }
                quote! { #arg }
            });
            let where_clause = &generics.where_clause;
            quote! {
                #(#attrs)*
                #constness #asyncness #unsafety #abi fn #ident #generics (#(#inputs,)* #variadic) #output #where_clause;
            }
        },
        ImplItem::Type(ImplItemType {
            attrs,
            ident,
            generics,
            ..
        }) => quote! { #(#attrs)* type #ident #generics; },
        _ => return syn::Error::new(item.span(), "unsupported item type").to_compile_error().into(),
    });
    if let Some((None, path, _)) = trait_ {
        let input = proc_macro2::TokenStream::from(input);
        (quote! {
            #(#attrs)*
            #visibility #unsafety trait #path {
                #(#items)*
            }
            #input
        })
        .into()
    } else {
        syn::Error::new(impl_token.span(), "extension trait name was not provided")
            .to_compile_error()
            .into()
    }
}