bon_macros/builder/builder_gen/
generic_setters.rs1use super::models::BuilderGenCtx;
2use crate::parsing::ItemSigConfig;
3use crate::util::prelude::*;
4use std::collections::BTreeSet;
5use syn::punctuated::Punctuated;
6use syn::token::Where;
7use syn::visit::Visit;
8
9pub(super) struct GenericSettersCtx<'a> {
10 base: &'a BuilderGenCtx,
11 config: &'a ItemSigConfig<String>,
12}
13
14impl<'a> GenericSettersCtx<'a> {
15 pub(super) fn new(base: &'a BuilderGenCtx, config: &'a ItemSigConfig<String>) -> Self {
16 Self { base, config }
17 }
18
19 pub(super) fn generic_setter_methods(&self) -> Result<TokenStream> {
20 let generics = &self.base.generics.decl_without_defaults;
21
22 let type_param_idents: Vec<&syn::Ident> = generics
23 .iter()
24 .filter_map(|param| match param {
25 syn::GenericParam::Type(type_param) => Some(&type_param.ident),
26 _ => None,
27 })
28 .collect();
29
30 for param in generics {
32 if let syn::GenericParam::Type(type_param) = param {
33 let mut params = TypeParamFinder::new(&type_param_idents);
34
35 for bound in &type_param.bounds {
36 params.visit_type_param_bound(bound);
37 }
38
39 params.found.remove(&type_param.ident);
41
42 if let Some(first_param) = params.found.iter().next() {
43 let params_str = params
44 .found
45 .iter()
46 .map(|p| format!("`{p}`"))
47 .collect::<Vec<_>>()
48 .join(", ");
49 bail!(
50 first_param,
51 "generic conversion methods cannot be generated for interdependent type parameters; \
52 the bounds on generic parameter `{}` reference other type parameters: {}\n\
53 \n\
54 Consider removing `generics(setters(...))` or restructuring your types to avoid interdependencies",
55 type_param.ident,
56 params_str
57 );
58 }
59 }
60 }
61
62 if let Some(where_clause) = &self.base.generics.where_clause {
64 for predicate in &where_clause.predicates {
65 let mut params = TypeParamFinder::new(&type_param_idents);
66 params.visit_where_predicate(predicate);
67 if params.found.len() > 1 {
68 let params_str = params
69 .found
70 .iter()
71 .map(|p| format!("`{p}`"))
72 .collect::<Vec<_>>()
73 .join(", ");
74 bail!(
75 predicate,
76 "generic conversion methods cannot be generated for interdependent type parameters; \
77 the where clause predicate references multiple type parameters: {}\n\
78 \n\
79 Consider removing `generics(setters(...))` or restructuring your types to avoid interdependencies",
80 params_str
81 );
82 }
83 }
84 }
85
86 let mut methods = Vec::with_capacity(generics.len());
87
88 for (index, param) in generics.iter().enumerate() {
89 match param {
90 syn::GenericParam::Type(type_param) => {
91 methods.push(self.generic_setter_method(index, type_param));
92 }
93 syn::GenericParam::Const(const_param) => {
94 bail!(
95 &const_param.ident,
96 "const generic parameters are not yet supported with `generics(setters(...))`; \
97 only type parameters can be overridden, feel free to open an issue if you need \
98 this feature"
99 );
100 }
101 syn::GenericParam::Lifetime(_) => {
102 }
104 }
105 }
106
107 Ok(quote! {
108 #(#methods)*
109 })
110 }
111
112 fn generic_setter_method(
113 &self,
114 param_index: usize,
115 type_param: &syn::TypeParam,
116 ) -> TokenStream {
117 let builder_ident = &self.base.builder_type.ident;
118 let state_var = &self.base.state_var;
119 let where_clause = &self.base.generics.where_clause;
120
121 let param_ident = &type_param.ident;
122 let method_name = self.method_name(param_ident);
123
124 let vis = self
125 .config
126 .vis
127 .as_ref()
128 .map(|v| &v.value)
129 .unwrap_or(&self.base.builder_type.vis);
130
131 let docs = self.method_docs(param_ident);
132
133 let new_type_var = self
136 .base
137 .namespace
138 .unique_ident(format!("New{param_ident}"));
140
141 let bounds = &type_param.bounds;
143 let new_type_param = if bounds.is_empty() {
144 quote!(#new_type_var)
145 } else {
146 quote!(#new_type_var: #bounds)
147 };
148
149 let output_generic_args = self
150 .base
151 .generics
152 .args
153 .iter()
154 .enumerate()
155 .map(|(i, arg)| {
156 if i == param_index {
157 quote!(#new_type_var)
158 } else {
159 quote!(#arg)
160 }
161 })
162 .collect::<Vec<_>>();
163
164 let mut runtime_asserts = Vec::new();
166 let mut type_state_bounds = Vec::new();
167 let named_member_conversions = self
168 .base
169 .named_members()
170 .enumerate()
171 .map(|(idx, member)| {
172 let uses_param = member_uses_generic_param(member, param_ident);
173 let index = syn::Index::from(idx);
174 if uses_param {
175 let state_mod = &self.base.state_mod.ident;
177 let field_pascal = &member.name.pascal;
178 type_state_bounds.push(quote! {
179 #state_var::#field_pascal: #state_mod::IsUnset
180 });
181
182 let field_ident = &member.name.orig;
184 let message = format!(
185 "BUG: field `{field_ident}` should be None \
186 when converting generic parameter `{param_ident}`"
187 );
188 runtime_asserts.push(quote! {
189 ::core::assert!(named.#index.is_none(), #message);
190 });
191 quote!(::core::option::Option::None)
193 } else {
194 quote!(named.#index)
196 }
197 })
198 .collect::<Vec<_>>();
199
200 let receiver_field = self.base.receiver().map(|receiver| {
201 let ident = &receiver.field_ident;
202 quote!(#ident: self.#ident,)
203 });
204
205 let start_fn_fields = self.base.start_fn_args().map(|member| {
206 let ident = &member.ident;
207 quote!(#ident: self.#ident,)
208 });
209
210 let custom_fields = self.base.custom_fields().map(|field| {
211 let ident = &field.ident;
212 quote!(#ident: self.#ident,)
213 });
214
215 let extended_where_clause = {
217 let mut clause = where_clause.clone().unwrap_or_else(|| syn::WhereClause {
218 where_token: Where::default(),
219 predicates: Punctuated::default(),
220 });
221
222 for predicate in &mut clause.predicates {
223 replace_type_param_in_predicate(predicate, param_ident, &new_type_var);
224 }
225
226 for bound in type_state_bounds {
227 clause.predicates.push(syn::parse_quote!(#bound));
228 }
229
230 (!clause.predicates.is_empty()).then(|| clause)
231 };
232
233 quote! {
234 #(#docs)*
235 #[inline(always)]
236 #vis fn #method_name<#new_type_param>(
237 self
238 ) -> #builder_ident<#(#output_generic_args,)* #state_var>
239 #extended_where_clause
240 {
241 let named = self.__unsafe_private_named;
242
243 #(#runtime_asserts)*
246
247 #builder_ident {
248 __unsafe_private_phantom: ::core::marker::PhantomData,
249 #receiver_field
250 #(#start_fn_fields)*
251 #(#custom_fields)*
252 __unsafe_private_named: (
253 #(#named_member_conversions,)*
254 ),
255 }
256 }
257 }
258 }
259
260 fn method_name(&self, param_ident: &syn::Ident) -> syn::Ident {
261 let param_name_snake = param_ident.pascal_to_snake_case();
262
263 let name_pattern = &self
265 .config
266 .name
267 .as_ref()
268 .expect("name should be validated")
269 .value;
270
271 let method_name = name_pattern.replace("{}", ¶m_name_snake.to_string());
272
273 syn::Ident::new(&method_name, param_ident.span())
274 }
275
276 fn method_docs(&self, param_ident: &syn::Ident) -> Vec<syn::Attribute> {
277 if let Some(ref docs) = self.config.docs {
279 return docs.value.clone();
280 }
281
282 let doc = format!(
284 "Convert the `{param_ident}` generic parameter to a different type.\n\
285 \n\
286 This method allows changing the type of the `{param_ident}` parameter on the builder, \
287 which is useful when you need to build up values with different types at \
288 different stages of construction."
289 );
290
291 vec![syn::parse_quote!(#[doc = #doc])]
292 }
293}
294
295struct TypeParamFinder<'ty, 'ast> {
296 type_params: &'ty [&'ty syn::Ident],
297
298 found: BTreeSet<&'ast syn::Ident>,
300}
301
302impl<'ty> TypeParamFinder<'ty, '_> {
303 fn new(type_params: &'ty [&'ty syn::Ident]) -> Self {
304 Self {
305 type_params,
306 found: BTreeSet::new(),
307 }
308 }
309}
310
311impl<'ast> Visit<'ast> for TypeParamFinder<'_, 'ast> {
312 fn visit_path(&mut self, path: &'ast syn::Path) {
313 if let Some(param) = path.get_ident() {
315 if self.type_params.contains(¶m) {
316 self.found.insert(param);
317 }
318 }
319
320 syn::visit::visit_path(self, path);
322 }
323}
324
325fn replace_type_param_in_predicate(
326 predicate: &mut syn::WherePredicate,
327 old_param: &syn::Ident,
328 new_param: &syn::Ident,
329) {
330 use syn::visit_mut::VisitMut;
331
332 struct TypeParamReplacer<'a> {
333 old_param: &'a syn::Ident,
334 new_param: &'a syn::Ident,
335 }
336
337 impl VisitMut for TypeParamReplacer<'_> {
338 fn visit_path_mut(&mut self, path: &mut syn::Path) {
339 if path.is_ident(self.old_param) {
341 if let Some(segment) = path.segments.first_mut() {
342 segment.ident = self.new_param.clone();
343 }
344 }
345 syn::visit_mut::visit_path_mut(self, path);
347 }
348
349 fn visit_type_path_mut(&mut self, type_path: &mut syn::TypePath) {
350 if let Some(qself) = &mut type_path.qself {
352 self.visit_type_mut(&mut qself.ty);
353 }
354 self.visit_path_mut(&mut type_path.path);
355 }
356 }
357
358 let mut replacer = TypeParamReplacer {
359 old_param,
360 new_param,
361 };
362 replacer.visit_where_predicate_mut(predicate);
363}
364
365fn member_uses_generic_param(member: &super::NamedMember, param_ident: &syn::Ident) -> bool {
367 let member_ty = member.underlying_norm_ty();
368 type_uses_generic_param(member_ty, param_ident)
369}
370
371fn type_uses_generic_param(ty: &syn::Type, param_ident: &syn::Ident) -> bool {
373 struct GenericParamVisitor<'a> {
374 param_ident: &'a syn::Ident,
375 found: bool,
376 }
377
378 impl<'ast> Visit<'ast> for GenericParamVisitor<'_> {
379 fn visit_type_path(&mut self, type_path: &'ast syn::TypePath) {
380 if self.found {
382 return;
383 }
384
385 if type_path.path.is_ident(self.param_ident) {
387 self.found = true;
388 return;
389 }
390
391 if let Some(qself) = &type_path.qself {
395 self.visit_type(&qself.ty);
397 } else if let Some(segment) = type_path.path.segments.first() {
398 if segment.ident == *self.param_ident {
400 self.found = true;
401 return;
402 }
403 }
404
405 syn::visit::visit_type_path(self, type_path);
407 }
408 }
409
410 let mut visitor = GenericParamVisitor {
411 param_ident,
412 found: false,
413 };
414 visitor.visit_type(ty);
415 visitor.found
416}