argopt_impl/
lib.rs

1use darling::FromMeta;
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{
5    bracketed,
6    parse::{Parse, ParseStream},
7    parse_macro_input, parse_quote, parse_str,
8    punctuated::Punctuated,
9    Attribute, AttributeArgs, FnArg, Ident, ItemFn, Meta, NestedMeta, Pat, Path, Token,
10};
11
12fn gen_cmd(item: ItemFn, is_subcmd: bool, gen_verbose: bool) -> TokenStream {
13    let vis = &item.vis;
14    let fn_async = &item.sig.asyncness;
15    let fn_name = &item.sig.ident;
16    let ret_type = &item.sig.output;
17
18    let mut cmd_help = quote! {};
19    let mut app_attrs = quote! {};
20    let mut fn_attrs: Vec<Attribute> = vec![];
21
22    for attr in item.attrs.iter() {
23        if attr.path.is_ident("doc") {
24            cmd_help = quote! { #attr };
25        } else if attr.path.is_ident("opt") {
26            let tokens = &attr.tokens;
27            app_attrs = quote! { #[clap #tokens] };
28        } else {
29            fn_attrs.push(attr.clone());
30        }
31    }
32
33    let mut arg_muts = vec![];
34    let mut arg_idents = vec![];
35    let mut tmp_arg_idents = vec![];
36    let mut arg_types = vec![];
37    let mut arg_docs = vec![];
38    let mut arg_attrs = vec![];
39
40    for arg in item.sig.inputs.iter() {
41        let arg = if let FnArg::Typed(arg) = arg {
42            arg
43        } else {
44            panic!("invalid function argument");
45        };
46
47        let mut doc = quote! {};
48        let mut attrs = vec![];
49
50        for attr in arg.attrs.iter() {
51            if attr.path.is_ident("doc") {
52                doc = quote! { #attr };
53            } else if attr.path.is_ident("opt") {
54                let tokens = attr.tokens.clone();
55                let attr: NestedMeta = parse_quote!(opt #tokens);
56
57                if let NestedMeta::Meta(Meta::List(ml)) = attr {
58                    for nm in ml.nested.iter() {
59                        attrs.push(nm.clone());
60                    }
61                } else {
62                    unreachable!()
63                }
64            } else {
65                panic!("invalid argument attribute");
66            }
67        }
68
69        if let Pat::Ident(pat_ident) = arg.pat.as_ref() {
70            assert!(pat_ident.attrs.is_empty());
71            assert!(pat_ident.by_ref.is_none());
72            assert!(pat_ident.subpat.is_none());
73
74            arg_muts.push(pat_ident.mutability);
75            arg_idents.push(pat_ident.ident.clone());
76            tmp_arg_idents
77                .push(parse_str::<Ident>(&format!("tmp_var_{}", pat_ident.ident)).unwrap());
78            arg_types.push(arg.ty.as_ref().clone());
79            arg_docs.push(doc);
80            arg_attrs.push(attrs);
81        } else {
82            panic!();
83        }
84    }
85
86    let body = &item.block;
87
88    let mod_name = module_name(&fn_name.to_string());
89
90    let options_type = option_struct_name(&fn_name.to_string());
91    let opts_var_name = option_var_name(&fn_name.to_string());
92
93    let arg_attrs = arg_attrs
94        .iter()
95        .map(|attrs| {
96            if attrs.is_empty() {
97                quote! {}
98            } else {
99                quote! {
100                    #[clap( #( #attrs ),* )]
101                }
102            }
103        })
104        .collect::<Vec<_>>();
105
106    if is_subcmd {
107        let subcmd_ctor = subcmd_ctor_name(&fn_name.to_string());
108
109        quote! {
110            #[doc(hidden)]
111            pub mod #mod_name {
112                use argopt::clap;
113                use super::*;
114
115                #[doc(hidden)]
116                #[derive(clap::Parser)]
117                #[allow(non_camel_case_types)]
118                pub enum #options_type {
119                    #cmd_help
120                    #app_attrs
121                    #subcmd_ctor {
122                        #(
123                            #arg_docs
124                            #arg_attrs
125                            #arg_idents: #arg_types,
126                        )*
127                    }
128                }
129            }
130
131            #(#fn_attrs)*
132            #vis #fn_async fn #fn_name (#opts_var_name: #mod_name::#options_type) #ret_type {
133                #(
134                    let #arg_muts #arg_idents;
135                )*
136
137                {
138                    #(
139                        let #arg_muts #tmp_arg_idents;
140                    )*
141
142                    match #opts_var_name {
143                        #mod_name::#options_type::#subcmd_ctor { #(#arg_idents),* } => {
144                            #(
145                                #tmp_arg_idents = #arg_idents;
146                            )*
147                        }
148                    }
149
150                    #(
151                        #arg_idents = #tmp_arg_idents;
152                    )*
153                }
154
155                #body
156            }
157        }
158    } else {
159        let verb = if gen_verbose {
160            VerbosityCode::new(&opts_var_name)
161        } else {
162            VerbosityCode::default()
163        };
164
165        let verbose_arg = verb.arg;
166        let def_logger = verb.def_logger;
167        let set_verbosity_level = verb.set_verbosity_level;
168
169        quote! {
170            #[doc(hidden)]
171            pub mod #mod_name {
172                use argopt::clap;
173                use super::*;
174
175                #[doc(hidden)]
176                #[derive(clap::Parser)]
177                #cmd_help
178                #app_attrs
179                #[allow(non_camel_case_types)]
180                pub struct #options_type {
181                    #(
182                        #arg_docs
183                        #arg_attrs
184                        pub #arg_idents: #arg_types,
185                    )*
186                    #verbose_arg
187                }
188            }
189
190            #def_logger
191
192            #(#fn_attrs)*
193            #vis #fn_async fn #fn_name () #ret_type {
194                #(
195                    let #arg_muts #arg_idents;
196                )*
197
198                {
199                    let #opts_var_name = <#mod_name::#options_type as argopt::clap::Parser>::parse();
200                    #(
201                        #arg_idents = #opts_var_name.#arg_idents;
202                    )*
203                    #set_verbosity_level
204                }
205
206                #body
207            }
208        }
209    }
210    .into()
211}
212
213#[derive(Default)]
214struct VerbosityCode {
215    arg: proc_macro2::TokenStream,
216    def_logger: proc_macro2::TokenStream,
217    set_verbosity_level: proc_macro2::TokenStream,
218}
219
220impl VerbosityCode {
221    fn new(opts_var_name: &Ident) -> Self {
222        Self {
223            arg: quote! {
224                #[clap(short, long, parse(from_occurrences), global = true)]
225                #[doc = "Verbose mode (-v, -vv, -vvv, etc.)"]
226                pub verbose: usize,
227            }
228            .into(),
229            def_logger: quote! {
230                struct StdoutLogger;
231
232                impl log::Log for StdoutLogger {
233                    fn enabled(&self, metadata: &log::Metadata) -> bool {
234                        metadata.level() <= log::max_level()
235                    }
236
237                    fn log(&self, record: &log::Record) {
238                        if self.enabled(record.metadata()) {
239                            println!("{}", record.args());
240                        }
241                    }
242
243                    fn flush(&self) {}
244                }
245
246                static ARGOPT_LOGGER: StdoutLogger = StdoutLogger;
247            },
248            set_verbosity_level: quote! {
249                log::set_logger(&ARGOPT_LOGGER).unwrap();
250
251                log::set_max_level(
252                    if #opts_var_name.verbose + 1 == log::LevelFilter::Error as usize {
253                        log::LevelFilter::Error
254                    } else if #opts_var_name.verbose + 1 == log::LevelFilter::Warn as usize {
255                        log::LevelFilter::Warn
256                    } else if #opts_var_name.verbose + 1 == log::LevelFilter::Info as usize {
257                        log::LevelFilter::Info
258                    } else if #opts_var_name.verbose + 1 == log::LevelFilter::Debug as usize {
259                        log::LevelFilter::Debug
260                    } else {
261                        log::LevelFilter::Trace
262                    }
263                );
264            },
265        }
266    }
267}
268
269#[derive(Debug, Default, FromMeta)]
270#[darling(default)]
271struct SubCmdAttr {}
272
273#[proc_macro_attribute]
274pub fn subcmd(_attr: TokenStream, item: TokenStream) -> TokenStream {
275    // let attr = parse_macro_input!(attr as AttributeArgs);
276    // let attr = SubCmdAttr::from_list(&attr).unwrap();
277    let item = parse_macro_input!(item as ItemFn);
278    // let fn_name = &item.sig.ident;
279    gen_cmd(item, true, false)
280}
281
282#[derive(Debug, Default, FromMeta)]
283#[darling(default)]
284struct CmdAttr {
285    verbose: bool,
286}
287
288#[proc_macro_attribute]
289pub fn cmd(attr: TokenStream, item: TokenStream) -> TokenStream {
290    let attr = parse_macro_input!(attr as AttributeArgs);
291    let attr = CmdAttr::from_list(&attr).unwrap();
292    let item = parse_macro_input!(item as ItemFn);
293    gen_cmd(item, false, attr.verbose)
294}
295
296fn module_name(fn_name: &str) -> Ident {
297    parse_str(&format!("__{fn_name}__impl")).unwrap()
298}
299
300fn option_struct_name(fn_name: &str) -> Ident {
301    parse_str(&format!("Options_{fn_name}")).unwrap()
302}
303
304fn option_var_name(fn_name: &str) -> Ident {
305    parse_str(&format!("options_{fn_name}")).unwrap()
306}
307
308fn subcmd_ctor_name(fn_name: &str) -> Ident {
309    use convert_case::{Case, Casing};
310    parse_str(&fn_name.to_case(Case::UpperCamel)).unwrap()
311}
312
313#[derive(Debug, Default)]
314struct CmdGroupAttr {
315    verbose: bool,
316    commands: Vec<Path>,
317}
318
319impl Parse for CmdGroupAttr {
320    fn parse(input: ParseStream) -> syn::Result<Self> {
321        let mut ret = CmdGroupAttr::default();
322
323        while let Ok(key) = input.parse::<Ident>() {
324            if key == "verbose" {
325                ret.verbose = true;
326            } else if key == "commands" {
327                input.parse::<Token![=]>()?;
328                let cmds;
329                bracketed!(cmds in input);
330                let cmds = Punctuated::<Path, Token![,]>::parse_separated_nonempty(&cmds)?;
331                ret.commands = cmds.into_iter().collect();
332            } else {
333                panic!("unexpected attribute for cmd_group");
334            }
335
336            if input.parse::<Token![,]>().is_err() {
337                break;
338            }
339        }
340
341        Ok(ret)
342    }
343}
344
345#[proc_macro_attribute]
346pub fn cmd_group(attr: TokenStream, item: TokenStream) -> TokenStream {
347    let attr = parse_macro_input!(attr as CmdGroupAttr);
348    let item = parse_macro_input!(item as ItemFn);
349
350    let vis = item.vis;
351    let body = item.block;
352    let fn_sig = item.sig;
353
354    let mut constr_names: Vec<Ident> = vec![];
355    let mut struct_names: Vec<Path> = vec![];
356    let mut cmds = vec![];
357
358    for cmd in attr.commands.iter() {
359        cmds.push(cmd.clone());
360        constr_names.push(parse_str(&format!("Constr_{}", path_to_str(cmd))).unwrap());
361
362        let ident = option_struct_name(&cmd.segments.last().unwrap().ident.to_string());
363        let mut cmd = cmd.clone();
364        let last = cmd.segments.pop().unwrap();
365        let mod_name = module_name(&last.value().ident.to_string());
366        cmd.segments.push(mod_name.into());
367        cmd.segments.push(ident.into());
368        struct_names.push(cmd);
369    }
370
371    let options_type: Ident = parse_str("Main_options_type").unwrap();
372    let mod_name: Ident = module_name(&fn_sig.ident.to_string());
373    let commands_enum: Ident = parse_str("Main_commands").unwrap();
374    let opts_var_name: Ident = parse_str("arg_Main_commands").unwrap();
375
376    let mut cmd_help = quote! {};
377    let mut app_attrs = quote! {};
378
379    for fn_attr in item.attrs.iter() {
380        if fn_attr.path.is_ident("doc") {
381            cmd_help = quote! { #fn_attr };
382        } else if fn_attr.path.is_ident("opt") {
383            let tokens = &fn_attr.tokens;
384            app_attrs = quote! { #[clap #tokens] };
385        }
386    }
387
388    let verb = if attr.verbose {
389        VerbosityCode::new(&opts_var_name)
390    } else {
391        VerbosityCode::default()
392    };
393
394    let verbose_arg = verb.arg;
395    let def_logger = verb.def_logger;
396    let set_verbosity_level = verb.set_verbosity_level;
397
398    (quote! {
399        #[doc(hidden)]
400        pub mod #mod_name {
401            use argopt::clap;
402            use super::*;
403
404            #[derive(clap::Parser)]
405            #cmd_help
406            #app_attrs
407            #[allow(non_camel_case_types)]
408            pub struct #options_type {
409                #verbose_arg
410
411                #[clap(subcommand)]
412                pub commands: #commands_enum
413            }
414
415            #[derive(clap::Subcommand)]
416            pub enum #commands_enum {
417                #(
418                    #[clap(flatten)]
419                    #constr_names(#struct_names),
420                )*
421            }
422        }
423
424        #def_logger
425
426        #vis #fn_sig {
427            #body
428
429            let #opts_var_name = <#mod_name::#options_type as argopt::clap::Parser>::parse();
430
431            #set_verbosity_level
432
433            match #opts_var_name.commands {
434                #(
435                    #mod_name::#commands_enum::#constr_names(opts) => #cmds(opts),
436                )*
437            }
438        }
439    })
440    .into()
441}
442
443fn path_to_str(path: &Path) -> String {
444    path.segments
445        .iter()
446        .map(|r| r.ident.to_string())
447        .collect::<Vec<String>>()
448        .join("_")
449}