Skip to main content

bon_macros/builder/builder_gen/
generic_setters.rs

1use super::models::BuilderGenCtx;
2use crate::parsing::ItemSigConfig;
3use crate::util::prelude::*;
4use std::collections::BTreeSet;
5use syn::punctuated::Punctuated;
6use syn::token::Where;
7use syn::visit::Visit;
8
9pub(super) struct GenericSettersCtx<'a> {
10    base: &'a BuilderGenCtx,
11    config: &'a ItemSigConfig<String>,
12}
13
14impl<'a> GenericSettersCtx<'a> {
15    pub(super) fn new(base: &'a BuilderGenCtx, config: &'a ItemSigConfig<String>) -> Self {
16        Self { base, config }
17    }
18
19    pub(super) fn generic_setter_methods(&self) -> Result<TokenStream> {
20        let generics = &self.base.generics.decl_without_defaults;
21
22        let type_param_idents: Vec<&syn::Ident> = generics
23            .iter()
24            .filter_map(|param| match param {
25                syn::GenericParam::Type(type_param) => Some(&type_param.ident),
26                _ => None,
27            })
28            .collect();
29
30        // Check for interdependent type parameters in generic bounds
31        for param in generics {
32            if let syn::GenericParam::Type(type_param) = param {
33                let mut params = TypeParamFinder::new(&type_param_idents);
34
35                for bound in &type_param.bounds {
36                    params.visit_type_param_bound(bound);
37                }
38
39                // Self-referential type params are fine
40                params.found.remove(&type_param.ident);
41
42                if let Some(first_param) = params.found.iter().next() {
43                    let params_str = params
44                        .found
45                        .iter()
46                        .map(|p| format!("`{p}`"))
47                        .collect::<Vec<_>>()
48                        .join(", ");
49                    bail!(
50                        first_param,
51                        "generic conversion methods cannot be generated for interdependent type parameters; \
52                         the bounds on generic parameter `{}` reference other type parameters: {}\n\
53                         \n\
54                         Consider removing `generics(setters(...))` or restructuring your types to avoid interdependencies",
55                        type_param.ident,
56                        params_str
57                    );
58                }
59            }
60        }
61
62        // Check for interdependent type parameters in where clauses
63        if let Some(where_clause) = &self.base.generics.where_clause {
64            for predicate in &where_clause.predicates {
65                let mut params = TypeParamFinder::new(&type_param_idents);
66                params.visit_where_predicate(predicate);
67                if params.found.len() > 1 {
68                    let params_str = params
69                        .found
70                        .iter()
71                        .map(|p| format!("`{p}`"))
72                        .collect::<Vec<_>>()
73                        .join(", ");
74                    bail!(
75                        predicate,
76                        "generic conversion methods cannot be generated for interdependent type parameters; \
77                         the where clause predicate references multiple type parameters: {}\n\
78                         \n\
79                         Consider removing `generics(setters(...))` or restructuring your types to avoid interdependencies",
80                        params_str
81                    );
82                }
83            }
84        }
85
86        let mut methods = Vec::with_capacity(generics.len());
87
88        for (index, param) in generics.iter().enumerate() {
89            match param {
90                syn::GenericParam::Type(type_param) => {
91                    methods.push(self.generic_setter_method(index, type_param));
92                }
93                syn::GenericParam::Const(const_param) => {
94                    bail!(
95                        &const_param.ident,
96                        "const generic parameters are not yet supported with `generics(setters(...))`; \
97                         only type parameters can be overridden, feel free to open an issue if you need \
98                         this feature"
99                    );
100                }
101                syn::GenericParam::Lifetime(_) => {
102                    // Skip lifetimes, they don't get setters
103                }
104            }
105        }
106
107        Ok(quote! {
108            #(#methods)*
109        })
110    }
111
112    fn generic_setter_method(
113        &self,
114        param_index: usize,
115        type_param: &syn::TypeParam,
116    ) -> TokenStream {
117        let builder_ident = &self.base.builder_type.ident;
118        let state_var = &self.base.state_var;
119        let where_clause = &self.base.generics.where_clause;
120
121        let param_ident = &type_param.ident;
122        let method_name = self.method_name(param_ident);
123
124        let vis = self
125            .config
126            .vis
127            .as_ref()
128            .map(|v| &v.value)
129            .unwrap_or(&self.base.builder_type.vis);
130
131        let docs = self.method_docs(param_ident);
132
133        // Build the generic arguments for the output type, where the current parameter
134        // is replaced with a new type variable. Even though the `GenericsNamespace`
135        let new_type_var = self
136            .base
137            .namespace
138            // Add `New` prefix to make the type variable more readable in the docs and IDE hints
139            .unique_ident(format!("New{param_ident}"));
140
141        // Copy the bounds from the original type parameter to the new one
142        let bounds = &type_param.bounds;
143        let new_type_param = if bounds.is_empty() {
144            quote!(#new_type_var)
145        } else {
146            quote!(#new_type_var: #bounds)
147        };
148
149        let output_generic_args = self
150            .base
151            .generics
152            .args
153            .iter()
154            .enumerate()
155            .map(|(i, arg)| {
156                if i == param_index {
157                    quote!(#new_type_var)
158                } else {
159                    quote!(#arg)
160                }
161            })
162            .collect::<Vec<_>>();
163
164        // Check which named members use this generic parameter
165        let mut runtime_asserts = Vec::new();
166        let mut type_state_bounds = Vec::new();
167        let named_member_conversions = self
168            .base
169            .named_members()
170            .enumerate()
171            .map(|(idx, member)| {
172                let uses_param = member_uses_generic_param(member, param_ident);
173                let index = syn::Index::from(idx);
174                if uses_param {
175                    // Add compile-time type state constraint
176                    let state_mod = &self.base.state_mod.ident;
177                    let field_pascal = &member.name.pascal;
178                    type_state_bounds.push(quote! {
179                        #state_var::#field_pascal: #state_mod::IsUnset
180                    });
181
182                    // Add runtime assert that this field is None
183                    let field_ident = &member.name.orig;
184                    let message = format!(
185                        "BUG: field `{field_ident}` should be None \
186                        when converting generic parameter `{param_ident}`"
187                    );
188                    runtime_asserts.push(quote! {
189                        ::core::assert!(named.#index.is_none(), #message);
190                    });
191                    // Field uses the generic parameter, so create a new None
192                    quote!(::core::option::Option::None)
193                } else {
194                    // Field doesn't use the generic parameter, so move it from the tuple
195                    quote!(named.#index)
196                }
197            })
198            .collect::<Vec<_>>();
199
200        let receiver_field = self.base.receiver().map(|receiver| {
201            let ident = &receiver.field_ident;
202            quote!(#ident: self.#ident,)
203        });
204
205        let start_fn_fields = self.base.start_fn_args().map(|member| {
206            let ident = &member.ident;
207            quote!(#ident: self.#ident,)
208        });
209
210        let custom_fields = self.base.custom_fields().map(|field| {
211            let ident = &field.ident;
212            quote!(#ident: self.#ident,)
213        });
214
215        // Extend where clause with type state bounds and update type parameter references
216        let extended_where_clause = {
217            let mut clause = where_clause.clone().unwrap_or_else(|| syn::WhereClause {
218                where_token: Where::default(),
219                predicates: Punctuated::default(),
220            });
221
222            for predicate in &mut clause.predicates {
223                replace_type_param_in_predicate(predicate, param_ident, &new_type_var);
224            }
225
226            for bound in type_state_bounds {
227                clause.predicates.push(syn::parse_quote!(#bound));
228            }
229
230            (!clause.predicates.is_empty()).then(|| clause)
231        };
232
233        quote! {
234            #(#docs)*
235            #[inline(always)]
236            #vis fn #method_name<#new_type_param>(
237                self
238            ) -> #builder_ident<#(#output_generic_args,)* #state_var>
239            #extended_where_clause
240            {
241                let named = self.__unsafe_private_named;
242
243                // Runtime safety asserts to ensure fields using the converted
244                // generic parameter are None
245                #(#runtime_asserts)*
246
247                #builder_ident {
248                    __unsafe_private_phantom: ::core::marker::PhantomData,
249                    #receiver_field
250                    #(#start_fn_fields)*
251                    #(#custom_fields)*
252                    __unsafe_private_named: (
253                        #(#named_member_conversions,)*
254                    ),
255                }
256            }
257        }
258    }
259
260    fn method_name(&self, param_ident: &syn::Ident) -> syn::Ident {
261        let param_name_snake = param_ident.pascal_to_snake_case();
262
263        // Name is guaranteed to be present due to validation in parse_setters_config
264        let name_pattern = &self
265            .config
266            .name
267            .as_ref()
268            .expect("name should be validated")
269            .value;
270
271        let method_name = name_pattern.replace("{}", &param_name_snake.to_string());
272
273        syn::Ident::new(&method_name, param_ident.span())
274    }
275
276    fn method_docs(&self, param_ident: &syn::Ident) -> Vec<syn::Attribute> {
277        // If custom docs are provided, use them
278        if let Some(ref docs) = self.config.docs {
279            return docs.value.clone();
280        }
281
282        // Otherwise, generate default documentation
283        let doc = format!(
284            "Convert the `{param_ident}` generic parameter to a different type.\n\
285            \n\
286            This method allows changing the type of the `{param_ident}` parameter on the builder, \
287            which is useful when you need to build up values with different types at \
288            different stages of construction."
289        );
290
291        vec![syn::parse_quote!(#[doc = #doc])]
292    }
293}
294
295struct TypeParamFinder<'ty, 'ast> {
296    type_params: &'ty [&'ty syn::Ident],
297
298    // Use a `BTreeSet` for deterministic ordering
299    found: BTreeSet<&'ast syn::Ident>,
300}
301
302impl<'ty> TypeParamFinder<'ty, '_> {
303    fn new(type_params: &'ty [&'ty syn::Ident]) -> Self {
304        Self {
305            type_params,
306            found: BTreeSet::new(),
307        }
308    }
309}
310
311impl<'ast> Visit<'ast> for TypeParamFinder<'_, 'ast> {
312    fn visit_path(&mut self, path: &'ast syn::Path) {
313        // Check if this path is one of our type parameters
314        if let Some(param) = path.get_ident() {
315            if self.type_params.contains(&param) {
316                self.found.insert(param);
317            }
318        }
319
320        // Continue visiting nested paths
321        syn::visit::visit_path(self, path);
322    }
323}
324
325fn replace_type_param_in_predicate(
326    predicate: &mut syn::WherePredicate,
327    old_param: &syn::Ident,
328    new_param: &syn::Ident,
329) {
330    use syn::visit_mut::VisitMut;
331
332    struct TypeParamReplacer<'a> {
333        old_param: &'a syn::Ident,
334        new_param: &'a syn::Ident,
335    }
336
337    impl VisitMut for TypeParamReplacer<'_> {
338        fn visit_path_mut(&mut self, path: &mut syn::Path) {
339            // Replace simple paths like `T`
340            if path.is_ident(self.old_param) {
341                if let Some(segment) = path.segments.first_mut() {
342                    segment.ident = self.new_param.clone();
343                }
344            }
345            // Continue visiting nested paths
346            syn::visit_mut::visit_path_mut(self, path);
347        }
348
349        fn visit_type_path_mut(&mut self, type_path: &mut syn::TypePath) {
350            // Handle qualified paths like T::Assoc
351            if let Some(qself) = &mut type_path.qself {
352                self.visit_type_mut(&mut qself.ty);
353            }
354            self.visit_path_mut(&mut type_path.path);
355        }
356    }
357
358    let mut replacer = TypeParamReplacer {
359        old_param,
360        new_param,
361    };
362    replacer.visit_where_predicate_mut(predicate);
363}
364
365/// Check if a member's type uses a specific generic parameter
366fn member_uses_generic_param(member: &super::NamedMember, param_ident: &syn::Ident) -> bool {
367    let member_ty = member.underlying_norm_ty();
368    type_uses_generic_param(member_ty, param_ident)
369}
370
371/// Recursively check if a type uses a specific generic parameter
372fn type_uses_generic_param(ty: &syn::Type, param_ident: &syn::Ident) -> bool {
373    struct GenericParamVisitor<'a> {
374        param_ident: &'a syn::Ident,
375        found: bool,
376    }
377
378    impl<'ast> Visit<'ast> for GenericParamVisitor<'_> {
379        fn visit_type_path(&mut self, type_path: &'ast syn::TypePath) {
380            // Early return if already found to avoid unnecessary recursion
381            if self.found {
382                return;
383            }
384
385            // Check if the path is the generic parameter we're looking for
386            if type_path.path.is_ident(self.param_ident) {
387                self.found = true;
388                return;
389            }
390
391            // For qualified paths like T::Assoc or <T as Trait>::Assoc,
392            // check if the first segment (or qself) uses the generic parameter
393
394            if let Some(qself) = &type_path.qself {
395                // For <T as Trait>::Assoc syntax
396                self.visit_type(&qself.ty);
397            } else if let Some(segment) = type_path.path.segments.first() {
398                // For T::Assoc syntax
399                if segment.ident == *self.param_ident {
400                    self.found = true;
401                    return;
402                }
403            }
404
405            // Continue visiting the rest of the type path
406            syn::visit::visit_type_path(self, type_path);
407        }
408    }
409
410    let mut visitor = GenericParamVisitor {
411        param_ident,
412        found: false,
413    };
414    visitor.visit_type(ty);
415    visitor.found
416}