config_derive/
lib.rs

1mod attr;
2
3use std::collections::HashMap;
4
5#[cfg(feature = "clap")]
6use heck::ToKebabCase;
7use proc_macro::TokenStream;
8use proc_macro2::Span;
9use quote::quote;
10#[cfg(any(feature = "env", feature = "clap", feature = "ini"))]
11use syn::Generics;
12use syn::{DeriveInput, Error, Ident, Lit, Meta, Type};
13
14use crate::attr::AssignAttrs;
15
16/// proc macro to declare a configuration structure
17#[proc_macro_attribute]
18pub fn config(_attrs: TokenStream, item: TokenStream) -> TokenStream {
19    let strukt = syn::parse_macro_input!(item as DeriveInput);
20    let struct_name = strukt.ident.clone();
21    // let attrs = syn::parse_macro_input!(attrs as ConfigAttrs);
22    let mut fields_name = vec![];
23    let mut fields_doc: HashMap<String, Option<String>> = HashMap::new();
24    let mut fields_is_boolean = vec![];
25
26    let fields = if let syn::Data::Struct(fields) = &strukt.data {
27        fields.fields.iter().map(|field| {
28            let ty = field.ty.clone();
29            let ident = field.ident.clone();
30            let vis = field.vis.clone();
31            let colon = field.colon_token;
32
33            if let Some(ident) = &ident {
34                fields_name.push(ident.to_string());
35                fields_doc.insert(ident.to_string(), None);
36            }
37            // Detect boolean types for clap
38            match &ty {
39                Type::Path(pat) => {
40                    if let Some(type_name) = pat.path.get_ident() {
41                        // let ident_name = type_name.to_string();
42                        if type_name == "bool" {
43                            fields_is_boolean.push(true);
44                        } else {
45                            fields_is_boolean.push(false);
46                        }
47                    } else {
48                        fields_is_boolean.push(false);
49                    }
50                }
51                _ => {
52                    fields_is_boolean.push(false);
53                }
54            }
55            let attrs =
56                field
57                    .attrs
58                    .clone()
59                    .into_iter()
60                    .filter(|attr| match attr.path.get_ident() {
61                        Some(attr_ident) if attr_ident == "serde" => {
62                            attr.parse_args::<AssignAttrs>().is_err()
63                        }
64                        Some(attr_ident) if attr_ident == "clap" => false,
65                        Some(attr_ident) if attr_ident == "doc" => {
66                            let doc = match attr.parse_meta() {
67                                Ok(Meta::NameValue(ref nv)) if nv.path.is_ident("doc") => {
68                                    if let Lit::Str(lit_str) = &nv.lit {
69                                        Some(lit_str.value())
70                                    } else {
71                                        None
72                                    }
73                                }
74                                _ => None,
75                            };
76                            if let Some(ident) = &ident {
77                                *fields_doc.get_mut(&ident.to_string()).unwrap() = doc;
78                            }
79                            true
80                        }
81                        _ => true,
82                    });
83            let res = quote! {
84                #(#attrs)*
85                #vis #ident #colon Option<#ty>,
86            };
87
88            res
89        })
90    } else {
91        return TokenStream::from(
92            Error::new_spanned(strukt, "must be a struct").to_compile_error(),
93        );
94    };
95    let struct_vis = strukt.vis.clone();
96    let struct_gen = strukt.generics.clone();
97    let struct_where = strukt.generics.where_clause.clone();
98    let struct_attrs =
99        strukt.attrs.clone().into_iter().filter(
100            |attr| matches!(attr.path.get_ident(), Some(attr_ident) if attr_ident == "serde"),
101        );
102    let opt_struct_name = Ident::new(format!("Opt{}", struct_name).as_str(), Span::call_site());
103
104    let opt_struct = quote! {
105        #(#struct_attrs)*
106        #struct_vis struct #opt_struct_name #struct_gen #struct_where {
107            #(#fields)*
108        }
109    };
110
111    #[cfg(feature = "clap")]
112    let docs = fields_doc.values().map(|doc| match doc {
113        Some(doc) => doc.trim().to_string(),
114        None => String::new(),
115    });
116
117    let json_branch = build_json_branch();
118    let toml_branch = build_toml_branch();
119    let yaml_branch = build_yaml_branch();
120    let dhall_branch = build_dhall_branch();
121    let default_trait_branch = build_default_trait_branch();
122    let custom_fn_branch = build_custom_fn_branch();
123
124    #[cfg(not(feature = "ini"))]
125    let ini_branch = quote! {};
126    #[cfg(feature = "ini")]
127    let ini_branch = build_ini_branch(&opt_struct_name, &struct_gen);
128
129    #[cfg(not(feature = "env"))]
130    let env_branch = quote! {};
131    #[cfg(feature = "env")]
132    let env_branch = build_env_branch(&opt_struct_name, &struct_gen);
133
134    #[cfg(not(feature = "clap"))]
135    let (clap_branch, clap_method) = (quote! {}, quote! {});
136    #[cfg(feature = "clap")]
137    let (clap_branch, clap_method) = build_clap_branch(
138        &opt_struct_name,
139        &struct_gen,
140        &fields_name,
141        &fields_is_boolean,
142        docs,
143    );
144
145    #[cfg(all(not(feature = "default_trait"), not(feature = "custom_fn")))]
146    let derive_serialize = quote! {};
147    #[cfg(any(feature = "default_trait", feature = "custom_fn"))]
148    let derive_serialize = quote! { #[derive(::twelf::reexports::serde::Serialize)] };
149
150    let code = quote! {
151        #derive_serialize
152        #[derive(::twelf::reexports::serde::Deserialize)]
153        #[serde(crate = "::twelf::reexports::serde")]
154        #strukt
155
156        impl #struct_gen #struct_name #struct_gen #struct_where {
157            pub fn with_layers(layers: &[::twelf::Layer]) -> Result<Self, ::twelf::Error> {
158                use std::iter::FromIterator;
159                let mut res: std::collections::HashMap<String, ::twelf::reexports::serde_json::Value> = std::collections::HashMap::new();
160                for layer in layers {
161                    let (extension,defaulted) = Self::parse_twelf(layer)?;
162                    let extension: std::collections::HashMap<_,_> = extension
163                        .as_object()
164                        .ok_or_else(|| ::twelf::Error::InvalidFormat)?
165                        .to_owned()
166                        .into_iter().filter(|(k, v)| (!defaulted.contains_key(k) || !defaulted[k] || !res.contains_key(k)) && !v.is_null())
167                        .collect(); // must collect, as filter uses res
168
169                    res.extend(extension);
170                }
171
172                ::twelf::reexports::log::debug!(target: "twelf", "configuration:");
173                for (key, val) in &res {
174                    ::twelf::reexports::log::debug!(target: "twelf", "{}={}", key, val);
175                }
176
177                ::twelf::reexports::serde_json::from_value(::twelf::reexports::serde_json::Value::Object(::twelf::reexports::serde_json::Map::from_iter(res.into_iter()))).map_err(|e| ::twelf::Error::Deserialize(format!("config error: {}", e.to_string())))
178            }
179            #clap_method
180
181            fn parse_twelf(priority: &::twelf::Layer) -> Result<(::twelf::reexports::serde_json::Value,std::collections::HashMap<String,bool>), ::twelf::Error>
182            {
183                #[derive(::twelf::reexports::serde::Deserialize, ::twelf::reexports::serde::Serialize)]
184                #[serde(crate = "::twelf::reexports::serde")]
185                #opt_struct
186
187                let (res,defaulted) = match priority {
188                    #env_branch
189                    #json_branch
190                    #toml_branch
191                    #yaml_branch
192                    #dhall_branch
193                    #ini_branch
194                    #clap_branch
195                    #default_trait_branch
196                    #custom_fn_branch
197                    other => unimplemented!("{:?}", other)
198                };
199
200                Ok((res,defaulted.unwrap_or(std::collections::HashMap::new())))
201            }
202        }
203    };
204
205    TokenStream::from(code)
206}
207
208#[cfg(feature = "clap")]
209fn build_clap_branch(
210    opt_struct_name: &Ident,
211    struct_gen: &Generics,
212    fields_name: &[String],
213    fields_is_boolean: &[bool],
214    docs: impl Iterator<Item = String>,
215) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
216    let field_names_clap = fields_name
217        .iter()
218        .map(|field_name| field_name.to_kebab_case());
219
220    let field_names_clap_cloned = field_names_clap.clone();
221    let clap_branch = quote! { ::twelf::Layer::Clap(matches) => {
222        let mut map: std::collections::HashMap<String, String> = std::collections::HashMap::new();
223        let mut defaulted: std::collections::HashMap<String, bool> = std::collections::HashMap::new();
224
225        #(
226            let field = String::from(#fields_name);
227
228            let mut insert_into_map = |vals: ::twelf::reexports::clap::parser::RawValues, is_default: bool| {
229                let mut key = field.clone();
230                defaulted.insert(key.clone(), is_default);
231
232                for val in vals.into_iter() {
233                    // hacky way of formatting everything to a string:
234                    let s = val.to_string_lossy().into_owned();
235                    let s = s.strip_prefix("\"").unwrap_or(&s);
236                    let s = s.strip_suffix("\"").unwrap_or(&s);
237
238                    if let Some(existing_val) = map.get_mut(&key) {
239                        *existing_val = existing_val.clone() + "," + s;
240                    } else {
241                        map.insert(key.clone(), s.to_string());
242                    }
243                }
244            };
245
246            if let Some(vals) = matches.try_get_raw(#fields_name).unwrap_or(None) {
247                let is_default = matches.value_source(#fields_name).unwrap() == ::twelf::reexports::clap::parser::ValueSource::DefaultValue;
248                insert_into_map(vals, is_default);
249            } else if let Some(vals) = matches.try_get_raw(#field_names_clap_cloned).unwrap_or(None) {
250                let is_default = matches.value_source(#field_names_clap_cloned).unwrap() == ::twelf::reexports::clap::parser::ValueSource::DefaultValue;
251                insert_into_map(vals, is_default);
252            }
253        )*
254
255        let tmp_cfg: #opt_struct_name #struct_gen = ::twelf::reexports::envy::from_iter(map.into_iter())?;
256        (::twelf::reexports::serde_json::to_value(tmp_cfg)?,Some(defaulted))
257    },};
258    let clap_method = quote! { pub fn clap_args() -> Vec<::twelf::reexports::clap::Arg> {
259        vec![#(
260           ::twelf::reexports::clap::Arg::new(#field_names_clap).long(#field_names_clap).help(#docs)
261            .action(if (#fields_is_boolean) {
262                ::twelf::reexports::clap::ArgAction::SetTrue
263            } else {
264                ::twelf::reexports::clap::ArgAction::Set
265            })
266        ),*]
267    }};
268    (clap_branch, clap_method)
269}
270
271#[cfg(feature = "env")]
272fn build_env_branch(opt_struct_name: &Ident, struct_gen: &Generics) -> proc_macro2::TokenStream {
273    quote! { ::twelf::Layer::Env(prefix) => match prefix {
274        Some(prefix) => {
275            let tmp_cfg: #opt_struct_name #struct_gen = ::twelf::reexports::envy::prefixed(prefix).from_env()?;
276            (::twelf::reexports::serde_json::to_value(tmp_cfg)?,None)
277        },
278        None => {
279            let tmp_cfg: #opt_struct_name #struct_gen = ::twelf::reexports::envy::from_env()?;
280            (::twelf::reexports::serde_json::to_value(tmp_cfg)?,None)
281        },
282    },}
283}
284
285#[cfg(any(
286    feature = "json",
287    feature = "yaml",
288    feature = "toml",
289    feature = "ini",
290    feature = "dhall"
291))]
292fn build_shellexpand() -> proc_macro2::TokenStream {
293    #[cfg(feature = "shellexpand")]
294    quote! { let content = ::twelf::reexports::shellexpand::env(&file_content)? }
295
296    #[cfg(not(feature = "shellexpand"))]
297    quote! { let content = file_content }
298}
299
300fn build_json_branch() -> proc_macro2::TokenStream {
301    #[cfg(feature = "json")]
302    let shellexpand = build_shellexpand();
303    #[cfg(feature = "json")]
304    let json_branch = quote! { ::twelf::Layer::Json(filepath) => {
305        let file_content = std::fs::read_to_string(filepath)?;
306        #shellexpand;
307        (::twelf::reexports::serde_json::from_str(&content)?,None)
308    }, };
309    #[cfg(not(feature = "json"))]
310    let json_branch = quote! {};
311    json_branch
312}
313
314fn build_toml_branch() -> proc_macro2::TokenStream {
315    #[cfg(feature = "toml")]
316    let shellexpand = build_shellexpand();
317    #[cfg(feature = "toml")]
318    let toml_branch = quote! { ::twelf::Layer::Toml(filepath) => {
319        let file_content = std::fs::read_to_string(filepath)?;
320        // Strip out comments (lines starting with #)
321        let file_content = file_content.lines().filter(|line| !line.trim().starts_with("#")).collect::<Vec<_>>().join("\n");
322
323        #shellexpand;
324        (::twelf::reexports::toml::from_str(&content)?,None)
325    }, };
326    #[cfg(not(feature = "toml"))]
327    let toml_branch = quote! {};
328    toml_branch
329}
330
331fn build_yaml_branch() -> proc_macro2::TokenStream {
332    #[cfg(feature = "yaml")]
333    let shellexpand = build_shellexpand();
334    #[cfg(feature = "yaml")]
335    let yaml_branch = quote! { ::twelf::Layer::Yaml(filepath) => {
336        let file_content = std::fs::read_to_string(filepath)?;
337        // Strip out comments (lines starting with #)
338        let file_content = file_content.lines().filter(|line| !line.trim().starts_with("#")).collect::<Vec<_>>().join("\n");
339        #shellexpand;
340        (::twelf::reexports::serde_yaml::from_str(&content)?,None)
341    }, };
342    #[cfg(not(feature = "yaml"))]
343    let yaml_branch = quote! {};
344    yaml_branch
345}
346
347fn build_dhall_branch() -> proc_macro2::TokenStream {
348    #[cfg(feature = "dhall")]
349    let shellexpand = build_shellexpand();
350    #[cfg(feature = "dhall")]
351    let dhall_branch = quote! { ::twelf::Layer::Dhall(filepath) => {
352        let file_content = std::fs::read_to_string(filepath)?;
353        // Strip out comments (lines starting with --)
354        let file_content = file_content.lines().filter(|line| !line.trim().starts_with("--")).collect::<Vec<_>>().join("\n");
355
356        #shellexpand;
357        (::twelf::reexports::serde_dhall::from_str(&content).parse()?,None)
358    }, };
359    #[cfg(not(feature = "dhall"))]
360    let dhall_branch = quote! {};
361    dhall_branch
362}
363
364#[cfg(feature = "ini")]
365fn build_ini_branch(opt_struct_name: &Ident, struct_gen: &Generics) -> proc_macro2::TokenStream {
366    let shellexpand = build_shellexpand();
367    quote! { ::twelf::Layer::Ini(filepath) => {
368        let file_content = std::fs::read_to_string(filepath)?;
369        // Strip out comments (lines starting with ;)
370        let file_content = file_content.lines().filter(|line| !line.trim().starts_with(";")).collect::<Vec<_>>().join("\n");
371        #shellexpand;
372       let tmp_cfg: #opt_struct_name #struct_gen = ::twelf::reexports::serde_ini::from_str(&content)?;
373       (::twelf::reexports::serde_json::to_value(tmp_cfg)?,None)
374    }, }
375}
376
377fn build_default_trait_branch() -> proc_macro2::TokenStream {
378    #[cfg(feature = "default_trait")]
379    let default_trait_branch = quote! { ::twelf::Layer::DefaultTrait => (::twelf::reexports::serde_json::to_value(<Self as std::default::Default>::default())?,None), };
380    #[cfg(not(feature = "default_trait"))]
381    let default_trait_branch = quote! {};
382    default_trait_branch
383}
384
385fn build_custom_fn_branch() -> proc_macro2::TokenStream {
386    #[cfg(feature = "custom_fn")]
387    let custom_branch =
388        quote! { ::twelf::Layer::CustomFn(custom_fn) => (custom_fn.clone().0(),None), };
389    #[cfg(not(feature = "custom_fn"))]
390    let custom_branch = quote! {};
391    custom_branch
392}