derive_with/
lib.rs

1use std::collections::HashMap;
2
3use proc_macro::TokenStream;
4use proc_macro2::Ident;
5use quote::{ToTokens, format_ident, quote};
6use syn::parse::Parse;
7use syn::punctuated::Punctuated;
8use syn::token::Comma;
9use syn::{
10    Attribute, GenericParam, Generics, Index, Meta, Path, PredicateType, Token, Type, TypeParam,
11    TypePath, WhereClause, WherePredicate,
12};
13
14/// A custom derive implementation for `#[derive(With)]`
15///
16/// # Get started
17///
18/// 1.Generate with-constructor for each field
19/// ```rust
20/// use derive_with::With;
21///
22/// #[derive(With, Default)]
23/// pub struct Foo {
24///     pub a: i32,
25///     pub b: String,
26/// }
27///
28/// #[derive(With, Default)]
29/// pub struct Bar (i32, String);
30///
31/// fn test_struct() {
32///     let foo = Foo::default().with_a(1).with_b(1.to_string());
33///     assert_eq!(foo.a, 1);
34///     assert_eq!(foo.b, "1".to_string());
35///
36///     let bar = Bar::default().with_0(1).with_1(1.to_string());
37///     assert_eq!(bar.0, 1);
38///     assert_eq!(bar.1, "1".to_string());
39/// }
40/// ```
41///
42/// 2.Generate with-constructor for specific fields
43/// ```rust
44/// use derive_with::With;
45///
46/// #[derive(With, Default)]
47/// #[with(a)]
48/// pub struct Foo {
49///     pub a: i32,
50///     pub b: String,
51/// }
52///
53/// #[derive(With, Default)]
54/// #[with(1)]
55/// pub struct Bar (i32, String);
56///
57/// fn test_struct() {
58///     let foo = Foo::default().with_a(1);
59///     assert_eq!(foo.a, 1);
60///
61///     let bar = Bar::default().with_1(1.to_string());
62///     assert_eq!(bar.1, "1".to_string());
63/// }
64/// ```
65#[proc_macro_derive(With, attributes(with))]
66pub fn derive(input: TokenStream) -> TokenStream {
67    let ast: syn::DeriveInput = syn::parse(input).expect("Couldn't parse item");
68    let result = match ast.data {
69        syn::Data::Struct(ref s) => with_for_struct(&ast, &s.fields),
70        syn::Data::Enum(_) => panic!("doesn't work with enums yet"),
71        syn::Data::Union(_) => panic!("doesn't work with unions yet"),
72    };
73    result.into()
74}
75
76fn with_for_struct(ast: &syn::DeriveInput, fields: &syn::Fields) -> proc_macro2::TokenStream {
77    match *fields {
78        syn::Fields::Named(ref fields) => with_constructor_for_named(ast, &fields.named),
79        syn::Fields::Unnamed(ref fields) => with_constructor_for_unnamed(ast, &fields.unnamed),
80        syn::Fields::Unit => panic!("Unit structs are not supported"),
81    }
82}
83
84fn with_constructor_for_named(
85    ast: &syn::DeriveInput,
86    fields: &Punctuated<syn::Field, Token![,]>,
87) -> proc_macro2::TokenStream {
88    let name = &ast.ident;
89    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
90    let generics_map = index_generics(&ast.generics);
91    let where_predicate_map = index_where_predicates(&ast.generics.where_clause);
92    let with_args = parse_with_args::<Ident>(&ast.attrs);
93    let field_count = fields.len();
94
95    let mut constructors = quote!();
96    for field in fields {
97        let field_name = field.ident.as_ref().unwrap();
98        if !contains_field(&with_args, field_name) {
99            continue;
100        }
101        let field_type = &field.ty;
102        let constructor_name = format_ident!("with_{}", field_name);
103
104        // Check the type of the field
105        let constructor = match field_type {
106            // For simple path types
107            Type::Path(type_path) => {
108                // Check if the type matches some generic parameter
109                match generics_map.get(&type_path.path).cloned() {
110                    // If the type is not generic, just use the Into trait to derive the method
111                    None => generate_constructor_for_named(
112                        &constructor_name,
113                        field_name,
114                        field_type,
115                        field_count,
116                    ),
117                    // If the type is generic, allow to switch types
118                    Some(mut generic) => {
119                        let new_generic = format_ident!("W{}", generic.ident);
120                        // Update the generic ident for the new one, so that it doesn't conflict with the existing
121                        generic.ident = new_generic.clone();
122
123                        // Determine the new generics, which are the existing generics
124                        let mut new_generic_params = Vec::new();
125                        for param in &ast.generics.params {
126                            new_generic_params.push(match param {
127                                // Except for the generic parameter that matches the field type
128                                GenericParam::Type(type_param)
129                                    if type_path.path.is_ident(&type_param.ident) =>
130                                {
131                                    // That must be replaced with the new generic ident
132                                    new_generic.to_token_stream()
133                                }
134                                GenericParam::Type(type_param) => {
135                                    type_param.ident.to_token_stream()
136                                }
137                                GenericParam::Lifetime(lifetime_param) => {
138                                    lifetime_param.lifetime.to_token_stream()
139                                }
140                                GenericParam::Const(const_param) => {
141                                    const_param.ident.to_token_stream()
142                                }
143                            });
144                        }
145
146                        // Compute the new field values, as we can't deconstruct when switching types
147                        let mut other_fields = Vec::new();
148                        for other_field in fields {
149                            let other_field_name = other_field.ident.as_ref().unwrap();
150                            if other_field_name != field_name {
151                                other_fields
152                                    .push(quote! { #other_field_name: self.#other_field_name });
153                            } else {
154                                other_fields.push(quote! { #field_name });
155                            }
156                        }
157
158                        // Retrieve the where predicate affecting this field, if any
159                        let where_clause = where_predicate_map.get(&type_path.path).cloned().map(
160                            |mut predicate| {
161                                // And update the bounded type to the new generic ident
162                                predicate.bounded_ty = Type::Path(TypePath {
163                                    qself: None,
164                                    path: Path::from(new_generic.clone()),
165                                });
166                                quote! { where #predicate }
167                            },
168                        );
169
170                        quote! {
171                            pub fn #constructor_name <#generic> (self, #field_name: #new_generic)
172                            -> #name < #(#new_generic_params),* >
173                            #where_clause
174                            {
175                                #name {
176                                    #(#other_fields),*
177                                }
178                            }
179                        }
180                    }
181                }
182            }
183            // For every other field type, just use the Into trait to derive the method
184            _ => generate_constructor_for_named(
185                &constructor_name,
186                field_name,
187                field_type,
188                field_count,
189            ),
190        };
191
192        constructors = quote! {
193            #constructors
194            #constructor
195        };
196    }
197    quote! {
198        #[automatically_derived]
199        impl #impl_generics #name #ty_generics #where_clause {
200            #constructors
201        }
202    }
203}
204
205fn with_constructor_for_unnamed(
206    ast: &syn::DeriveInput,
207    fields: &Punctuated<syn::Field, Token![,]>,
208) -> proc_macro2::TokenStream {
209    let name = &ast.ident;
210    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
211    let generics_map = index_generics(&ast.generics);
212    let where_predicate_map = index_where_predicates(&ast.generics.where_clause);
213    let with_args = parse_with_args::<Index>(&ast.attrs);
214
215    let mut constructors = quote!();
216    for (index, field) in fields.iter().enumerate() {
217        let index = syn::Index::from(index);
218        if !contains_field(&with_args, &index) {
219            continue;
220        }
221        let field_type = &field.ty;
222        let field_name = format_ident!("field_{}", index);
223        let constructor_name = format_ident!("with_{}", index);
224
225        // Check the type of the field
226        let constructor = match field_type {
227            // For simple path types
228            Type::Path(type_path) => {
229                // Check if the type matches some generic parameter
230                match generics_map.get(&type_path.path).cloned() {
231                    // If the type is not generic, just use the Into trait to derive the method
232                    None => generate_constructor_for_unnamed(
233                        &constructor_name,
234                        index,
235                        &field_name,
236                        field_type,
237                    ),
238
239                    // If the type is generic, allow to switch types
240                    Some(mut generic) => {
241                        let new_generic = format_ident!("W{}", generic.ident);
242                        // Update the generic ident for the new one, so that it doesn't conflict with the existing
243                        generic.ident = new_generic.clone();
244
245                        // Determine the new generics, which are the existing generics
246                        let mut new_generic_params = Vec::new();
247                        for param in &ast.generics.params {
248                            new_generic_params.push(match param {
249                                // Except for the generic parameter that matches the field type
250                                GenericParam::Type(type_param)
251                                    if type_path.path.is_ident(&type_param.ident) =>
252                                {
253                                    // That must be replaced with the new generic ident
254                                    new_generic.to_token_stream()
255                                }
256                                GenericParam::Type(type_param) => {
257                                    type_param.ident.to_token_stream()
258                                }
259                                GenericParam::Lifetime(lifetime_param) => {
260                                    lifetime_param.lifetime.to_token_stream()
261                                }
262                                GenericParam::Const(const_param) => {
263                                    const_param.ident.to_token_stream()
264                                }
265                            });
266                        }
267
268                        // Compute the new field values
269                        let mut other_fields = Vec::new();
270                        for (other_index, _) in fields.iter().enumerate() {
271                            let other_index = syn::Index::from(other_index);
272                            if other_index != index {
273                                other_fields.push(quote! { self.#other_index });
274                            } else {
275                                other_fields.push(quote! { #field_name });
276                            }
277                        }
278
279                        // Retrieve the where predicate affecting this field, if any
280                        let where_clause = where_predicate_map.get(&type_path.path).cloned().map(
281                            |mut predicate| {
282                                // And update the bounded type to the new generic ident
283                                predicate.bounded_ty = Type::Path(TypePath {
284                                    qself: None,
285                                    path: Path::from(new_generic.clone()),
286                                });
287                                quote! { where #predicate }
288                            },
289                        );
290
291                        quote! {
292                            pub fn #constructor_name <#generic> (self, #field_name: #new_generic)
293                            -> #name < #(#new_generic_params),* >
294                            #where_clause
295                            {
296                                #name ( #(#other_fields),* )
297                            }
298                        }
299                    }
300                }
301            }
302            // For every other field type, just use the Into trait to derive the method
303            _ => {
304                generate_constructor_for_unnamed(&constructor_name, index, &field_name, field_type)
305            }
306        };
307
308        constructors = quote! {
309            #constructors
310            #constructor
311        };
312    }
313    quote! {
314        #[automatically_derived]
315        impl #impl_generics #name #ty_generics #where_clause {
316            #constructors
317        }
318    }
319}
320
321fn parse_with_args<T: Parse>(attrs: &[Attribute]) -> Option<Punctuated<T, Comma>> {
322    if let Some(attr) = attrs.iter().find(|attr| attr.path().is_ident("with")) {
323        match &attr.meta {
324            Meta::List(list) => Some(
325                list.parse_args_with(Punctuated::<T, Comma>::parse_terminated)
326                    .expect("Couldn't parse with args"),
327            ),
328            _ => panic!("`with` attribute should like `#[with(a, b, c)]`"),
329        }
330    } else {
331        None
332    }
333}
334
335fn contains_field<T: Parse + PartialEq>(
336    with_args: &Option<Punctuated<T, Comma>>,
337    item: &T,
338) -> bool {
339    with_args.is_none() || with_args.as_ref().unwrap().iter().any(|arg| arg == item)
340}
341
342fn index_generics(generics: &Generics) -> HashMap<Path, TypeParam> {
343    generics
344        .params
345        .iter()
346        .filter_map(|p| match p {
347            GenericParam::Type(type_param) => Some(type_param),
348            _ => None,
349        })
350        .map(|p| (Path::from(p.ident.clone()), p.clone()))
351        .collect()
352}
353
354fn index_where_predicates(where_clause: &Option<WhereClause>) -> HashMap<Path, PredicateType> {
355    where_clause
356        .as_ref()
357        .map(|w| {
358            w.predicates
359                .iter()
360                .filter_map(|p| match p {
361                    WherePredicate::Type(t) => Some(t),
362                    _ => None,
363                })
364                .filter_map(|t| match &t.bounded_ty {
365                    Type::Path(type_path) => Some((type_path.path.clone(), t.clone())),
366                    _ => None,
367                })
368                .collect()
369        })
370        .unwrap_or_default()
371}
372
373fn generate_constructor_for_named(
374    constructor_name: &Ident,
375    field_name: &Ident,
376    field_type: &Type,
377    field_count: usize,
378) -> proc_macro2::TokenStream {
379    let field_arg_type = match field_type {
380        Type::Path(type_path) if is_builtin_numeric_type(&type_path.path) => quote! { #field_type },
381        _ => quote! { impl Into<#field_type> },
382    };
383    if field_count == 1 {
384        quote! {
385            pub fn #constructor_name(self, #field_name: #field_arg_type) -> Self {
386                Self {
387                    #field_name: #field_name.into(),
388                }
389            }
390        }
391    } else {
392        quote! {
393            pub fn #constructor_name(self, #field_name: #field_arg_type) -> Self {
394                Self {
395                    #field_name: #field_name.into(),
396                    ..self
397                }
398            }
399        }
400    }
401}
402
403fn generate_constructor_for_unnamed(
404    constructor_name: &Ident,
405    field_index: Index,
406    field_name: &Ident,
407    field_type: &Type,
408) -> proc_macro2::TokenStream {
409    let field_arg_type = match field_type {
410        Type::Path(type_path) if is_builtin_numeric_type(&type_path.path) => {
411            quote! { #field_type }
412        }
413        _ => quote! { impl Into<#field_type> },
414    };
415    quote! {
416        pub fn #constructor_name(mut self, #field_name: #field_arg_type) -> Self {
417            self.#field_index = #field_name.into();
418            self
419        }
420    }
421}
422
423/// Check if a path represents a built-in numeric type
424fn is_builtin_numeric_type(path: &Path) -> bool {
425    // Get the string representation of the path
426    let path_str = path.to_token_stream().to_string();
427
428    // Check for common numeric types
429    matches!(
430        path_str.as_str(),
431        "i8" | "i16"
432            | "i32"
433            | "i64"
434            | "i128"
435            | "isize"
436            | "u8"
437            | "u16"
438            | "u32"
439            | "u64"
440            | "u128"
441            | "usize"
442            | "f32"
443            | "f64"
444    )
445}