better_default_derive/
lib.rs1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6 parse_macro_input, parse_quote, spanned::Spanned, Data, DataEnum, DataStruct, DeriveInput,
7 Error, Fields, Generics, Ident, Type, TypePath, Variant,
8};
9
10const DEFAULT_VARIANT_KEYWORD: &str = "default";
11
12#[proc_macro_derive(Default, attributes(default))]
13pub fn derive(input: TokenStream) -> TokenStream {
14 let output = match __derive(parse_macro_input!(input as DeriveInput)) {
15 Ok(output) => output,
16 Err(err) => err.into_compile_error(),
17 };
18 proc_macro::TokenStream::from(output)
19}
20
21fn __derive(input: DeriveInput) -> Result<proc_macro2::TokenStream, Error> {
22 let DeriveInput {
23 attrs: _,
24 vis: _,
25 ident: input_ident,
26 mut generics,
27 data,
28 } = input;
29
30 let (body, fields) = match data {
31 Data::Struct(data) => struct_case(&input_ident, data),
32 Data::Enum(data) => enum_case(&input_ident, data),
33 Data::Union(_) => Err(Error::new_spanned(
34 &input_ident,
35 "#[derive(Default)] is not supported for unions",
36 )),
37 }?;
38
39 add_trait_bounds(&mut generics, &fields);
40 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
41
42 let output = quote! {
43 impl #impl_generics std::default::Default for #input_ident #ty_generics #where_clause {
44 fn default() -> Self {
45 #body
46 }
47 }
48 };
49
50 Ok(output)
51}
52
53fn struct_case(
54 struct_ident: &Ident,
55 data: DataStruct,
56) -> Result<(proc_macro2::TokenStream, Fields), Error> {
57 let data_constr = default_instance_constr(struct_ident, &data.fields);
58
59 Ok((data_constr, data.fields))
60}
61
62fn enum_case(
63 root_ident: &Ident,
64 data: DataEnum,
65) -> Result<(proc_macro2::TokenStream, Fields), Error> {
66 if data.variants.is_empty() {
67 return Err(Error::new_spanned(
68 root_ident,
69 "#[derive(Default)] is not supported for empty enums",
70 ));
71 }
72
73 let mut default_variants = data.variants.into_iter().filter(has_default_attr);
74
75 match (default_variants.next(), default_variants.next()) {
76 (Some(default_variant), None) => {
77 let default_variant_constr = {
78 let constr =
81 default_instance_constr(&default_variant.ident, &default_variant.fields);
82 quote!(Self::#constr)
83 };
84
85 Ok((default_variant_constr, default_variant.fields))
86 }
87
88 (Some(default_variant), Some(another_default_variant)) => {
89 let msg = "#[default] is defined multiple times";
90 if cfg!(nightly) {
91 let span = another_default_variant
92 .span()
93 .join(default_variant.span())
94 .expect("self and other are not from the same file");
95 Err(Error::new(span, msg))
96 } else {
97 Err(Error::new_spanned(another_default_variant, msg))
98 }
99 }
100 (None, _) => Err(Error::new_spanned(
101 root_ident,
102 "expected one variant with #[default]",
103 )),
104 }
105}
106
107fn default_instance_constr(data_constr_ident: &Ident, fields: &Fields) -> proc_macro2::TokenStream {
108 match fields {
109 Fields::Unit => quote!(#data_constr_ident),
110 Fields::Unnamed(unnamed) => {
111 let fields_constr = unnamed.unnamed.iter().map(|field| {
112 let ty = &field.ty;
113 quote!(#ty::default())
114 });
115 quote!(#data_constr_ident(#(#fields_constr),*))
116 }
117 Fields::Named(named) => {
118 let fields_constr = named.named.iter().map(|field| {
119 let field_name = field
120 .ident
121 .as_ref()
122 .expect("named fields should contain an ident");
123 let ty = &field.ty;
124 quote!(#field_name : #ty::default())
125 });
126 quote!(#data_constr_ident{#(#fields_constr),*})
127 }
128 }
129}
130
131fn has_default_attr(variant: &Variant) -> bool {
132 variant
133 .attrs
134 .get(0)
135 .map(|attr| attr.path().is_ident(DEFAULT_VARIANT_KEYWORD))
136 .unwrap_or_default()
137}
138
139fn add_trait_bounds(generics: &mut Generics, fields: &Fields) {
140 let used_types: HashSet<Ident> = fields
141 .iter()
142 .filter_map(|field| type_ident(&field.ty))
143 .cloned()
144 .collect();
145
146 for type_param in generics.type_params_mut() {
147 if used_types.contains(&type_param.ident) {
148 type_param
149 .bounds
150 .push(parse_quote!(::std::default::Default));
151 }
152 }
153}
154
155fn type_ident(ty: &Type) -> Option<&Ident> {
156 if let &Type::Path(TypePath {
157 qself: None,
158 ref path,
159 }) = ty
160 {
161 if path.segments.len() == 1 {
162 return Some(&path.segments.first()?.ident);
163 }
164 }
165 None
166}