clap_config_file/
lib.rs

1//! A single-derive macro merging Clap + config, defaulting field names to kebab-case.
2//! Now supports bool fields with or without default_value, avoiding parse errors.
3
4use heck::ToKebabCase;
5use proc_macro::TokenStream;
6use proc_macro2::{Span, TokenStream as TokenStream2};
7use quote::quote;
8use syn::{parse_macro_input, DeriveInput, Error, LitStr};
9
10mod parse_attrs;
11use parse_attrs::*;
12
13#[proc_macro_derive(
14    ClapConfigFile,
15    attributes(config_file_name, config_file_formats, config_arg)
16)]
17pub fn derive_clap_config_file(input: TokenStream) -> TokenStream {
18    let ast = parse_macro_input!(input as DeriveInput);
19    match build_impl(ast) {
20        Ok(ts) => ts.into(),
21        Err(e) => e.to_compile_error().into(),
22    }
23}
24
25fn build_impl(ast: DeriveInput) -> syn::Result<TokenStream2> {
26    let struct_ident = &ast.ident;
27    let generics = &ast.generics;
28
29    let macro_cfg = parse_struct_level_attrs(&ast.attrs)?;
30
31    let fields_named = match &ast.data {
32        syn::Data::Struct(syn::DataStruct {
33            fields: syn::Fields::Named(ref named),
34            ..
35        }) => &named.named,
36        _ => {
37            return Err(Error::new_spanned(
38                &ast.ident,
39                "ClapConfigFile only supports a struct with named fields.",
40            ))
41        }
42    };
43
44    let field_infos = parse_fields(fields_named)?;
45    let parse_info_impl = generate_parse_info_impl(struct_ident, &field_infos, &macro_cfg);
46
47    let debug_impl = generate_debug_impl(struct_ident, generics, &field_infos);
48    let serialize_impl = generate_serialize_impl(struct_ident, generics, &field_infos);
49
50    let expanded = quote! {
51        impl #generics #struct_ident #generics {
52            pub fn parse_info() -> (Self, Option<std::path::PathBuf>, Option<&'static str>) {
53                #parse_info_impl
54            }
55            pub fn parse() -> Self {
56                Self::parse_info().0
57            }
58        }
59
60        #debug_impl
61        #serialize_impl
62    };
63
64    Ok(expanded)
65}
66
67/// Generate parse_info: ephemeral CLI + ephemeral config => unify.
68fn generate_parse_info_impl(
69    struct_ident: &syn::Ident,
70    fields: &[FieldInfo],
71    macro_cfg: &MacroConfig,
72) -> TokenStream2 {
73    let base_name = &macro_cfg.base_name;
74    let fmts = &macro_cfg.formats;
75    let fmts_list: Vec<_> = fmts.iter().map(|s| s.as_str()).collect();
76
77    // ephemeral CLI
78    let cli_ident = syn::Ident::new(&format!("__{}_Cli", struct_ident), Span::call_site());
79    let cli_fields = fields
80        .iter()
81        .filter(|f| {
82            !matches!(
83                f.arg_attrs.availability,
84                FieldAvailability::ConfigOnly | FieldAvailability::Internal
85            )
86        })
87        .map(generate_cli_field);
88
89    let cli_extras = quote! {
90        #[clap(long="no-config", default_value_t=false, help="Do not use a config file")]
91        __no_config: bool,
92
93        #[clap(long="config-file", help="Path to the config file")]
94        __config_file: Option<std::path::PathBuf>,
95    };
96    let build_cli_struct = quote! {
97        #[derive(::clap::Parser, ::std::fmt::Debug, ::std::default::Default)]
98        struct #cli_ident {
99            #cli_extras
100            #(#cli_fields),*
101        }
102    };
103
104    // ephemeral config
105    let cfg_ident = syn::Ident::new(&format!("__{}_Cfg", struct_ident), Span::call_site());
106    let cfg_fields = fields
107        .iter()
108        .filter(|f| {
109            !matches!(
110                f.arg_attrs.availability,
111                FieldAvailability::CliOnly | FieldAvailability::Internal
112            )
113        })
114        .map(generate_config_field);
115    let build_cfg_struct = quote! {
116        #[derive(::serde::Deserialize, ::std::fmt::Debug, ::std::default::Default)]
117        struct #cfg_ident {
118            #(#cfg_fields),*
119        }
120    };
121
122    let unify_stmts = fields.iter().map(unify_field);
123
124    let inline_helpers = quote! {
125        fn __inline_guess_format(path: &std::path::Path, known_formats: &[&str]) -> Option<&'static str> {
126            if let Some(ext) = path.extension().and_then(|e| e.to_str()).map(|s| s.to_lowercase()) {
127                for &f in known_formats {
128                    if ext == f {
129                        return Some(Box::leak(f.to_string().into_boxed_str()));
130                    }
131                }
132            }
133            None
134        }
135
136        fn __inline_find_config(base_name: &str, fmts: &[&str]) -> Option<std::path::PathBuf> {
137            let mut dir = std::env::current_dir().ok()?;
138            let mut found: Option<std::path::PathBuf> = None;
139
140            loop {
141                let mut found_this = vec![];
142                for &f in fmts {
143                    let candidate = dir.join(format!("{}.{}", base_name, f));
144                    if candidate.is_file() {
145                        found_this.push(candidate);
146                    }
147                }
148                if found_this.len() > 1 {
149                    eprintln!("Error: multiple config files in same dir: {:?}", found_this);
150                    std::process::exit(2);
151                } else if found_this.len() == 1 {
152                    if found.is_some() {
153                        eprintln!(
154                            "Error: multiple config files found walking up: {:?} and {:?}",
155                            found.as_ref().unwrap(), found_this[0]
156                        );
157                        std::process::exit(2);
158                    }
159                    found = Some(found_this.remove(0));
160                }
161                if !dir.pop() {
162                    break;
163                }
164            }
165            found
166        }
167    };
168
169    quote! {
170        #build_cli_struct
171        #build_cfg_struct
172
173        use ::clap::Parser;
174        let cli = #cli_ident::parse();
175
176        #inline_helpers
177
178        let mut used_path: Option<std::path::PathBuf> = None;
179        let mut used_format: Option<&'static str> = None;
180
181        let mut config_data = ::config::Config::builder();
182        if !cli.__no_config {
183            if let Some(ref path) = cli.__config_file {
184                used_path = Some(path.clone());
185                let format = __inline_guess_format(path, &[#(#fmts_list),*]);
186                if let Some(fmt) = format {
187                    let file = match fmt {
188                        "yaml" | "yml" => ::config::File::from(path.as_path()).format(::config::FileFormat::Yaml),
189                        "json" => ::config::File::from(path.as_path()).format(::config::FileFormat::Json),
190                        "toml" => ::config::File::from(path.as_path()).format(::config::FileFormat::Toml),
191                        _ => ::config::File::from(path.as_path()).format(::config::FileFormat::Yaml),
192                    };
193                    config_data = config_data.add_source(file);
194                }
195                used_format = format;
196            } else if let Some(found) = __inline_find_config(#base_name, &[#(#fmts_list),*]) {
197                used_path = Some(found.clone());
198                let format = __inline_guess_format(&found, &[#(#fmts_list),*]);
199                if let Some(fmt) = format {
200                    let file = match fmt {
201                        "yaml" | "yml" => ::config::File::from(found.as_path()).format(::config::FileFormat::Yaml),
202                        "json" => ::config::File::from(found.as_path()).format(::config::FileFormat::Json),
203                        "toml" => ::config::File::from(found.as_path()).format(::config::FileFormat::Toml),
204                        _ => ::config::File::from(found.as_path()).format(::config::FileFormat::Yaml),
205                    };
206                    config_data = config_data.add_source(file);
207                }
208                used_format = format;
209            }
210        }
211
212        let built = config_data.build().unwrap_or_else(|e| {
213            eprintln!("Failed to build config: {}", e);
214            ::config::Config::default()
215        });
216        let ephemeral_cfg: #cfg_ident = built.clone().try_deserialize().unwrap_or_else(|e| {
217            eprintln!("Failed to deserialize config into struct: {}", e);
218            eprintln!("Config data after build: {:#?}", built);
219            #cfg_ident::default()
220        });
221
222
223        let final_struct = #struct_ident {
224            #(#unify_stmts),*
225        };
226        (final_struct, used_path, used_format)
227    }
228}
229
230/// Generate ephemeral CLI field if field is not config_only
231fn generate_cli_field(field: &FieldInfo) -> TokenStream2 {
232    let ident = &field.ident;
233    let kebab_default = ident.to_string().to_kebab_case();
234    let final_name = field.arg_attrs.name.clone().unwrap_or(kebab_default);
235    let name_lit = LitStr::new(&final_name, Span::call_site());
236    let help_text = &field.arg_attrs.help_text;
237    let help_attr = if help_text.is_empty() {
238        quote!()
239    } else {
240        let help_lit = LitStr::new(help_text, Span::call_site());
241        quote!(help=#help_lit,)
242    };
243
244    if field.arg_attrs.positional {
245        // For positional arguments
246        if field.is_vec_type() {
247            quote! {
248                #[clap(value_name=#name_lit, num_args=1.., action=::clap::ArgAction::Append, #help_attr)]
249                #ident: Option<Vec<String>>
250            }
251        } else {
252            quote! {
253                #[clap(value_name=#name_lit, #help_attr)]
254                #ident: Option<String>
255            }
256        }
257    } else {
258        // short?
259        let short_attr = if let Some(ch) = field.arg_attrs.short {
260            quote!(short=#ch,)
261        } else {
262            quote!()
263        };
264
265        if field.is_bool_type() {
266            // Handle bool default_value "true"/"false"
267            if let Some(ref dv) = field.arg_attrs.default_value {
268                let is_true = dv.eq_ignore_ascii_case("true");
269                let is_false = dv.eq_ignore_ascii_case("false");
270                if !is_true && !is_false {
271                    let msg = format!(
272                        "For bool field, default_value must be \"true\" or \"false\", got {}",
273                        dv
274                    );
275                    return quote! {
276                        compile_error!(#msg);
277                        #ident: ()
278                    };
279                }
280                let bool_lit = if is_true { quote!(true) } else { quote!(false) };
281                quote! {
282                    #[clap(long=#name_lit, #short_attr default_value_t=#bool_lit, #help_attr)]
283                    #ident: Option<bool>
284                }
285            } else {
286                quote! {
287                    #[clap(long=#name_lit, #short_attr action=::clap::ArgAction::SetTrue, #help_attr)]
288                    #ident: Option<bool>
289                }
290            }
291        } else {
292            let dv_attr = if let Some(dv) = &field.arg_attrs.default_value {
293                let dv_lit = LitStr::new(dv, Span::call_site());
294                quote!(default_value=#dv_lit,)
295            } else {
296                quote!()
297            };
298            let is_vec = field.is_vec_type();
299            let multi = if is_vec {
300                quote!(num_args = 1.., action = ::clap::ArgAction::Append,)
301            } else {
302                quote!()
303            };
304            let field_ty = {
305                let t = &field.ty;
306                quote!(Option<#t>)
307            };
308
309            quote! {
310                #[clap(long=#name_lit, #short_attr #dv_attr #multi #help_attr)]
311                #ident: #field_ty
312            }
313        }
314    }
315}
316/// Generate ephemeral config field if field is not cli_only
317fn generate_config_field(field: &FieldInfo) -> TokenStream2 {
318    let ident = &field.ident;
319    let ty = &field.ty;
320
321    // Only use rename if explicitly specified
322    let rename_attr = if let Some(name) = &field.arg_attrs.name {
323        let name_lit = LitStr::new(name, Span::call_site());
324        quote!(#[serde(rename = #name_lit)])
325    } else {
326        quote!()
327    };
328
329    quote! {
330        #rename_attr
331        #[serde(default)]
332        pub #ident: #ty
333    }
334}
335
336/// Merge ephemeral CLI + ephemeral config => final
337fn unify_field(field: &FieldInfo) -> TokenStream2 {
338    let ident = &field.ident;
339    match field.arg_attrs.availability {
340        FieldAvailability::CliOnly => {
341            if field.is_vec_type() {
342                quote!(#ident: cli.#ident.unwrap_or_default())
343            } else if field.is_bool_type() {
344                quote!(#ident: cli.#ident.unwrap_or(false))
345            } else {
346                quote!(#ident: cli.#ident.unwrap_or_default())
347            }
348        }
349        FieldAvailability::ConfigOnly => {
350            quote!(#ident: ephemeral_cfg.#ident)
351        }
352        FieldAvailability::CliAndConfig => {
353            if field.is_vec_type() {
354                match field.arg_attrs.multi_value_behavior {
355                    MultiValueBehavior::Extend => quote! {
356                        #ident: {
357                            let mut merged = ephemeral_cfg.#ident.clone();
358                            if let Some(cli_vec) = cli.#ident {
359                                merged.extend(cli_vec);
360                            }
361                            merged
362                        }
363                    },
364                    MultiValueBehavior::Overwrite => quote! {
365                        #ident: cli.#ident.unwrap_or_else(|| ephemeral_cfg.#ident.clone())
366                    },
367                }
368            } else if field.is_bool_type() {
369                quote!(#ident: cli.#ident.unwrap_or(ephemeral_cfg.#ident))
370            } else {
371                quote!(#ident: cli.#ident.unwrap_or_else(|| ephemeral_cfg.#ident))
372            }
373        }
374        FieldAvailability::Internal => {
375            quote!(#ident: Default::default())
376        }
377    }
378}
379
380/// Implement Debug for final struct
381fn generate_debug_impl(
382    struct_ident: &syn::Ident,
383    generics: &syn::Generics,
384    fields: &[FieldInfo],
385) -> TokenStream2 {
386    let field_idents = fields.iter().map(|fi| &fi.ident);
387    quote! {
388        impl #generics ::std::fmt::Debug for #struct_ident #generics {
389            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
390                let mut dbg = f.debug_struct(stringify!(#struct_ident));
391                #( dbg.field(stringify!(#field_idents), &self.#field_idents); )*
392                dbg.finish()
393            }
394        }
395    }
396}
397
398/// Implement Serialize for final struct
399fn generate_serialize_impl(
400    struct_ident: &syn::Ident,
401    generics: &syn::Generics,
402    fields: &[FieldInfo],
403) -> TokenStream2 {
404    let field_idents = fields.iter().map(|fi| &fi.ident);
405    let field_names = fields.iter().map(|fi| fi.ident.to_string());
406    let num_fields = fields.len();
407
408    quote! {
409        impl #generics ::serde::Serialize for #struct_ident #generics {
410            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
411            where
412                S: ::serde::Serializer
413            {
414                use ::serde::ser::SerializeStruct;
415                let mut st = serializer.serialize_struct(
416                    stringify!(#struct_ident),
417                    #num_fields
418                )?;
419                #(
420                    st.serialize_field(#field_names, &self.#field_idents)?;
421                )*
422                st.end()
423            }
424        }
425    }
426}