1use heck::ToKebabCase;
5use proc_macro::TokenStream;
6use proc_macro2::{Span, TokenStream as TokenStream2};
7use quote::quote;
8use syn::{parse_macro_input, DeriveInput, Error, LitStr};
9
10mod parse_attrs;
11use parse_attrs::*;
12
13#[proc_macro_derive(
14 ClapConfigFile,
15 attributes(config_file_name, config_file_formats, config_arg)
16)]
17pub fn derive_clap_config_file(input: TokenStream) -> TokenStream {
18 let ast = parse_macro_input!(input as DeriveInput);
19 match build_impl(ast) {
20 Ok(ts) => ts.into(),
21 Err(e) => e.to_compile_error().into(),
22 }
23}
24
25fn build_impl(ast: DeriveInput) -> syn::Result<TokenStream2> {
26 let struct_ident = &ast.ident;
27 let generics = &ast.generics;
28
29 let macro_cfg = parse_struct_level_attrs(&ast.attrs)?;
30
31 let fields_named = match &ast.data {
32 syn::Data::Struct(syn::DataStruct {
33 fields: syn::Fields::Named(ref named),
34 ..
35 }) => &named.named,
36 _ => {
37 return Err(Error::new_spanned(
38 &ast.ident,
39 "ClapConfigFile only supports a struct with named fields.",
40 ))
41 }
42 };
43
44 let field_infos = parse_fields(fields_named)?;
45 let parse_info_impl = generate_parse_info_impl(struct_ident, &field_infos, ¯o_cfg);
46
47 let debug_impl = generate_debug_impl(struct_ident, generics, &field_infos);
48 let serialize_impl = generate_serialize_impl(struct_ident, generics, &field_infos);
49
50 let expanded = quote! {
51 impl #generics #struct_ident #generics {
52 pub fn parse_info() -> (Self, Option<std::path::PathBuf>, Option<&'static str>) {
53 #parse_info_impl
54 }
55 pub fn parse() -> Self {
56 Self::parse_info().0
57 }
58 }
59
60 #debug_impl
61 #serialize_impl
62 };
63
64 Ok(expanded)
65}
66
67fn generate_parse_info_impl(
69 struct_ident: &syn::Ident,
70 fields: &[FieldInfo],
71 macro_cfg: &MacroConfig,
72) -> TokenStream2 {
73 let base_name = ¯o_cfg.base_name;
74 let fmts = ¯o_cfg.formats;
75 let fmts_list: Vec<_> = fmts.iter().map(|s| s.as_str()).collect();
76
77 let cli_ident = syn::Ident::new(&format!("__{}_Cli", struct_ident), Span::call_site());
79 let cli_fields = fields
80 .iter()
81 .filter(|f| {
82 !matches!(
83 f.arg_attrs.availability,
84 FieldAvailability::ConfigOnly | FieldAvailability::Internal
85 )
86 })
87 .map(generate_cli_field);
88
89 let cli_extras = quote! {
90 #[clap(long="no-config", default_value_t=false, help="Do not use a config file")]
91 __no_config: bool,
92
93 #[clap(long="config-file", help="Path to the config file")]
94 __config_file: Option<std::path::PathBuf>,
95 };
96 let build_cli_struct = quote! {
97 #[derive(::clap::Parser, ::std::fmt::Debug, ::std::default::Default)]
98 struct #cli_ident {
99 #cli_extras
100 #(#cli_fields),*
101 }
102 };
103
104 let cfg_ident = syn::Ident::new(&format!("__{}_Cfg", struct_ident), Span::call_site());
106 let cfg_fields = fields
107 .iter()
108 .filter(|f| {
109 !matches!(
110 f.arg_attrs.availability,
111 FieldAvailability::CliOnly | FieldAvailability::Internal
112 )
113 })
114 .map(generate_config_field);
115 let build_cfg_struct = quote! {
116 #[derive(::serde::Deserialize, ::std::fmt::Debug, ::std::default::Default)]
117 struct #cfg_ident {
118 #(#cfg_fields),*
119 }
120 };
121
122 let unify_stmts = fields.iter().map(unify_field);
123
124 let inline_helpers = quote! {
125 fn __inline_guess_format(path: &std::path::Path, known_formats: &[&str]) -> Option<&'static str> {
126 if let Some(ext) = path.extension().and_then(|e| e.to_str()).map(|s| s.to_lowercase()) {
127 for &f in known_formats {
128 if ext == f {
129 return Some(Box::leak(f.to_string().into_boxed_str()));
130 }
131 }
132 }
133 None
134 }
135
136 fn __inline_find_config(base_name: &str, fmts: &[&str]) -> Option<std::path::PathBuf> {
137 let mut dir = std::env::current_dir().ok()?;
138 let mut found: Option<std::path::PathBuf> = None;
139
140 loop {
141 let mut found_this = vec![];
142 for &f in fmts {
143 let candidate = dir.join(format!("{}.{}", base_name, f));
144 if candidate.is_file() {
145 found_this.push(candidate);
146 }
147 }
148 if found_this.len() > 1 {
149 eprintln!("Error: multiple config files in same dir: {:?}", found_this);
150 std::process::exit(2);
151 } else if found_this.len() == 1 {
152 if found.is_some() {
153 eprintln!(
154 "Error: multiple config files found walking up: {:?} and {:?}",
155 found.as_ref().unwrap(), found_this[0]
156 );
157 std::process::exit(2);
158 }
159 found = Some(found_this.remove(0));
160 }
161 if !dir.pop() {
162 break;
163 }
164 }
165 found
166 }
167 };
168
169 quote! {
170 #build_cli_struct
171 #build_cfg_struct
172
173 use ::clap::Parser;
174 let cli = #cli_ident::parse();
175
176 #inline_helpers
177
178 let mut used_path: Option<std::path::PathBuf> = None;
179 let mut used_format: Option<&'static str> = None;
180
181 let mut config_data = ::config::Config::builder();
182 if !cli.__no_config {
183 if let Some(ref path) = cli.__config_file {
184 used_path = Some(path.clone());
185 let format = __inline_guess_format(path, &[#(#fmts_list),*]);
186 if let Some(fmt) = format {
187 let file = match fmt {
188 "yaml" | "yml" => ::config::File::from(path.as_path()).format(::config::FileFormat::Yaml),
189 "json" => ::config::File::from(path.as_path()).format(::config::FileFormat::Json),
190 "toml" => ::config::File::from(path.as_path()).format(::config::FileFormat::Toml),
191 _ => ::config::File::from(path.as_path()).format(::config::FileFormat::Yaml),
192 };
193 config_data = config_data.add_source(file);
194 }
195 used_format = format;
196 } else if let Some(found) = __inline_find_config(#base_name, &[#(#fmts_list),*]) {
197 used_path = Some(found.clone());
198 let format = __inline_guess_format(&found, &[#(#fmts_list),*]);
199 if let Some(fmt) = format {
200 let file = match fmt {
201 "yaml" | "yml" => ::config::File::from(found.as_path()).format(::config::FileFormat::Yaml),
202 "json" => ::config::File::from(found.as_path()).format(::config::FileFormat::Json),
203 "toml" => ::config::File::from(found.as_path()).format(::config::FileFormat::Toml),
204 _ => ::config::File::from(found.as_path()).format(::config::FileFormat::Yaml),
205 };
206 config_data = config_data.add_source(file);
207 }
208 used_format = format;
209 }
210 }
211
212 let built = config_data.build().unwrap_or_else(|e| {
213 eprintln!("Failed to build config: {}", e);
214 ::config::Config::default()
215 });
216 let ephemeral_cfg: #cfg_ident = built.clone().try_deserialize().unwrap_or_else(|e| {
217 eprintln!("Failed to deserialize config into struct: {}", e);
218 eprintln!("Config data after build: {:#?}", built);
219 #cfg_ident::default()
220 });
221
222
223 let final_struct = #struct_ident {
224 #(#unify_stmts),*
225 };
226 (final_struct, used_path, used_format)
227 }
228}
229
230fn generate_cli_field(field: &FieldInfo) -> TokenStream2 {
232 let ident = &field.ident;
233 let kebab_default = ident.to_string().to_kebab_case();
234 let final_name = field.arg_attrs.name.clone().unwrap_or(kebab_default);
235 let name_lit = LitStr::new(&final_name, Span::call_site());
236 let help_text = &field.arg_attrs.help_text;
237 let help_attr = if help_text.is_empty() {
238 quote!()
239 } else {
240 let help_lit = LitStr::new(help_text, Span::call_site());
241 quote!(help=#help_lit,)
242 };
243
244 if field.arg_attrs.positional {
245 if field.is_vec_type() {
247 quote! {
248 #[clap(value_name=#name_lit, num_args=1.., action=::clap::ArgAction::Append, #help_attr)]
249 #ident: Option<Vec<String>>
250 }
251 } else {
252 quote! {
253 #[clap(value_name=#name_lit, #help_attr)]
254 #ident: Option<String>
255 }
256 }
257 } else {
258 let short_attr = if let Some(ch) = field.arg_attrs.short {
260 quote!(short=#ch,)
261 } else {
262 quote!()
263 };
264
265 if field.is_bool_type() {
266 if let Some(ref dv) = field.arg_attrs.default_value {
268 let is_true = dv.eq_ignore_ascii_case("true");
269 let is_false = dv.eq_ignore_ascii_case("false");
270 if !is_true && !is_false {
271 let msg = format!(
272 "For bool field, default_value must be \"true\" or \"false\", got {}",
273 dv
274 );
275 return quote! {
276 compile_error!(#msg);
277 #ident: ()
278 };
279 }
280 let bool_lit = if is_true { quote!(true) } else { quote!(false) };
281 quote! {
282 #[clap(long=#name_lit, #short_attr default_value_t=#bool_lit, #help_attr)]
283 #ident: Option<bool>
284 }
285 } else {
286 quote! {
287 #[clap(long=#name_lit, #short_attr action=::clap::ArgAction::SetTrue, #help_attr)]
288 #ident: Option<bool>
289 }
290 }
291 } else {
292 let dv_attr = if let Some(dv) = &field.arg_attrs.default_value {
293 let dv_lit = LitStr::new(dv, Span::call_site());
294 quote!(default_value=#dv_lit,)
295 } else {
296 quote!()
297 };
298 let is_vec = field.is_vec_type();
299 let multi = if is_vec {
300 quote!(num_args = 1.., action = ::clap::ArgAction::Append,)
301 } else {
302 quote!()
303 };
304 let field_ty = {
305 let t = &field.ty;
306 quote!(Option<#t>)
307 };
308
309 quote! {
310 #[clap(long=#name_lit, #short_attr #dv_attr #multi #help_attr)]
311 #ident: #field_ty
312 }
313 }
314 }
315}
316fn generate_config_field(field: &FieldInfo) -> TokenStream2 {
318 let ident = &field.ident;
319 let ty = &field.ty;
320
321 let rename_attr = if let Some(name) = &field.arg_attrs.name {
323 let name_lit = LitStr::new(name, Span::call_site());
324 quote!(#[serde(rename = #name_lit)])
325 } else {
326 quote!()
327 };
328
329 quote! {
330 #rename_attr
331 #[serde(default)]
332 pub #ident: #ty
333 }
334}
335
336fn unify_field(field: &FieldInfo) -> TokenStream2 {
338 let ident = &field.ident;
339 match field.arg_attrs.availability {
340 FieldAvailability::CliOnly => {
341 if field.is_vec_type() {
342 quote!(#ident: cli.#ident.unwrap_or_default())
343 } else if field.is_bool_type() {
344 quote!(#ident: cli.#ident.unwrap_or(false))
345 } else {
346 quote!(#ident: cli.#ident.unwrap_or_default())
347 }
348 }
349 FieldAvailability::ConfigOnly => {
350 quote!(#ident: ephemeral_cfg.#ident)
351 }
352 FieldAvailability::CliAndConfig => {
353 if field.is_vec_type() {
354 match field.arg_attrs.multi_value_behavior {
355 MultiValueBehavior::Extend => quote! {
356 #ident: {
357 let mut merged = ephemeral_cfg.#ident.clone();
358 if let Some(cli_vec) = cli.#ident {
359 merged.extend(cli_vec);
360 }
361 merged
362 }
363 },
364 MultiValueBehavior::Overwrite => quote! {
365 #ident: cli.#ident.unwrap_or_else(|| ephemeral_cfg.#ident.clone())
366 },
367 }
368 } else if field.is_bool_type() {
369 quote!(#ident: cli.#ident.unwrap_or(ephemeral_cfg.#ident))
370 } else {
371 quote!(#ident: cli.#ident.unwrap_or_else(|| ephemeral_cfg.#ident))
372 }
373 }
374 FieldAvailability::Internal => {
375 quote!(#ident: Default::default())
376 }
377 }
378}
379
380fn generate_debug_impl(
382 struct_ident: &syn::Ident,
383 generics: &syn::Generics,
384 fields: &[FieldInfo],
385) -> TokenStream2 {
386 let field_idents = fields.iter().map(|fi| &fi.ident);
387 quote! {
388 impl #generics ::std::fmt::Debug for #struct_ident #generics {
389 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
390 let mut dbg = f.debug_struct(stringify!(#struct_ident));
391 #( dbg.field(stringify!(#field_idents), &self.#field_idents); )*
392 dbg.finish()
393 }
394 }
395 }
396}
397
398fn generate_serialize_impl(
400 struct_ident: &syn::Ident,
401 generics: &syn::Generics,
402 fields: &[FieldInfo],
403) -> TokenStream2 {
404 let field_idents = fields.iter().map(|fi| &fi.ident);
405 let field_names = fields.iter().map(|fi| fi.ident.to_string());
406 let num_fields = fields.len();
407
408 quote! {
409 impl #generics ::serde::Serialize for #struct_ident #generics {
410 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
411 where
412 S: ::serde::Serializer
413 {
414 use ::serde::ser::SerializeStruct;
415 let mut st = serializer.serialize_struct(
416 stringify!(#struct_ident),
417 #num_fields
418 )?;
419 #(
420 st.serialize_field(#field_names, &self.#field_idents)?;
421 )*
422 st.end()
423 }
424 }
425 }
426}