makeit_derive/
lib.rs

1#[macro_use]
2extern crate quote;
3
4use proc_macro::TokenStream;
5use quote::ToTokens;
6use syn::spanned::Spanned;
7use syn::{GenericParam, ItemStruct, Visibility};
8
9fn capitalize(s: &str) -> String {
10    let mut c = s.chars();
11    match c.next() {
12        None => String::new(),
13        Some(ch) => ch.to_uppercase().collect::<String>() + c.as_str(),
14    }
15}
16
17#[proc_macro_derive(Builder, attributes(default))]
18pub fn derive_builder(input: TokenStream) -> TokenStream {
19    let input: ItemStruct = syn::parse(input).unwrap();
20    let struct_name = &input.ident;
21    // You can't accuse me of being original. 🤷‍♂️
22    let builder_name = format_ident!("{}Builder", input.ident);
23    // We'll nest the entirety of the builder's helper types in a private module so that they don't
24    // leak into the user's scope.
25    let mod_name = format_ident!("{}Fields", builder_name);
26
27    let struct_generics = input.generics.params.iter().collect::<Vec<_>>();
28    // The type parameter names representing each field of the type being built.
29    let mut set_fields_generics = vec![];
30    // The type names representing fields that have been initialized.
31    let mut all_set = vec![];
32    // The type names representing fields that have not yet been initialized.
33    let mut all_unset = vec![];
34
35    // These are the generic parameters for the `impl` that lets the user call `.build()`. They
36    // normally would all have to be "field_foo_set" and need no params beyond the underlying
37    // type's, but we support default values so we need to account for them to let people build
38    // without setting those.
39    let mut buildable_generics = vec![];
40    let mut buildable_generics_use = vec![];
41
42    let mut default_where_clauses = vec![];
43
44    for (i, field) in input.fields.iter().enumerate() {
45        // We'll use these as the name of the type parameters for the builder's fields.
46        let field_name = format_ident!(
47            "{}",
48            match &field.ident {
49                Some(field) => capitalize(&field.to_string()),
50                None => format!("Field{}", i), // Idents can't start with numbers.
51            }
52        );
53        // We'll use these as the base for the types representing the builder state.
54        let field_generic_name = format_ident!(
55            "Field{}",
56            match &field.ident {
57                Some(field) => capitalize(&field.to_string()),
58                None => format!("{}", i),
59            }
60        );
61        let set_field_generic_name = format_ident!("{}Set", field_name);
62        let unset_field_generic_name = format_ident!("{}Unset", field_name);
63
64        if field.attrs.iter().any(|attr| attr.path.is_ident("default")) {
65            let ty = &field.ty;
66            buildable_generics.push(field_generic_name.clone());
67            buildable_generics_use.push(field_generic_name.clone());
68            default_where_clauses.push(quote_spanned!(ty.span() => #ty: ::std::default::Default));
69        } else {
70            buildable_generics_use.push(set_field_generic_name.clone());
71        }
72        set_fields_generics.push(field_generic_name);
73        all_set.push(set_field_generic_name);
74        all_unset.push(unset_field_generic_name);
75    }
76
77    // `input.generics.params` contains bounds. Here we get only the params without the bounds for
78    // use in type uses, not `impl` declarations.
79    let use_struct_generics = input
80        .generics
81        .params
82        .iter()
83        .map(|param| match param {
84            GenericParam::Type(p) => {
85                let ident = &p.ident;
86                quote!(#ident)
87            }
88            GenericParam::Lifetime(p) => {
89                let lt = &p.lifetime;
90                quote!(#lt)
91            }
92            GenericParam::Const(p) => {
93                let ident = &p.ident;
94                quote!(#ident)
95            }
96        })
97        .collect::<Vec<_>>();
98
99    let comma = if use_struct_generics.is_empty() {
100        quote!()
101    } else {
102        quote!(,)
103    };
104
105    let constrained_generics = quote!(<#(#struct_generics),* #comma #(#set_fields_generics),*>);
106    let where_clause = &input.generics.where_clause;
107    let where_clause = if where_clause.is_some() {
108        quote!(#where_clause, #(#default_where_clauses),*)
109    } else {
110        quote!(where #(#default_where_clauses),*)
111    };
112    let use_generics = quote!(<#(#use_struct_generics),* #comma #(#set_fields_generics),*>);
113
114    // Construct each of the setter methods. These desugar roughly to the following signature:
115    //
116    //   fn set_<field_name>(self, value: <field_type>) -> <Type>Builder
117    //
118    let setters = input.fields.iter().enumerate().map(|(i, f)| {
119
120        let (field, method_name) = match &f.ident {
121            Some(field) => (quote!(#field), format_ident!("set_{}", field)),
122            None => {
123                let i = syn::Index::from(i);
124                (quote!(#i), format_ident!("set_{}", i))
125            }
126        };
127        let inner_method_name = format_ident!("inner_{}", method_name);
128        let decl_generics = set_fields_generics
129            .iter()
130            .enumerate()
131            .filter(|(j, _)| i!=*j)
132            .map(|(_, f)| f);
133        let decl_generics = quote!(<#(#struct_generics),* #comma #(#decl_generics),*>);
134        let unset_generics = set_fields_generics
135            .iter()
136            .zip(input.fields.iter())
137            .enumerate()
138            .map(|(j, (g, f))| if i == j {
139                // FIXME: dedup this logic.
140                let field_name = format_ident!("{}", match &f.ident {
141                    Some(field) => capitalize(&field.to_string()),
142                    None => format!("Field{}", i),
143                });
144                let f = format_ident!("{}Unset", field_name);
145                quote!(#f)
146            } else {
147                quote!(#g)
148            });
149        let unset_generics = quote!(<#(#use_struct_generics),* #comma #(#unset_generics),*>);
150        let set_generics = set_fields_generics
151            .iter().zip(input.fields.iter()).enumerate().map(|(j, (g, f))| if i == j {
152            let field_name = format_ident!("{}", match &f.ident {
153                Some(field) => capitalize(&field.to_string()),
154                None => format!("Field{}", i),
155            });
156            let f = format_ident!("{}Set", field_name);
157            quote!(#f)
158        } else {
159            quote!(#g)
160        });
161        let set_generics = quote!(<#(#use_struct_generics),* #comma #(#set_generics),*>);
162        let ty = &f.ty;
163        quote! {
164            impl #decl_generics #builder_name #unset_generics #where_clause {
165                #[must_use]
166                pub fn #method_name(mut self, value: #ty) -> #builder_name #set_generics {
167                    self.#inner_method_name(value);
168                    // We do the following instead of `::core::mem::transmute(self)` here
169                    // because we can't `transmute` on fields that involve generics.
170                    let ptr = &self as *const #builder_name #unset_generics as *const #builder_name #set_generics;
171                    ::core::mem::forget(self);
172                    unsafe {
173                        ptr.read()
174                    }
175                }
176
177                fn #inner_method_name(&mut self, value: #ty) {
178                    let inner = self.inner.as_mut_ptr();
179                    // We know that `inner` is a valid pointer that we can write to.
180                    unsafe {
181                        ::core::ptr::addr_of_mut!((*inner).#field).write(value);
182                    }
183                }
184            }
185        }
186    });
187    let field_ptr_methods = input.fields.iter().enumerate().map(|(i, f)| {
188        let (field, method_name) = match &f.ident {
189            Some(field) => (quote!(#field), format_ident!("ptr_{}", i)),
190            None => {
191                let i = syn::Index::from(i);
192                (quote!(#i), format_ident!("ptr_{}", i))
193            }
194        };
195        let ty = &f.ty;
196        quote! {
197            /// Returns a mutable pointer to a field of the type being built. This is useful if the
198            /// initialization requires subtle unsafe shenanigans. You will need to call
199            /// `.unsafe_build()` after ensuring all of the fields have been initialized.
200            #[must_use]
201            pub unsafe fn #method_name(&mut self) -> *mut #ty {
202                let inner = self.inner.as_mut_ptr();
203                ::core::ptr::addr_of_mut!((*inner).#field)
204            }
205        }
206    });
207
208    let vis = match &input.vis {
209        // For private `struct`s we need to change teh visibility of their builders to be
210        // accessible from their scope without leaking as `pub`.
211        Visibility::Inherited => quote!(pub(super)),
212        vis => quote!(#vis),
213    };
214
215    let defaults = input.fields.iter().enumerate().filter_map(|(i, f)| {
216        let field = match &f.ident {
217            Some(field) => format_ident!("inner_set_{}", field),
218            None => format_ident!("inner_set_{}", i),
219        };
220        f.attrs
221            .iter()
222            .find(|attr| attr.path.is_ident("default"))
223            .map(|attr| {
224                let default = &attr.tokens;
225                if default.is_empty() {
226                    quote!(builder.#field(::std::default::Default::default());)
227                } else {
228                    quote!(builder.#field(#default);)
229                }
230            })
231    });
232    // Construct the params for the `impl` item that provides the `build` method. Normally it would
233    // be straightforward: you just specify that all the type params corresponding to fields are
234    // set to the `Set` state, but that doesn't account for defaulted type params.
235    let build_generics = input.generics.params.iter().collect::<Vec<_>>();
236    let build_generics = if buildable_generics.is_empty() {
237        quote!(<#(#build_generics),*>)
238    } else {
239        let comma = if build_generics.is_empty() {
240            quote!()
241        } else {
242            quote!(,)
243        };
244        quote!(<#(#build_generics),* #comma #(#buildable_generics),*>)
245    };
246    let build_use_generics =
247        quote!(<#(#use_struct_generics),* #comma #(#buildable_generics_use),*>);
248
249    let builder_assoc_type = quote! {
250        type Builder = #builder_name<#(#use_struct_generics),* #comma #(#all_unset),*>;
251    };
252
253    let input = quote! {
254        #[allow(non_snake_case)]
255        #[deny(unused_must_use, clippy::pedantic)]
256        mod #mod_name {
257            use super::*;
258
259            #[must_use]
260            #[repr(transparent)]
261            #vis struct #builder_name #constrained_generics #where_clause {
262                inner: ::core::mem::MaybeUninit<#struct_name<#(#use_struct_generics),*>>,
263                __fields: ::core::marker::PhantomData<(#(#set_fields_generics),*)>,
264            }
265
266            #(pub struct #all_set;)*
267            #(pub struct #all_unset;)*
268
269            impl<#(#struct_generics),*> ::makeit::Buildable for #struct_name <#(#use_struct_generics),*>
270            #where_clause
271            {
272                #builder_assoc_type
273
274                /// Returns a builder that lets you initialize `Self` field by field in a zero-cost,
275                /// type-safe manner.
276                #[must_use]
277                #[allow(unused_parens)]
278                fn builder() -> Self::Builder {
279                    let mut builder = #builder_name {
280                        inner: unsafe {
281                            ::core::mem::MaybeUninit::<Self>::uninit()
282                        },
283                        __fields: ::core::marker::PhantomData,
284                    };
285                    #(#defaults)*
286                    builder
287                }
288            }
289
290            impl #build_generics #builder_name #build_use_generics #where_clause {
291                /// Finalize the builder.
292                #[must_use]
293                pub fn build(self) -> #struct_name<#(#use_struct_generics),*> {
294                    // This method is only callable if all of the fields have been initialized, making
295                    // the underlying value at `inner` correctly formed.
296                    unsafe { self.unsafe_build() }
297                }
298            }
299
300            #(#setters)*
301
302            impl #constrained_generics #builder_name #use_generics #where_clause {
303
304                #(#field_ptr_methods)*
305
306                /// HERE BE DRAGONS!
307                ///
308                /// # Safety
309                ///
310                /// You're dealing with `MaybeUninit`. If you have to research what that is, you don't
311                /// want this.
312                #[must_use]
313                pub unsafe fn maybe_uninit(self) -> ::core::mem::MaybeUninit<#struct_name<#(#use_struct_generics),*>> {
314                    self.inner
315                }
316
317                /// Only call if you have set a field through their mutable pointer, instead
318                /// of using the type-safe builder. It is your responsibility to ensure that
319                /// all fields have been set before doing this.
320                #[must_use]
321                pub unsafe fn unsafe_build(self) -> #struct_name<#(#use_struct_generics),*> {
322                    self.inner.assume_init()
323                }
324            }
325        }
326    };
327
328    TokenStream::from(input.into_token_stream())
329}