Skip to main content

bauer_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{ToTokens, format_ident, quote, quote_spanned};
4use syn::{DeriveInput, parse::ParseStream, parse_macro_input, spanned::Spanned};
5
6use crate::{
7    builder::{BuilderAttr, Kind},
8    field::{BuilderField, Len, Repeat},
9};
10
11mod builder;
12mod field;
13mod type_state;
14mod util;
15
16/// The main macro.
17///
18/// # Usage
19///
20/// ```
21/// # use bauer_macros::Builder;
22/// #[derive(Builder)]
23/// pub struct Foo {
24///     #[builder(default = "42")]
25///     pub field_a: u32,
26///     pub field_b: bool,
27///     #[builder(into)]
28///     pub field_c: String,
29///     #[builder(repeat, repeat_n = 1..=3)]
30///     pub field_d: Vec<f64>,
31/// }
32/// ```
33///
34/// # Errors
35///
36/// When a builder can fail, the `.build` function will return an `Result` that contains the built
37/// value or a descriptive error.
38///
39/// If any of these cases are true, the `.build` function will return a `Result`:
40///
41/// **A field is required**  
42/// By default all fields are required, barring some exceptions (field is `Option`, field has a
43/// default value, field is `repeat`, etc)
44///
45/// **`repeat_n` is set**
46/// If `repeat_n` is set for any field, then `.build` will return an error if the range is not
47/// satisfied.
48///
49/// **Other Cases**
50/// There are other cases where `.build` can fail, this list is non-exhaustive.
51///
52/// ## Type-State Builder
53///
54/// If `kind` is set to `"type-state"`, then the builder will _not_ return a Result, as all build
55/// conditions are validated at compile-time.
56///
57/// # Builder Attributes
58///
59/// ## **`kind`**
60///
61/// ### Possible Values
62///
63/// **`"owned"`**  
64/// The builder functions consume and generate owned values
65///
66/// ```
67/// # use bauer_macros::Builder;
68/// #[derive(Builder)]
69/// #[builder(kind = "owned")]
70/// pub struct Foo {
71///     a: u32,
72/// }
73///
74/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
75/// let foo: Foo = Foo::builder()
76///     .a(42)
77///     .build()?;
78/// # Ok(()) }
79/// ```
80///
81/// **`"borrowed"`**  
82/// The builder functions operate on mutable references to the builder
83///
84/// _Note: After calling `.build()`, the builder is reset_
85///
86/// ```
87/// # use bauer_macros::Builder;
88/// #[derive(Builder)]
89/// #[builder(kind = "borrowed")]
90/// pub struct Foo {
91///     a: u32,
92/// }
93///
94/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
95/// let mut builder = Foo::builder();
96/// builder.a(42);
97/// let foo: Foo = builder.build()?;
98/// assert_eq!(foo.a, 42);
99/// # Ok(()) }
100/// ```
101///
102/// **`"type-state"`**  
103/// The builder and its functions are generated in a way that uses the type-state pattern.  This
104/// means that things like required fields can be enforced at compile-time.  Due to the constraints
105/// with type-state builders, some attributes may be limited.  All limitations are documented with
106/// the attributes.
107///
108/// The `.build` function will never return an error, as it is only possible to call when building
109/// the final structure is infallible.
110///
111/// ```compile_fail
112/// # use bauer_macros::Builder;
113/// #[derive(Builder)]
114/// #[builder(kind = "type-state")]
115/// pub struct Foo {
116///     a: u32,
117/// }
118///
119/// let foo: Foo = Foo::builder().build(); // fails to compile
120/// ```
121///
122/// Default: `"owned"`
123///
124/// ## **`prefix`**/**`suffix`**
125///
126/// Default: `prefix = "", suffix = ""`
127///
128/// Set the prefix or suffix for the generated builder functions
129///
130/// ```
131/// # use bauer_macros::Builder;
132/// #[derive(Builder)]
133/// #[builder(prefix = "set_")]
134/// pub struct Foo {
135///     a: u32,
136/// }
137///
138/// let f = Foo::builder()
139///     .set_a(42)
140///     .build()
141///     .unwrap();
142/// ```
143///
144/// ## **`visibility`**
145///
146/// Default: visibility of the struct
147///
148/// Set the visibilty for the created builder
149///
150/// The visibility can be set to `pub(self)` in order to make the builder private to the current
151/// module.
152///
153/// ```
154/// # use bauer_macros::Builder;
155/// #[derive(Builder)]
156/// #[builder(visibility = pub(crate))]
157/// pub struct Foo {
158///     a: u32,
159/// }
160/// ```
161///
162/// ## **`crate`**
163///
164/// Default: `bauer`
165///
166/// The name of this crate in the current crate.  This should only be needed if you rename the
167/// dependency in your `Cargo.toml`
168///
169/// ```
170/// # use bauer_macros::Builder;
171/// #[derive(Builder)]
172/// #[builder(crate = not_bauer)]
173/// pub struct Foo {
174///     a: u32,
175/// }
176/// ```
177///
178/// # Fields Attributes
179///
180/// ## **`default`**
181///
182/// Argument: Optional String
183///
184/// If provided, the field does not need to be specified, and will default to the value provided.
185/// If not value is provided to the `default` attribute, then [`Default::default`] will be used.
186///
187/// ```
188/// # use bauer_macros::Builder;
189/// # const _: &str = stringify!(
190/// #[derive(Builder)]
191/// # );
192/// # #[derive(Builder, PartialEq, Debug)]
193/// pub struct Foo {
194///     #[builder(default)]
195///     a: u32, // defaults to 0
196///     #[builder(default = "std::f32::consts::PI")]
197///     b: f32, // defaults to PI
198/// }
199///
200/// let foo = Foo::builder().build();
201/// assert_eq!(foo, Foo { a: 0, b: std::f32::consts::PI });
202///
203/// let foo = Foo::builder()
204///     .a(42)
205///     .build();
206/// assert_eq!(foo, Foo { a: 42, b: std::f32::consts::PI });
207/// ```
208///
209/// ## **`repeat`**
210///
211/// Make the method accept only a single item and build a list from it
212///
213/// When using a data structure that does not have the inner type as its singular generic, the type
214/// can be specified using `repeat = <type>`.
215///
216/// ```
217/// # use bauer_macros::Builder;
218/// # const _: &str = stringify!(
219/// #[derive(Builder)]
220/// # );
221/// # #[derive(Builder, PartialEq, Debug)]
222/// pub struct Foo {
223///     #[builder(repeat)]
224///     items: Vec<u32>,
225///     #[builder(repeat = char)]
226///     chars: String,
227/// }
228///
229/// let foo = Foo::builder()
230///     .items(0)
231///     .items(1)
232///     .items(2)
233///     .chars('a')
234///     .chars('b')
235///     .chars('c')
236///     .build();
237/// assert_eq!(
238///     foo,
239///     Foo {
240///         items: vec![0, 1, 2],
241///         chars: String::from("abc"),
242///     },
243/// );
244/// ```
245///
246/// ## **`repeat_n`**
247///
248/// Attribute `repeat` must also be specified.
249///
250/// Ensure that the length of items supplied via repeat is within a certain range.  The range can
251/// be any pattern that may be used in a `match` statement.  If this range is not met, an error
252/// will be returned.
253///
254/// #### Type-state Builder
255///
256/// When using the type-state kind, the value used is limited to the following (where `N` and `M`
257/// are integer literals)
258///
259/// - Integer Literals (`N`)
260/// - Closed Ranges (`N..M` or `N..=M`)
261/// - Minimum Ranges (`N..`)
262///
263/// Note: The length of the range is limited to 64, because big ranges slow compile-time.  If you
264/// require a larger range and the compile-time sacrifice is worth it, you can enable the
265/// `unlimited_range` feature.
266///
267/// ```
268/// # use bauer_macros::Builder;
269/// # const _: &str = stringify!(
270/// #[derive(Builder)]
271/// # );
272/// # #[derive(Builder, PartialEq, Debug)]
273/// pub struct Foo {
274///     #[builder(repeat, repeat_n = 2..=3)]
275///     items: Vec<u32>,
276/// }
277///
278/// let foo = Foo::builder()
279///     .items(0)
280///     .items(1)
281///     .items(2)
282///     .build()
283///     .unwrap();
284/// assert_eq!(foo, Foo { items: vec![0, 1, 2] });
285///
286/// let foo = Foo::builder()
287///     .items(0)
288///     .build()
289///     .unwrap_err();
290/// assert_eq!(foo, FooBuildError::RangeItems(1));
291/// ```
292///
293/// ## **`rename`**
294///
295/// Make the function that is generated use a different name from field itself.
296///
297/// ```
298/// # use bauer_macros::Builder;
299/// # const _: &str = stringify!(
300/// #[derive(Builder)]
301/// # );
302/// # #[derive(Builder, PartialEq, Debug)]
303/// pub struct Foo {
304///     #[builder(repeat, rename = "item")]
305///     items: Vec<u32>,
306/// }
307///
308/// let foo = Foo::builder()
309///     .item(0)
310///     .item(1)
311///     .build();
312/// assert_eq!(foo, Foo { items: vec![0, 1] });
313/// ```
314///
315/// ## **`skip_prefix`**/**`skip_suffix`**
316///
317/// If a prefix or a suffix is specified in the builder attributes, skip applying those to the name
318/// of this function.  This is epecially useful with `rename`.
319///
320/// ```
321/// # use bauer_macros::Builder;
322/// # const _: &str = stringify!(
323/// #[derive(Builder)]
324/// # );
325/// # #[derive(Builder, PartialEq, Debug)]
326/// #[builder(prefix = "set_")]
327/// pub struct Foo {
328///     #[builder(repeat, rename = "item", skip_prefix)]
329///     items: Vec<u32>,
330/// }
331///
332/// let foo = Foo::builder()
333///     .item(0)
334///     .item(1)
335///     .build();
336/// assert_eq!(foo, Foo { items: vec![0, 1] });
337/// ```
338///
339/// ## **`into`**
340///
341/// Make the method accept anything can be turned into the field.
342///
343/// ```
344/// # use bauer_macros::Builder;
345/// # const _: &str = stringify!(
346/// #[derive(Builder)]
347/// # );
348/// # #[derive(Builder, PartialEq, Debug)]
349/// pub struct Foo {
350///     #[builder(into)]
351///     a: String,
352/// }
353///
354/// let foo = Foo::builder()
355///     .a("hello")
356///     .build()
357///     .unwrap();
358/// assert_eq!(foo, Foo { a: String::from("hello") });
359/// ```
360///
361/// ## **`tuple`**
362///
363/// Rather than accepting a field that is a tuple by value, accept each element of the tuple as a
364/// separate parameters to the setter function.
365///
366/// If names are specified using `tuple(name1, name2, ...)`, they will be used for the names of the
367/// parameters to the function (see example).
368///
369/// Note: If used with `repeat`, `repeat` must come before `tuple`.
370///
371/// ```
372/// # use bauer_macros::Builder;
373/// #[derive(Builder)]
374/// pub struct Foo {
375///     #[builder(tuple)]
376///     tuple: (i32, i32),
377///     #[builder(tuple(a, b))]
378///     tuple_names: (i32, i32),
379///     #[builder(into, tuple(a, b))]
380///     tuple_into: (String, f64),
381///     #[builder(repeat, tuple(foo, bar))]
382///     tuples: Vec<(i32, i32)>,
383/// }
384///
385/// let foo = Foo::builder()
386///     .tuple(0, 1)
387///     .tuple_names(2, 3)
388///     .tuple_into("pi", 3.14)
389///     .tuples(4, 5)
390///     .tuples(6, 7)
391///     .build();
392/// ```
393///
394/// ## **`adapter`**
395///
396/// Create a custom implementation for the generated function.  The adapter uses the closure syntax
397/// with types specified and will generate the method accordingly.
398///
399/// Any number of arguments are allowed and will be used in the generated function.
400///
401/// Conflicts with `into` and `tuple`.
402///
403/// ```
404/// # use bauer_macros::Builder;
405/// # const _: &str = stringify!(
406/// #[derive(Builder)]
407/// # );
408/// # #[derive(Builder, PartialEq, Debug)]
409/// pub struct Foo {
410///     #[builder(adapter = |x: u32, y: u32| format!("{}/{}", x, y))]
411///     field: String,
412/// }
413///
414/// let foo = Foo::builder()
415///     .field(5, 23)
416///     .build()
417///     .unwrap();
418/// assert_eq!(foo, Foo { field: String::from("5/23") });
419/// ```
420#[proc_macro_derive(Builder, attributes(builder))]
421pub fn builder(input: TokenStream) -> TokenStream {
422    let input = parse_macro_input!(input as DeriveInput);
423    let ident = &input.ident;
424    let vis = &input.vis;
425
426    let attr = input.attrs.iter().find(|a| a.path().is_ident("builder"));
427    let attr: BuilderAttr = if let Some(attr) = attr {
428        match attr.parse_args_with(|ps: ParseStream| BuilderAttr::parse(ps, vis.clone())) {
429            Ok(a) => a,
430            Err(e) => return e.to_compile_error().into(),
431        }
432    } else {
433        BuilderAttr::new(vis.clone())
434    };
435
436    let data_struct = match input.data {
437        syn::Data::Struct(ref data_struct) => data_struct,
438        syn::Data::Enum(data_enum) => {
439            return syn::Error::new(data_enum.enum_token.span(), "Enums are not supported.")
440                .to_compile_error()
441                .into();
442        }
443        syn::Data::Union(data_union) => {
444            return syn::Error::new(data_union.union_token.span(), "Unions are not supported.")
445                .to_compile_error()
446                .into();
447        }
448    };
449
450    let self_param = attr.self_param();
451    let builder_vis = &attr.vis;
452
453    let builder = format_ident!("{}Builder", ident);
454    let build_err = format_ident!("{}BuildError", ident);
455    let inner = format_ident!("__unsafe_builder_content");
456
457    let fields_named: Vec<_> = match data_struct.fields {
458        syn::Fields::Named(ref fields_named) => match fields_named
459            .named
460            .iter()
461            .enumerate()
462            .map(|(index, f)| BuilderField::parse(f, &attr, ident, index))
463            .collect::<Result<_, _>>()
464        {
465            Ok(v) => v,
466            Err(e) => return e.to_compile_error().into(),
467        },
468        syn::Fields::Unnamed(_) => {
469            return syn::Error::new(ident.span(), "Unnamed fields are not supported.")
470                .to_compile_error()
471                .into();
472        }
473        syn::Fields::Unit => {
474            return syn::Error::new(ident.span(), "Unit structs are not supported.")
475                .to_compile_error()
476                .into();
477        }
478    };
479
480    let private_module = attr.private_module();
481    let fields = fields_named.iter().map(|f| {
482        if let Some(Repeat {
483            inner_ty,
484            array,
485            len,
486        }) = &f.attr.repeat
487        {
488            if *array {
489                let Len::Raw { pattern, .. } = &len else {
490                    unreachable!("If array, then Len::Raw set");
491                };
492                quote! { #private_module::PushableArray<#pattern, #inner_ty> }
493            } else {
494                quote! { ::std::vec::Vec<#inner_ty> }
495            }
496        } else {
497            let ty = &f.ty;
498            quote! { ::core::option::Option<#ty> }
499        }
500    });
501
502    if attr.kind == Kind::TypeState {
503        return type_state::type_state_builder(&attr, &input, &fields_named).into();
504    }
505
506    let functions: TokenStream2 = fields_named
507        .iter()
508        .map(|f| f.function(&attr, &inner))
509        .collect();
510
511    let (build_err_variants, build_err_messages): (Vec<_>, Vec<_>) = fields_named
512        .iter()
513        .flat_map(|f| {
514            let mut variants = Vec::new();
515            if let Some(err) = &f.missing_err {
516                let msg = format!("Missing required field '{}'", f.ident);
517                variants.push((
518                    err.to_token_stream(),
519                    quote! { Self::#err => write!(f, #msg) },
520                ));
521            }
522            if let Some(Repeat {
523                len: Len::Raw { pattern, error },
524                ..
525            }) = &f.attr.repeat
526            {
527                variants.push((
528                    quote! {
529                        #error(usize)
530                    },
531                    quote!{
532                        Self::#error(n) => write!(f, "Invalid number of repeat arguments provided.  Expected {:?}, got {}", #pattern, n)
533                    },
534                ));
535            }
536            variants.into_iter()
537        })
538        .collect();
539
540    let build_fields = fields_named.iter().map(|field| {
541        let name = &field.ident;
542        let field_i = field.tuple_index();
543
544        if let Some(rep @ Repeat { inner_ty, .. }) = &field.attr.repeat {
545            if let Len::Raw { pattern, error } = &rep.len {
546                let value = if rep.array {
547                    quote_spanned! { inner_ty.span()=> {
548                        let arr = ::core::mem::take(&mut self.#inner.#field_i);
549                        arr.into_array()
550                            .expect("The match ensures the length of this array is correct")
551                    }}
552                } else {
553                    quote_spanned! { inner_ty.span()=>
554                        self.#inner.#field_i.drain(..).collect()
555                    }
556                };
557                quote_spanned! { pattern.span()=>
558                    #name: match self.#inner.#field_i.len() {
559                        #pattern => #value, // TODO: Take and then slice.try_into()
560                        len => return Err(#build_err::#error(len)),
561                    }
562                }
563            } else {
564                quote_spanned! { inner_ty.span()=>
565                    // using associated function syntax as that gives better error messages
566                    // (i.e., not "call chain may not have expected associated type"
567                    #name: ::std::iter::FromIterator::from_iter(self.#inner.#field_i.drain(..))
568                }
569            }
570        } else if field.wrapped_option {
571            quote! {
572                #name: self.#inner.#field_i
573            }
574        } else if let Some(default) = &field.attr.default {
575            if let Some(default) = default {
576                if field.attr.into {
577                    quote! {
578                        #name: self.#inner.#field_i.take().unwrap_or_else(|| #default.into())
579                    }
580                } else {
581                    quote! {
582                        #name: self.#inner.#field_i.take().unwrap_or_else(|| #default)
583                    }
584                }
585            } else {
586                quote_spanned! {
587                    field.ty.span() =>
588                    #name: self.#inner.#field_i.take().unwrap_or_default()
589                }
590            }
591        } else {
592            let err = field
593                .missing_err
594                .as_ref()
595                .expect("missing_err is set when default is none");
596            quote! {
597                #name: self.#inner.#field_i.take().ok_or(#build_err::#err)?
598            }
599        }
600    });
601
602    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
603
604    let build_fn = if build_err_variants.is_empty() {
605        quote! {
606            #builder_vis fn build(#self_param) -> #ident #ty_generics {
607                #[allow(deprecated)] // #inner is set to deprecated
608                {
609                    #ident {
610                        #(#build_fields),*
611                    }
612                }
613            }
614        }
615    } else {
616        quote! {
617            #builder_vis fn build(#self_param) -> ::core::result::Result<#ident #ty_generics, #build_err> {
618                #[allow(deprecated)] // #inner is set to deprecated
619                {
620                    Ok(#ident {
621                        #(#build_fields),*
622                    })
623                }
624            }
625        }
626    };
627
628    let build_err_enum = if build_err_variants.is_empty() {
629        quote! {}
630    } else {
631        quote! {
632            #[derive(::std::fmt::Debug, ::std::cmp::PartialEq, ::std::cmp::Eq)]
633            #builder_vis enum #build_err {
634                #(#build_err_variants),*
635            }
636
637            impl ::core::fmt::Display for #build_err {
638                fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
639                    use ::core::fmt::Write;
640                    match self {
641                        #(#build_err_messages),*
642                    }
643                }
644            }
645
646            impl ::core::error::Error for #build_err {}
647        }
648    };
649
650    let init = fields_named
651        .iter()
652        .map(|_| quote! { ::core::default::Default::default() });
653
654    let into_impl = if build_err_variants.is_empty() {
655        quote! {
656            impl #impl_generics ::core::convert::From<#builder #ty_generics> for #ident #ty_generics {
657                fn from(mut builder: #builder #ty_generics) -> Self {
658                    builder.build()
659                }
660            }
661        }
662    } else {
663        quote! {
664            impl #impl_generics ::core::convert::TryFrom<#builder #ty_generics> for #ident #ty_generics {
665                type Error = #build_err;
666
667                fn try_from(mut builder: #builder #ty_generics) -> Result<Self, Self::Error> {
668                    builder.build()
669                }
670            }
671        }
672    };
673
674    quote! {
675        #build_err_enum
676
677        #[must_use = "The builder doesn't construct its type until `.build()` is called"]
678        #builder_vis struct #builder #impl_generics {
679            #[deprecated = "This field is for internal use only; You almost certainly don't need to touch this. If you encounter a bug or missing feature, file an issue on the repo."]
680            #[doc(hidden)]
681            #inner: (#(#fields,)*),
682        }
683
684        impl #impl_generics #builder #ty_generics #where_clause {
685            #functions
686
687            #build_fn
688        }
689
690        impl #impl_generics ::core::default::Default for #builder #ty_generics #where_clause {
691            fn default() -> Self {
692                Self {
693                    #inner: (#(#init,)*),
694                }
695            }
696        }
697
698        impl #impl_generics #ident #ty_generics #where_clause {
699            #builder_vis fn builder() -> #builder #ty_generics {
700                ::core::default::Default::default()
701            }
702        }
703
704        #into_impl
705    }
706    .into()
707}