builder_pattern_fsm/
lib.rs

1use std::collections::HashMap;
2
3use proc_macro2;
4use syn::spanned::Spanned;
5
6const FINAL_BUILDER_SUFFIX: &str = "FinalBuilder";
7const FIELDS_CONTAINER_SUFFIX: &str = "FieldsContainer";
8
9fn extract_defaults(
10    fields: &[syn::Field],
11) -> std::collections::HashMap<syn::Ident, proc_macro2::TokenStream> {
12    let mut defaults = std::collections::HashMap::with_capacity(fields.len());
13    for field in fields.iter() {
14        for attr in field.attrs.iter() {
15            match attr.meta.clone() {
16                syn::Meta::List(metalist) => {
17                    match metalist
18                        .path
19                        .segments
20                        .first()
21                        .unwrap()
22                        .ident
23                        .to_string()
24                        .as_str()
25                    {
26                        "default" => {
27                            defaults.insert(field.ident.clone().unwrap(), metalist.tokens.clone());
28                        }
29                        value => panic!("Attribute '{value}' is not supported"),
30                    }
31                }
32                _ => panic!("No"),
33            }
34        }
35    }
36    defaults
37}
38
39/// Given a reference to a `syn::Type`, this function attempts to extract the inner type `T`
40/// if the input type is an `Option<T>`. If the input type is not an `Option`, the function
41/// returns `None`.
42///
43/// This extraction process is intricate because, during code generation, we lack information about the types.
44/// Consequently, we must operate with token streams.
45fn extract_from_option_type(ty: &syn::Type) -> Option<syn::Type> {
46    match ty {
47        syn::Type::Path(syn::TypePath {
48            path: syn::Path { segments, .. },
49            ..
50        }) => segments
51            .iter()
52            .find(|segment| segment.ident == stringify!(Option))
53            .map(|segment| match segment.arguments {
54                syn::PathArguments::AngleBracketed(ref inner) => inner
55                    .args
56                    .first()
57                    .map(|ty| match ty {
58                        syn::GenericArgument::Type(ty) => Some(ty.clone()),
59                        _ => None,
60                    })
61                    .flatten(),
62                _ => None,
63            })
64            .flatten(),
65        _ => None,
66    }
67}
68
69fn generate_maybe_wrapped_with_option(ty: &syn::Type) -> proc_macro2::TokenStream {
70    if extract_from_option_type(ty).is_some() {
71        return quote::quote!(#ty);
72    }
73    quote::quote!(::std::option::Option<#ty>)
74}
75
76fn generate_container_fields(data: &syn::DataStruct) -> proc_macro2::TokenStream {
77    let wrapped_fields = data.fields.iter().map(|field| {
78        let field_name = field.ident.clone();
79        let wrapped_type = generate_maybe_wrapped_with_option(&field.ty);
80        quote::quote! {
81            #field_name: #wrapped_type
82        }
83    });
84    quote::quote!(
85        #(#wrapped_fields),*
86    )
87}
88
89fn generate_container(
90    struct_name: &syn::Ident,
91    data: &syn::DataStruct,
92) -> (syn::Ident, proc_macro2::TokenStream) {
93    let fields = generate_container_fields(data);
94    let builder_name = syn::Ident::new(
95        &format!(
96            "__{name}{FIELDS_CONTAINER_SUFFIX}",
97            name = struct_name.to_string()
98        ),
99        struct_name.span(),
100    );
101    (
102        builder_name.clone(),
103        quote::quote!(
104            #[derive(Default)]
105            struct #builder_name {
106                #fields
107            }
108        ),
109    )
110}
111
112fn generate_final_builder(
113    struct_name: &syn::Ident,
114    builder_name: &syn::Ident,
115    shared_builder_name: &syn::Ident,
116    data: &syn::DataStruct,
117) -> proc_macro2::TokenStream {
118    let final_builder_fields = data.fields.iter().cloned().map(|field| {
119        let name = field.ident.unwrap();
120        if extract_from_option_type(&field.ty).is_some() {
121            return quote::quote!(
122                #name: self.shared.#name
123            );
124        }
125        quote::quote!(
126            #name: self.shared.#name.unwrap()
127        )
128    });
129    quote::quote!(
130        struct #builder_name {
131            shared: #shared_builder_name
132        }
133        impl #builder_name {
134            pub fn build(self) -> #struct_name {
135                #struct_name {
136                    #(#final_builder_fields),*
137                }
138            }
139        }
140    )
141}
142
143fn forge_cache_key(fields: &[syn::Field]) -> String {
144    fields
145        .iter()
146        .cloned()
147        .map(|field| field.ident.unwrap().to_string())
148        .collect::<Vec<_>>()
149        .join("::")
150}
151
152fn generate_builder(
153    struct_name: &syn::Ident,
154    fields_container_name: &syn::Ident,
155    fields: Vec<syn::Field>,
156    build_method_impl: &proc_macro2::TokenStream,
157    defaults: &std::collections::HashMap<syn::Ident, proc_macro2::TokenStream>,
158    generator_cache: &mut HashMap<String, (syn::Ident, proc_macro2::TokenStream)>,
159) -> syn::Ident {
160    let cache_key = forge_cache_key(fields.as_slice());
161    if let Some((ident, _)) = generator_cache.get(&cache_key) {
162        return ident.clone();
163    }
164
165    let builder_name = syn::Ident::new(
166        &format!(
167            "__{name}_{nonce}",
168            name = struct_name.to_string(),
169            nonce = uuid::Uuid::new_v4().to_string().replace("-", "")
170        ),
171        struct_name.span(),
172    );
173
174    let builder_methods = fields.iter().enumerate().map(|(index, field)| {
175        let field_name = field.ident.clone();
176
177        let mut new_fields = fields.clone();
178        new_fields.remove(index);
179
180        // Literally nothing else to generate, return to the FinalBuilder
181        let next_builder_name = if new_fields.is_empty() {
182            syn::Ident::new(
183                &format!(
184                    "__{name}{FINAL_BUILDER_SUFFIX}",
185                    name = struct_name.to_string()
186                ),
187                struct_name.span(),
188            )
189        }
190        // leftovers fields are all optional, should treat them differently
191        else if new_fields
192            .iter()
193            .all(|field| extract_from_option_type(&field.ty).is_some())
194        {
195            generate_builder(
196                struct_name,
197                fields_container_name,
198                new_fields,
199                build_method_impl,
200                defaults,
201                generator_cache,
202            )
203        }
204        // General case
205        else {
206            generate_builder(
207                struct_name,
208                fields_container_name,
209                new_fields,
210                build_method_impl,
211                defaults,
212                generator_cache,
213            )
214        };
215
216        let method_name = syn::Ident::new(
217            &format!(
218                "with_{name}",
219                name = field.ident.clone().unwrap().to_string()
220            ),
221            field.ident.span(),
222        );
223
224        let field_type = extract_from_option_type(&field.ty).unwrap_or_else(|| field.ty.clone());
225        quote::quote!(
226            pub fn #method_name(mut self, value: impl Into<#field_type>) -> #next_builder_name {
227                self.shared.#field_name = Some(value.into());
228                #next_builder_name {shared: self.shared}
229            }
230        )
231    });
232
233    let are_all_defaults = {
234        fields
235            .iter()
236            .all(|field| defaults.get(&field.ident.clone().unwrap()).is_some())
237    };
238
239    let are_all_optional = fields
240        .iter()
241        .all(|field| extract_from_option_type(&field.ty).is_some());
242
243    let build_method = if fields.is_empty() || are_all_defaults || are_all_optional {
244        build_method_impl.clone()
245    } else {
246        quote::quote!()
247    };
248
249    let builder_code = quote::quote!(
250        #[allow(non_camel_case_types)]
251        struct #builder_name {
252            shared: #fields_container_name
253        }
254
255        impl #builder_name {
256            #(#builder_methods)*
257            #build_method
258        }
259    );
260
261    generator_cache.insert(cache_key, (builder_name.clone(), builder_code));
262    builder_name
263}
264
265#[proc_macro_derive(Builder, attributes(default))]
266pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
267    let parsed = syn::parse_macro_input!(input as syn::DeriveInput);
268
269    // We support structs only
270    let syn::Data::Struct(data) = parsed.data else {
271        panic!("This macro can only operate on structs")
272    };
273
274    // And the struct must have named fields, anon are no good
275    if !data
276        .fields
277        .iter()
278        .all(|field| field.ident.clone().is_some())
279    {
280        panic!("This struct contains anon fields, which is not supported");
281    }
282
283    let (shared_builder_name, shared_builder_definition) = generate_container(&parsed.ident, &data);
284    let final_builder_name = syn::Ident::new(
285        &format!(
286            "__{name}{FINAL_BUILDER_SUFFIX}",
287            name = parsed.ident.to_string()
288        ),
289        parsed.ident.span(),
290    );
291
292    // That builder will contain the actual `build` method
293    let final_builder = generate_final_builder(
294        &parsed.ident,
295        &final_builder_name,
296        &shared_builder_name,
297        &data,
298    );
299
300    let mut generator_cache = HashMap::with_capacity(1 + 2usize.pow((data.fields.len()) as u32));
301    let struct_name = parsed.ident.clone();
302
303    let build_method_impl = {
304        let builder_fields = data.fields.iter().map(|field| {
305            let name = field.ident.clone().unwrap();
306            if extract_from_option_type(&field.ty).is_some() {
307                return quote::quote!(
308                    #name: self.shared.#name
309                );
310            }
311            quote::quote!(
312                #name: self.shared.#name.unwrap()
313            )
314        });
315        quote::quote!(
316            pub fn build(self) -> #struct_name {
317                #struct_name {
318                    #(#builder_fields),*
319                }
320            }
321        )
322    };
323
324    let defaults = extract_defaults(&data.fields.iter().cloned().collect::<Vec<_>>());
325    // That builder you'd get by invoking the `builder` method on the target struct
326    let initial_builder_name = generate_builder(
327        &struct_name,
328        &shared_builder_name,
329        data.fields.into_iter().collect(),
330        &build_method_impl,
331        &defaults,
332        &mut generator_cache,
333    );
334
335    // Recover builders implementations from the the generator cache
336    let builders = generator_cache.into_values().map(|(_, tokens)| tokens);
337    let shared_builder_defaults_setters = defaults.iter().map(|(name, tokens)| {
338        quote::quote!(
339            shared_builder.#name = Some(#tokens.into());
340        )
341    });
342
343    quote::quote!(
344        #shared_builder_definition
345        #(#builders)*
346        #final_builder
347        impl #struct_name {
348            pub fn builder() -> #initial_builder_name {
349                let mut shared_builder = #shared_builder_name::default();
350                #(#shared_builder_defaults_setters)*
351                #initial_builder_name {
352                    shared: shared_builder,
353                }
354            }
355        }
356    )
357    .into()
358}
359
360#[cfg(test)]
361mod tests {
362    use quote::ToTokens;
363    use rstest::rstest;
364    use syn::{parse_quote, Type};
365
366    #[rstest(
367        input_type,
368        expected_output_str,
369        case(parse_quote! { Option<i32> }, Some("i32")),
370        case(parse_quote! { Result<Option<i32>, String> }, None),
371        case(parse_quote! { Option<Option<i32>> }, Some("Option < i32 >")),
372        case(parse_quote! { i32 }, None),
373        case(parse_quote! { Vec<String> }, None)
374    )]
375    fn extract_from_option_type_test(input_type: Type, expected_output_str: Option<&str>) {
376        let result = super::extract_from_option_type(&input_type);
377        let result_str = result.map(|t| t.to_token_stream().to_string());
378        let expected_output = expected_output_str.map(|s| s.to_string());
379        assert_eq!(result_str, expected_output);
380    }
381}