better_default_derive/
lib.rs

1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6    parse_macro_input, parse_quote, spanned::Spanned, Data, DataEnum, DataStruct, DeriveInput,
7    Error, Fields, Generics, Ident, Type, TypePath, Variant,
8};
9
10const DEFAULT_VARIANT_KEYWORD: &str = "default";
11
12#[proc_macro_derive(Default, attributes(default))]
13pub fn derive(input: TokenStream) -> TokenStream {
14    let output = match __derive(parse_macro_input!(input as DeriveInput)) {
15        Ok(output) => output,
16        Err(err) => err.into_compile_error(),
17    };
18    proc_macro::TokenStream::from(output)
19}
20
21fn __derive(input: DeriveInput) -> Result<proc_macro2::TokenStream, Error> {
22    let DeriveInput {
23        attrs: _,
24        vis: _,
25        ident: input_ident,
26        mut generics,
27        data,
28    } = input;
29
30    let (body, fields) = match data {
31        Data::Struct(data) => struct_case(&input_ident, data),
32        Data::Enum(data) => enum_case(&input_ident, data),
33        Data::Union(_) => Err(Error::new_spanned(
34            &input_ident,
35            "#[derive(Default)] is not supported for unions",
36        )),
37    }?;
38
39    add_trait_bounds(&mut generics, &fields);
40    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
41
42    let output = quote! {
43        impl #impl_generics std::default::Default for #input_ident #ty_generics #where_clause {
44            fn default() -> Self {
45                #body
46            }
47        }
48    };
49
50    Ok(output)
51}
52
53fn struct_case(
54    struct_ident: &Ident,
55    data: DataStruct,
56) -> Result<(proc_macro2::TokenStream, Fields), Error> {
57    let data_constr = default_instance_constr(struct_ident, &data.fields);
58
59    Ok((data_constr, data.fields))
60}
61
62fn enum_case(
63    root_ident: &Ident,
64    data: DataEnum,
65) -> Result<(proc_macro2::TokenStream, Fields), Error> {
66    if data.variants.is_empty() {
67        return Err(Error::new_spanned(
68            root_ident,
69            "#[derive(Default)] is not supported for empty enums",
70        ));
71    }
72
73    let mut default_variants = data.variants.into_iter().filter(has_default_attr);
74
75    match (default_variants.next(), default_variants.next()) {
76        (Some(default_variant), None) => {
77            let default_variant_constr = {
78                // Something as below would be great, but `Self::XXX` is not a valid identifier
79                // let variant_ident = Ident::new(&format!("Self::{}", &default_variant.ident), Span::call_site());
80                let constr =
81                    default_instance_constr(&default_variant.ident, &default_variant.fields);
82                quote!(Self::#constr)
83            };
84
85            Ok((default_variant_constr, default_variant.fields))
86        }
87
88        (Some(default_variant), Some(another_default_variant)) => {
89            let msg = "#[default] is defined multiple times";
90            if cfg!(nightly) {
91                let span = another_default_variant
92                    .span()
93                    .join(default_variant.span())
94                    .expect("self and other are not from the same file");
95                Err(Error::new(span, msg))
96            } else {
97                Err(Error::new_spanned(another_default_variant, msg))
98            }
99        }
100        (None, _) => Err(Error::new_spanned(
101            root_ident,
102            "expected one variant with #[default]",
103        )),
104    }
105}
106
107fn default_instance_constr(data_constr_ident: &Ident, fields: &Fields) -> proc_macro2::TokenStream {
108    match fields {
109        Fields::Unit => quote!(#data_constr_ident),
110        Fields::Unnamed(unnamed) => {
111            let fields_constr = unnamed.unnamed.iter().map(|field| {
112                let ty = &field.ty;
113                quote!(#ty::default())
114            });
115            quote!(#data_constr_ident(#(#fields_constr),*))
116        }
117        Fields::Named(named) => {
118            let fields_constr = named.named.iter().map(|field| {
119                let field_name = field
120                    .ident
121                    .as_ref()
122                    .expect("named fields should contain an ident");
123                let ty = &field.ty;
124                quote!(#field_name : #ty::default())
125            });
126            quote!(#data_constr_ident{#(#fields_constr),*})
127        }
128    }
129}
130
131fn has_default_attr(variant: &Variant) -> bool {
132    variant
133        .attrs
134        .get(0)
135        .map(|attr| attr.path().is_ident(DEFAULT_VARIANT_KEYWORD))
136        .unwrap_or_default()
137}
138
139fn add_trait_bounds(generics: &mut Generics, fields: &Fields) {
140    let used_types: HashSet<Ident> = fields
141        .iter()
142        .filter_map(|field| type_ident(&field.ty))
143        .cloned()
144        .collect();
145
146    for type_param in generics.type_params_mut() {
147        if used_types.contains(&type_param.ident) {
148            type_param
149                .bounds
150                .push(parse_quote!(::std::default::Default));
151        }
152    }
153}
154
155fn type_ident(ty: &Type) -> Option<&Ident> {
156    if let &Type::Path(TypePath {
157        qself: None,
158        ref path,
159    }) = ty
160    {
161        if path.segments.len() == 1 {
162            return Some(&path.segments.first()?.ident);
163        }
164    }
165    None
166}