Skip to main content

bauer_macros/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::{ToTokens, format_ident, quote, quote_spanned};
6use syn::{
7    DeriveInput, Ident, Pat, parse::ParseStream, parse_macro_input, parse_quote_spanned,
8    spanned::Spanned,
9};
10
11use crate::{
12    attr::builder::{BuilderAttr, Kind},
13    attr::field::{BuilderField, Len, Repeat},
14    util::parallel_assign,
15};
16
17mod attr;
18mod type_state;
19mod util;
20
21fn builder_fn(input: &DeriveInput, builder_attr: &BuilderAttr, builder: &Ident) -> TokenStream2 {
22    let ident = &input.ident;
23    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
24    let konst = builder_attr.konst_kw();
25    let builder_vis = &builder_attr.vis;
26
27    let name = &builder_attr.builder_fn.name;
28    let attributes = &builder_attr.builder_fn.attributes;
29
30    quote! {
31        impl #impl_generics #ident #ty_generics #where_clause {
32            #(#attributes)*
33            #builder_vis #konst fn #name() -> #builder #ty_generics {
34                #builder::new()
35            }
36        }
37    }
38}
39
40fn parse_build_attr(input: &DeriveInput) -> syn::Result<BuilderAttr> {
41    let mut out = BuilderAttr::new(input.vis.clone());
42    for attr in input.attrs.iter().filter(|a| a.path().is_ident("builder")) {
43        attr.parse_args_with(|ps: ParseStream| out.parse(ps))?;
44    }
45    Ok(out)
46}
47
48#[proc_macro_derive(Builder, attributes(builder))]
49pub fn builder(input: TokenStream) -> TokenStream {
50    let input = parse_macro_input!(input as DeriveInput);
51    let ident = &input.ident;
52
53    let builder_attr: BuilderAttr = match parse_build_attr(&input) {
54        Ok(a) => a,
55        Err(e) => return e.to_compile_error().into(),
56    };
57
58    let data_struct = match input.data {
59        syn::Data::Struct(ref data_struct) => data_struct,
60        syn::Data::Enum(data_enum) => {
61            return syn::Error::new(data_enum.enum_token.span(), "Enums are not supported.")
62                .to_compile_error()
63                .into();
64        }
65        syn::Data::Union(data_union) => {
66            return syn::Error::new(data_union.union_token.span(), "Unions are not supported.")
67                .to_compile_error()
68                .into();
69        }
70    };
71
72    let self_param = builder_attr.self_param();
73    let builder_vis = &builder_attr.vis;
74
75    let builder = format_ident!("{}Builder", ident);
76    let build_err = builder_attr.error.name(ident);
77    let inner = format_ident!("__unsafe_builder_content");
78
79    let mut tuple_index = 0;
80    let fields: Vec<_> = match data_struct.fields {
81        syn::Fields::Named(ref fields_named) => match fields_named
82            .named
83            .iter()
84            .map(|f| BuilderField::parse(f, &builder_attr, ident, &mut tuple_index))
85            .collect::<Result<_, _>>()
86        {
87            Ok(v) => v,
88            Err(e) => return e.to_compile_error().into(),
89        },
90        syn::Fields::Unnamed(_) => {
91            return syn::Error::new(ident.span(), "Unnamed fields are not supported.")
92                .to_compile_error()
93                .into();
94        }
95        syn::Fields::Unit => {
96            return syn::Error::new(ident.span(), "Unit structs are not supported.")
97                .to_compile_error()
98                .into();
99        }
100    };
101
102    let private_module = builder_attr.private_module();
103
104    if builder_attr.kind == Kind::TypeState {
105        return type_state::type_state_builder(&builder_attr, &input, fields).into();
106    }
107
108    let (field_types, init): (Vec<_>, Vec<_>) = fields
109        .iter()
110        .filter(|f| !f.should_skip())
111        .map(|f| {
112            if let Some(Repeat {
113                inner_ty,
114                array,
115                len,
116                ..
117            }) = &f.attr.repeat
118            {
119                if *array {
120                    let Len::Raw { pattern, .. } = &len else {
121                        unreachable!("If array, then Len::Raw set");
122                    };
123                    (
124                        quote! { #private_module::PushableArray<#pattern, #inner_ty> },
125                        quote! { #private_module::PushableArray::new() },
126                    )
127                } else {
128                    (
129                        quote! { ::std::vec::Vec<#inner_ty> },
130                        quote! { ::std::vec::Vec::new() },
131                    )
132                }
133            } else {
134                let ty = &f.ty;
135                (
136                    quote! { ::core::option::Option<#ty> },
137                    quote! { ::core::option::Option::None },
138                )
139            }
140        })
141        .collect();
142
143    let functions: TokenStream2 = fields
144        .iter()
145        .filter(|f| !f.should_skip())
146        .map(|f| f.function(&builder_attr, &inner))
147        .collect();
148
149    let (build_err_variants, build_err_messages): (Vec<_>, Vec<_>) = fields
150        .iter()
151        .filter(|f| !f.should_skip())
152        .flat_map(|f| {
153            let mut variants = Vec::new();
154            if let Some(err) = &f.missing_err {
155                let msg = format!("Missing required field '{}'", f.ident);
156                variants.push((
157                    err.to_token_stream(),
158                    quote! { Self::#err => write!(f, #msg) },
159                ));
160            }
161            if let Some(Repeat {
162                len: Len::Raw { pattern, error },
163                ..
164            }) = &f.attr.repeat
165            {
166                let error_msg = format!(
167                    "Invalid number of repeat arguments provided.  Expected {}, got {{}}",
168                    pattern.to_token_stream()
169                );
170                variants.push((
171                    quote! {
172                        #error(usize)
173                    },
174                    quote! {
175                        Self::#error(n) => write!(f, #error_msg, n)
176                    },
177                ));
178            }
179            variants.into_iter()
180        })
181        .collect();
182
183    let not_skipped_field_values = fields.iter().filter(|f| !f.should_skip()).map(|field| {
184        let name = &field.ident;
185        let wrapped_ty = &field.wrapped_type();
186        let field_i = field.tuple_index();
187
188        let value = if let Some(rep @ Repeat { inner_ty, collector, .. }) = &field.attr.repeat {
189            if let Len::Raw { pattern, error } = &rep.len {
190                let value = if rep.array {
191                    quote_spanned! { inner_ty.span()=> {
192                        let arr = ::core::mem::replace(&mut inner.#field_i, #private_module::PushableArray::new());
193                        arr.into_array()
194                            .expect("The match ensures the length of this array is correct")
195                    }}
196                } else {
197                    assert!(!rep.array);
198                    assert!(!builder_attr.konst);
199
200                    collector.collect(parse_quote_spanned! {inner_ty.span()=>
201                        inner.#field_i.drain(..)
202                    })
203                };
204
205                if let Pat::Ident(_) = pattern {
206                    quote_spanned! { pattern.span()=>
207                        if inner.#field_i.len() == #pattern {
208                            #value
209                        } else {
210                            return Err(#build_err::#error(self.#inner.#field_i.len()));
211                        }
212                    }
213                } else {
214                    quote_spanned! { pattern.span()=>
215                        match inner.#field_i.len() {
216                            #pattern => #value,
217                            len => return Err(#build_err::#error(len)),
218                        }
219                    }
220                }
221            } else {
222                assert!(!rep.array);
223                assert!(!builder_attr.konst);
224                collector.collect(parse_quote_spanned! {inner_ty.span()=>
225                    inner.#field_i.drain(..)
226                })
227            }
228        } else if field.wrapped_option {
229            quote! { inner.#field_i.take() }
230        } else if let Some(default) = &field.attr.default {
231            let default = default.to_value(field.attr.into);
232            quote! {
233                // NOTE: not using Option::unwrap_or_else, since it's not stable in const
234                match inner.#field_i.take() {
235                    Some(v) => v,
236                    None => #default
237                }
238            }
239        } else {
240            let err = field
241                .missing_err
242                .as_ref()
243                .expect("missing_err is set when default is none");
244            quote! {
245                // NOTE: not using Option::ok_or, since it's not stable in const
246                match inner.#field_i.take() {
247                    Some(v) => v,
248                    None => return Err(#build_err::#err),
249                }
250            }
251        };
252
253        quote! {{
254            let #name: #wrapped_ty = #value;
255            #name
256        }}
257    });
258
259    let not_skipped_fields: Vec<_> = fields
260        .iter()
261        .filter(|f| !f.should_skip())
262        .map(|f| &f.ident)
263        .collect();
264
265    let set_not_skipped_fields = parallel_assign(
266        not_skipped_fields.iter().copied(),
267        not_skipped_field_values,
268        quote! {
269            let inner = &mut self.#inner;
270        },
271    );
272
273    let set_skipped_fields = parallel_assign(
274        fields.iter().filter(|f| f.should_skip()).map(|f| &f.ident),
275        fields.iter().filter_map(BuilderField::skipped_field_value),
276        quote! {
277            #[allow(unused)]
278            let (#(#not_skipped_fields),*) = (#(&#not_skipped_fields),*);
279        },
280    );
281
282    let finish_fields = fields.iter().map(|field| &field.ident);
283
284    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
285
286    let konst = builder_attr.konst_kw();
287
288    let (ret_ty, ret_val) = if build_err_variants.is_empty() && !builder_attr.error.force {
289        (quote! { #ident #ty_generics }, quote! { ret })
290    } else {
291        (
292            quote! { ::core::result::Result<#ident #ty_generics, #build_err> },
293            quote! { Ok(ret) },
294        )
295    };
296
297    let build_fn_attributes = &builder_attr.build_fn.attributes;
298    let build_fn_name = &builder_attr.build_fn.name;
299    let build_fn = quote! {
300        #(#build_fn_attributes)*
301        #builder_vis #konst fn #build_fn_name(#self_param) -> #ret_ty {
302            #[allow(deprecated)] // #inner is set to deprecated
303            let ret = {
304                #set_not_skipped_fields
305                #set_skipped_fields
306
307                #ident {
308                    #(#finish_fields),*
309                }
310            };
311            #ret_val
312        }
313    };
314
315    let build_err_enum = if build_err_variants.is_empty() && !builder_attr.error.force {
316        quote! {}
317    } else {
318        let attributes = &builder_attr.error.attributes;
319        quote! {
320            #(#attributes)*
321            #[derive(::std::fmt::Debug, ::std::cmp::PartialEq, ::std::cmp::Eq)]
322            #[allow(enum_variant_names)]
323            #builder_vis enum #build_err {
324                #(#build_err_variants),*
325            }
326
327            impl ::core::fmt::Display for #build_err {
328                fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
329                    use ::core::fmt::Write;
330                    match *self {
331                        #(#build_err_messages),*
332                    }
333                }
334            }
335
336            impl ::core::error::Error for #build_err {}
337        }
338    };
339
340    let into_impl = if build_err_variants.is_empty() && !builder_attr.error.force {
341        quote! {
342            impl #impl_generics ::core::convert::From<#builder #ty_generics> for #ident #ty_generics #where_clause {
343                fn from(mut builder: #builder #ty_generics) -> Self {
344                    builder.#build_fn_name()
345                }
346            }
347        }
348    } else {
349        quote! {
350            #[allow(clippy::infallible_try_from)]
351            impl #impl_generics ::core::convert::TryFrom<#builder #ty_generics> for #ident #ty_generics #where_clause {
352                type Error = #build_err;
353
354                fn try_from(mut builder: #builder #ty_generics) -> Result<Self, Self::Error> {
355                    builder.#build_fn_name()
356                }
357            }
358        }
359    };
360
361    let builder_attributes = &builder_attr.attributes;
362
363    let builder_fn = builder_fn(&input, &builder_attr, &builder);
364
365    let assert_crate = builder_attr.assert_crate();
366    quote! {
367        #assert_crate
368
369        #build_err_enum
370
371        #(#builder_attributes)*
372        #[must_use = "The builder doesn't construct its type until `.build()` is called"]
373        #builder_vis struct #builder #impl_generics #where_clause {
374            #[deprecated = "This field is for internal use only; You almost certainly don't need to touch this. If you encounter a bug or missing feature, file an issue on the repo."]
375            #[doc(hidden)]
376            #inner: (#(#field_types,)*),
377        }
378
379        impl #impl_generics #builder #ty_generics #where_clause {
380            #functions
381
382            #build_fn
383        }
384
385        impl #impl_generics #builder #ty_generics #where_clause {
386            #konst fn new() -> Self {
387                Self {
388                    #inner: (#(#init,)*),
389                }
390            }
391        }
392
393        impl #impl_generics ::core::default::Default for #builder #ty_generics #where_clause {
394            fn default() -> Self {
395                Self::new()
396            }
397        }
398
399        #builder_fn
400
401        #into_impl
402    }
403    .into()
404}