1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
extern crate proc_macro;
extern crate proc_macro2;
#[macro_use]
extern crate quote;
extern crate syn;

use proc_macro::TokenStream;
use syn::{Path, Type, TypePath};

#[proc_macro_derive(Deref)]
pub fn derive_deref(input: TokenStream) -> TokenStream {
    let item = syn::parse(input).unwrap();
    let (field_ty, field_access) = parse_fields(&item, false);

    let name = &item.ident;
    let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();

    quote!(
        impl #impl_generics ::std::ops::Deref for #name #ty_generics
        #where_clause
        {
            type Target = #field_ty;

            fn deref(&self) -> &Self::Target {
                #field_access
            }
        }
    ).into()
}

#[proc_macro_derive(DerefMut)]
pub fn derive_deref_mut(input: TokenStream) -> TokenStream {
    let item = syn::parse(input).unwrap();
    let (_, field_access) = parse_fields(&item, true);

    let name = &item.ident;
    let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();

    quote!(
        impl #impl_generics ::std::ops::DerefMut for #name #ty_generics
        #where_clause
        {
            fn deref_mut(&mut self) -> &mut Self::Target {
                #field_access
            }
        }
    ).into()
}

fn parse_fields(item: &syn::DeriveInput, mutable: bool) -> (syn::Type, proc_macro2::TokenStream) {
    let trait_name = if mutable { "DerefMut" } else { "Deref" };
    let fields = match item.data {
        syn::Data::Struct(ref body) => body.fields.iter().collect::<Vec<&syn::Field>>(),
        _ => panic!("#[derive({})] can only be used on structs", trait_name),
    };

    let field_ty = match fields.len() {
        1 => Some(fields[0].ty.clone()),
        2 => {
            if let Type::Path(TypePath { path: Path { segments, .. }, .. }) = &fields[1].ty {
                if segments
                    .last()
                    .expect("Expected path to have at least one segment")
                    .value()
                    .ident == "PhantomData"
                {
                    Some(fields[0].ty.clone())
                } else {
                    None
                }
            } else {
                None
            }
        },
        _ => None,
    };
    let field_ty = field_ty
        .unwrap_or_else(|| {
            panic!(
                "#[derive({})] can only be used on structs with one field, \
                 and optionally a second `PhantomData` field.",
                 trait_name,
            )
        });

    let field_name = match fields[0].ident {
        Some(ref ident) => quote!(#ident),
        None => quote!(0),
    };

    match (field_ty, mutable) {
        (syn::Type::Reference(syn::TypeReference { elem, .. }), _) => (*elem.clone(), quote!(self.#field_name)),
        (x, true) => (x, quote!(&mut self.#field_name)),
        (x, false) => (x, quote!(&self.#field_name)),
    }
}