partial_config_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro_error2::proc_macro_error;
3use proc_macro_error2::{OptionExt, ResultExt};
4use quote::ToTokens;
5use std::collections::{BTreeSet, HashMap};
6use syn::{
7    punctuated::Punctuated, token::Comma, Attribute, DeriveInput, Field, Generics, Ident, Meta,
8};
9
10#[proc_macro_error]
11#[proc_macro_derive(
12    HasPartial,
13    attributes(partial_derives, partial_rename, env_source, env, partial_only)
14)]
15pub fn has_partial(input: TokenStream) -> TokenStream {
16    let DeriveInput {
17        ident,
18        generics,
19        data,
20        attrs,
21        vis,
22    } = syn::parse_macro_input!(input as DeriveInput);
23    // TODO: support inheriting `pub(crate)
24    // TODO: panic on generics
25
26    let partial_ident = partial_struct_name(&ident, &attrs);
27
28    match vis {
29        syn::Visibility::Public(_) => {}
30        _ => {
31            proc_macro_error2::abort!(vis, "Cannot implement `HasPartial` for a private structure.";
32                help = "If your structure is private, it is better to convert to it with an `Into::into` rather than directly derive `HasPartial`, which by definition will expose some of the fields"
33            )
34        }
35    };
36
37    let strct = match data {
38        syn::Data::Struct(thing) => thing,
39        syn::Data::Enum(_) => {
40            proc_macro_error2::abort!(
41                ident, "Enums are not supported.";
42                help = "While it is possible to support `enum`s in principle, this is most likely an X-Y problem. You should use `partial_config` to build your internal `enum` with an extra layer.",
43            );
44        }
45        syn::Data::Union(_) => {
46            proc_macro_error2::abort!(
47                ident, "Data unions are not supported.";
48                help = "Data unions are not usually used in Safe Rust, even though they could be, this is discouraged in favour of Enums, which are not supported either. Consider using a `struct` instead."
49            );
50        }
51    };
52
53    let fields = match strct.fields {
54        syn::Fields::Named(namede) => namede.named,
55        syn::Fields::Unnamed(flds) => {
56            proc_macro_error2::abort!(
57                flds, "Unnamed fields can't be named in configuration layers.";
58                help = "If the field is unnamed, I cannot find a consistent way of naming them in configuration layers, because they muse be human facing. You are probably applying this derive macro to a tuple structure, which is not a sensible input."
59            );
60        }
61        syn::Fields::Unit => {
62            proc_macro_error2::abort!(
63                strct.fields, "Unit fields cannot be named.";
64                help = "If the field is unnamed, I cannot find a consistent way of naming them in configuration layers. Add a dummy field with e.g. `PhantomData` to silence this error."
65            );
66        }
67    };
68
69    let (optional_fields, required_fields): (Punctuated<Field, Comma>, Punctuated<Field, Comma>) =
70        fields.into_iter().partition(|field| is_option(&field.ty));
71
72    let required_fields: Punctuated<Field, Comma> = required_fields
73        .into_iter()
74        .map(|field| {
75            let ty = field.ty;
76            let ty: syn::Type = syn::parse_quote! { Option<#ty>};
77            Field { ty, ..field }
78        })
79        .collect();
80
81    let impl_has_partial = quote::quote! {
82        impl #generics ::partial_config::HasPartial for #ident #generics {
83            type Partial = #partial_ident #generics;
84        }
85    };
86
87    let impl_partial = impl_partial(
88        &generics,
89        &ident,
90        &partial_ident,
91        &required_fields,
92        &optional_fields,
93    )
94    .unwrap();
95
96    let all_fields: Punctuated<Field, Comma> = optional_fields
97        .iter()
98        .cloned()
99        .chain(required_fields.iter().cloned())
100        .map(|field| Field {
101            attrs: field
102                .attrs
103                .into_iter()
104                .filter(|attr| !attr.path().is_ident("env"))
105                .map(|attr| {
106                    if attr.path().is_ident("partial_only") {
107                        let contents: syn::Meta = attr
108                            .parse_args()
109                            .expect_or_abort("Attribute failed to parse");
110                        syn::parse_quote! {
111                            #[#contents]
112                        }
113                    } else {
114                        attr
115                    }
116                })
117                .collect(),
118            ..field
119        })
120        .collect();
121
122    // TODO: Forward all other derives unless otherwise specified.
123    // Do not remove serde unless required to
124    let derives: Vec<Attribute> = attribute_assign(&attrs);
125
126    let output = quote::quote! {
127        #(#derives)*
128        pub struct #partial_ident #generics {
129            #all_fields
130        }
131
132        #[automatically_derived]
133        #impl_partial
134
135        #[automatically_derived]
136        #impl_has_partial
137    };
138    TokenStream::from(output)
139}
140
141fn partial_struct_name(ident: &Ident, attrs: &Vec<Attribute>) -> Ident {
142    let mut ident = quote::format_ident!("Partial{}", ident);
143    for attr in attrs {
144        if attr.path().is_ident("partial_rename") {
145            let identifier: Ident = attr
146                .parse_args()
147                .expect_or_abort("Failed to parse partial_rename identifier");
148            ident = identifier;
149        }
150    }
151    ident
152}
153
154fn attribute_assign(attrs: &Vec<Attribute>) -> Vec<Attribute> {
155    let mut derives: Punctuated<syn::Path, Comma> = Punctuated::new();
156    let mut out_attrs: Vec<Attribute> = Vec::new();
157    for attr in attrs {
158        if attr.path().is_ident("partial_derives") {
159            let nested = attr
160                .parse_args_with(Punctuated::<Meta, syn::Token![,]>::parse_terminated)
161                .expect_or_abort("Invalid specification for `partial_derives`");
162            for item in nested {
163                match item {
164                    Meta::Path(pth) =>  {
165                        derives.push(pth);
166                    },
167                    item => proc_macro_error2::abort!(item, "The paths specified must be specific derive macros, e.g. Clone, got {} instead, which is not allowed", item.to_token_stream())
168                }
169            }
170        } else if attr.path().is_ident("partial_only") {
171            let contents: syn::Meta = attr
172                .parse_args()
173                .expect_or_abort("Attributes failed to parse");
174            out_attrs.push(syn::parse_quote! {
175                #[#contents]
176            })
177        }
178    }
179
180    // TODO: emit warning
181    if !derives.iter().any(|thing| thing.is_ident("Default")) {
182        derives.push(syn::parse_quote! {Default});
183    }
184    vec![syn::parse_quote! {
185        #[derive(#derives)]
186    }]
187}
188
189fn impl_partial(
190    generics: &Generics,
191    ident: &Ident,
192    partial_ident: &Ident,
193    required_fields: &Punctuated<Field, Comma>,
194    optional_fields: &Punctuated<Field, Comma>,
195) -> Result<proc_macro2::TokenStream, &'static str> {
196    let error: syn::Expr = syn::parse_quote! {
197        ::core::result::Result::Err(::partial_config::Error::MissingFields {
198            required_fields: missing_fields
199        })
200    };
201
202    let opt_fields: Punctuated<Ident, Comma> = optional_fields
203        .iter()
204        .cloned()
205        .filter_map(|field| field.ident)
206        .collect();
207
208    let req_fields: Punctuated<Ident, Comma> = required_fields
209        .iter()
210        .cloned()
211        .filter_map(|field| field.ident)
212        .collect();
213
214    let assembling_config: syn::Stmt = assembling_config(req_fields.len(), opt_fields.len());
215
216    let req_field_expr: Punctuated<syn::Stmt, syn::token::Semi> = req_fields
217        .iter()
218        .cloned()
219        .map(|ident| -> syn::Stmt {
220            syn::parse_quote! {
221                let #ident = match self.#ident {
222                    Some(value) => value,
223                    None => {
224                        missing_fields.push(::partial_config::MissingField(stringify!(#ident)));
225                        Default::default()
226                    }
227                };
228            }
229        })
230        .collect();
231
232    let opt_field_expr: Punctuated<syn::Stmt, syn::token::Semi> = optional_fields
233        .iter()
234        .cloned()
235        .filter_map(|field: Field| {
236            field.ident.map(|ident| -> syn::Stmt {
237                // TODO: add explicit fallback
238                syn::parse_quote! {
239                    let #ident = self.#ident;
240                }
241            })
242        })
243        .collect();
244
245    let all_fields: Punctuated<Ident, Comma> = opt_fields
246        .into_iter()
247        .chain(req_fields.into_iter())
248        .collect();
249
250    let override_expr: Punctuated<syn::Stmt, syn::token::Semi> = all_fields
251        .iter()
252        .cloned()
253        .map(|ident: Ident| -> syn::Stmt {
254            syn::parse_quote! {
255                let #ident = other.#ident.or(self.#ident);
256            }
257        })
258        .collect();
259
260    Ok(quote::quote! {
261        impl #generics ::partial_config::Partial for #partial_ident #generics {
262            type Target = #ident #generics;
263
264            type Error = ::partial_config::Error;
265
266            fn build(self) -> Result<Self::Target, Self::Error> {
267                let mut missing_fields = ::std::vec::Vec::new();
268                #assembling_config;
269
270                #req_field_expr
271                #opt_field_expr
272
273                if !missing_fields.is_empty() {
274                    #error
275                } else {
276                    Ok(
277                        Self::Target {
278                            #all_fields
279                        }
280                    )
281                }
282            }
283
284            fn override_with(self, other: Self) -> Self {
285                #override_expr
286                Self {
287                    #all_fields
288                }
289
290            }
291        }
292    })
293}
294
295fn is_option(ty: &syn::Type) -> bool {
296    match ty {
297        syn::Type::Path(path) => path
298            .path
299            .segments
300            .last()
301            .map(|segment| segment.ident == "Option")
302            .unwrap_or(false),
303        _ => false,
304    }
305}
306
307fn extract_option_generic(ty: &syn::Type) -> syn::Type {
308    match ty {
309        syn::Type::Path(path) => path
310            .path
311            .segments
312            .last()
313            .map(|segment| match &segment.arguments {
314                syn::PathArguments::None => {
315                    proc_macro_error2::abort!(segment, "The Option does not have any arguments")
316                }
317                syn::PathArguments::Parenthesized(_) => proc_macro_error2::abort!(
318                    segment,
319                    "The option cannot have parenthesised arguments"
320                ),
321                syn::PathArguments::AngleBracketed(generics) => {
322                    match generics
323                        .args
324                        .first()
325                        .expect_or_abort("Cannot have an empty set of generic arguments")
326                    {
327                        syn::GenericArgument::Lifetime(_) => todo!(),
328                        syn::GenericArgument::Type(ty) => ty.clone(),
329                        syn::GenericArgument::Const(_) => todo!(),
330                        syn::GenericArgument::AssocType(_) => todo!(),
331                        syn::GenericArgument::AssocConst(_) => todo!(),
332                        syn::GenericArgument::Constraint(_) => todo!(),
333                        _ => todo!(),
334                    }
335                }
336            })
337            .expect_or_abort("Failed to obtain type"),
338        _ => todo!("Not implemented yet"),
339    }
340}
341
342#[cfg(all(feature = "tracing", feature = "log"))]
343compile_error!("The features \"tracing\" and \"log\" are mutually exclusive. Please either use pure tracing, or enable the \"log\" feature in \"tracing\" and use the \"log\" feature of this crate. ");
344
345fn assembling_config(required_fields_count: usize, optional_fields_count: usize) -> syn::Stmt {
346    #[cfg(feature = "tracing")]
347    syn::parse_quote! {
348        {
349            ::tracing::info!(?self, "Building configuration {required_fields_count} ({optional_fields_count})", required_fields_count = #required_fields_count, optional_fields_count=#optional_fields_count);
350        }
351    }
352    #[cfg(feature = "log")]
353    syn::parse_quote! {
354        ::log::info!("Building configuration. {required_fields_count} ({optional_fields_count}) fields", required_fields_count = #required_fields_count, optional_fields_count=#optional_fields_count);
355    }
356    #[cfg(not(any(feature = "tracing", feature = "log")))]
357    syn::parse_quote! {
358        println!("Building configuration. {required_fields_count} ({optional_fields_count}) fields", required_fields_count = #required_fields_count, optional_fields_count=#optional_fields_count);
359    }
360}
361
362#[proc_macro_error]
363#[proc_macro_derive(EnvSourced, attributes(env_var_rename, env))]
364pub fn env_sourced(input: TokenStream) -> TokenStream {
365    let DeriveInput {
366        data,
367        attrs,
368        ident: in_ident,
369        ..
370    } = syn::parse_macro_input!(input as DeriveInput);
371
372    let out_ident: Ident = env_var_struct_name(attrs);
373    let strct = match data {
374        syn::Data::Struct(strct) => strct,
375        syn::Data::Enum(_) => panic!("Enums are not supported"),
376        syn::Data::Union(_) => panic!("Data unions are not supported"),
377    };
378
379    let fields: Punctuated<Field, Comma> = match strct.fields {
380        syn::Fields::Named(fld) => fld.named,
381        _ => unreachable!(),
382    };
383
384    let EnvVarFieldsResult {
385        fields: all_fields,
386        default_mappings,
387    } = env_var_fields(&fields);
388
389    let default_struct = impl_default_env(default_mappings);
390    let impl_source = impl_source(&fields);
391
392    let output = quote::quote! {
393    pub struct #out_ident<'a> {
394        #all_fields
395    }
396
397    impl<'a> ::partial_config::env::EnvSourced<'a> for #in_ident {
398        type Source = #out_ident<'a>;
399    }
400
401    impl<'a> #out_ident<'a> {
402        pub const fn new() -> Self {
403            #default_struct
404        }
405    }
406
407    impl<'a> Default for #out_ident<'a> {
408        fn default() -> Self {
409            #default_struct
410        }
411    }
412
413    impl<'a> ::partial_config::Source<#in_ident> for #out_ident<'a> {
414        type Error = ::partial_config::Error;
415
416        fn to_partial(self) -> Result<<#in_ident as ::partial_config::HasPartial>::Partial, Self::Error> {
417            pub type Issue86935Workaround = <#in_ident as ::partial_config::HasPartial>::Partial;
418
419            Ok(Issue86935Workaround {
420                #impl_source
421            })
422        }
423
424        fn name(&self) -> String {
425            "Environment Variables".to_owned()
426        }
427    }
428    };
429    TokenStream::from(output)
430}
431
432struct EnvVarFieldsResult {
433    fields: Punctuated<Field, Comma>,
434    default_mappings: HashMap<Ident, BTreeSet<Ident>>,
435}
436
437fn is_string(ty: &syn::Type) -> bool {
438    match ty {
439        syn::Type::Path(pth) => pth.path.is_ident("String") || pth.path.is_ident("str"),
440        syn::Type::Reference(reference) => is_string(&reference.elem),
441        _ => false,
442    }
443}
444
445fn impl_source(fields: &Punctuated<Field, Comma>) -> Punctuated<syn::FieldValue, Comma> {
446    fields
447        .iter()
448        .map(|Field { ident, ty, .. }| -> syn::FieldValue {
449            if let Some(ident) = ident {
450                if is_string(&ty) {
451                    syn::parse_quote! {
452                        #ident: ::partial_config::env::extract(&self.#ident)?
453                    }
454                } else {
455                    let inner_ty = if is_option(ty) {
456                        extract_option_generic(ty)
457                    } else {
458                        ty.clone()
459                    };
460                    syn::parse_quote! {
461                        #ident: ::partial_config::env::extract(&self.#ident)?
462                        .map(|s: String| <#inner_ty as ::core::str::FromStr>::from_str(&s))
463                        .transpose()
464                        .map_err(|e|
465                            ::partial_config::Error::ParseFieldError {
466                                field_name: stringify!(#ident),
467                                field_type: stringify!(#ty),
468                                error_condition: Box::new(e)
469                            })?
470                    }
471                }
472            } else {
473                proc_macro_error2::abort!(ident, "Non-struct like fields are not allowed");
474            }
475        })
476        .collect()
477}
478
479fn impl_default_env(default_mappings: HashMap<Ident, BTreeSet<Ident>>) -> syn::ExprStruct {
480    let elements: Punctuated<syn::FieldValue, Comma> = default_mappings
481        .iter()
482        .map(|(field_name, env_var_strings)| -> syn::FieldValue {
483            let env_var_strings: Punctuated<syn::LitStr, Comma> = env_var_strings
484                .iter()
485                .cloned()
486                .map(|ident| -> syn::LitStr {
487                    syn::LitStr::new(&ident.to_string(), proc_macro2::Span::call_site())
488                })
489                .collect();
490            syn::parse_quote! {
491                #field_name: [#env_var_strings]
492            }
493        })
494        .collect();
495
496    syn::parse_quote! {
497        Self {
498            #elements
499        }
500    }
501}
502
503fn env_var_fields(fields: &Punctuated<Field, Comma>) -> EnvVarFieldsResult {
504    let mut output = Punctuated::new();
505    let mut default_mappings: HashMap<Ident, BTreeSet<Ident>> = HashMap::new();
506    for field in fields {
507        let mut n = 0_usize;
508        field.attrs.iter().for_each(|attr| {
509            if attr.path().is_ident("env") {
510                let nested = attr.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated).expect_or_abort("Invalid specification for the `env` attribute");
511                let env_vars: BTreeSet<Ident> = nested.iter().
512                    filter_map(|item| {
513                        match item {
514                            Meta::Path(pth) => Some(pth.get_ident().expect_or_abort("Must have identifier and not a path").clone()),
515                            _ => None
516                        }
517                    })
518                    .collect();
519                n+=env_vars.len();
520                let key = field.ident.clone().expect_or_abort("Identifiers for all fields must be known at this point");
521                default_mappings.entry(key.clone())
522                    .and_modify(|previous| {
523                        if !previous.is_disjoint(&env_vars) {
524                            proc_macro_error2::emit_error!(key, "Environment variable specifications must be disjoint. The field {key} has the following duplicate specifications {:?}",
525                                previous.intersection(&env_vars).map(|ident| ident.to_string()).collect::<Vec<_>>());
526                        }
527                        previous.extend(env_vars.iter().cloned())
528                    })
529                    .or_insert(env_vars);
530            }
531        });
532        if n == 0 {
533            proc_macro_error2::emit_error!(field.ident, "At least one `env` directive must be specified";
534                help = "Try using an uppercase version of the field name: {}", field.ident.to_token_stream().to_string().to_uppercase();
535                note = "It is better to enforce that all env-var deserializeable fields are explicitly set in the code.")
536        }
537        // TODO: check uniqueness in leaf nodes
538        // TODO: Check for empty nodes and replace with uppercase
539        let ty: syn::Type = syn::parse_quote! {
540            [&'a str; #n]
541        };
542
543        output.push(Field {
544            ty,
545            attrs: vec![],
546            ..field.clone()
547        });
548    }
549
550    EnvVarFieldsResult {
551        fields: output,
552        default_mappings,
553    }
554}
555
556fn env_var_struct_name(attrs: Vec<Attribute>) -> Ident {
557    let mut ident = syn::parse_quote! { EnvVarSource };
558    for attr in attrs {
559        if attr.path().is_ident("env_var_rename") {
560            let identifier: Ident = attr
561                .parse_args()
562                .expect_or_abort("Failed to parse env_var_rename identifier. ");
563            ident = identifier;
564        }
565    }
566    ident
567}