clap_handler_derive/
lib.rs

1use proc_macro::{TokenStream};
2use proc_macro_error::proc_macro_error;
3use quote::quote;
4use syn::{AttributeArgs, ItemFn, Meta, NestedMeta, DeriveInput, Data, parse_macro_input, DataStruct, DataEnum, Fields, Type, Ident};
5
6#[proc_macro_derive(Handler, attributes(handler_inject))]
7#[proc_macro_error]
8pub fn derive_handler(item: TokenStream) -> TokenStream {
9    let input = parse_macro_input!(item as DeriveInput);
10    let name = &input.ident;
11
12    let context_injector = match input.attrs.iter()
13        .find(|attr| attr.path.is_ident("handler_inject"))
14        .and_then(|a| a.parse_args::<Ident>().ok()) {
15        Some(ident) => {
16            cfg_if::cfg_if! {
17                if #[cfg(feature = "async")] {
18                    quote! { self.#ident(ctx).await?; }
19                } else {
20                    quote! { self.#ident(ctx)?; }
21                }
22            }
23        }
24        None => quote! {},
25    };
26
27    let expanded = match input.data {
28        Data::Struct(DataStruct { fields, .. }) => {
29            match fields {
30                Fields::Named(ref fields_name) => {
31                    let subcommand_field: Option<syn::Ident> = fields_name.named.iter().find_map(|field| {
32                        for attr in field.attrs.iter() {
33                            if attr.path.is_ident("clap") {
34                                let ident: syn::Ident = attr.parse_args().ok()?;
35                                if ident == "subcommand" {
36                                    return Some(field.ident.clone().unwrap());
37                                }
38                            }
39                        }
40                        None
41                    });
42
43                    match subcommand_field {
44                        Some(subcommand_field) => {
45                            #[cfg(not(feature = "async"))]
46                            quote! {
47                                impl clap_handler::Handler for #name {
48                                    fn handle_command(&mut self, ctx: &mut clap_handler::Context) -> anyhow::Result<()> {
49                                        #context_injector
50                                        Ok(())
51                                    }
52
53                                    fn handle_subcommand(&mut self, ctx: clap_handler::Context) -> anyhow::Result<()> {
54                                        clap_handler::Handler::execute(&mut self.#subcommand_field, ctx)
55                                    }
56                                }
57                            }
58
59                            #[cfg(feature = "async")]
60                            quote! {
61                                #[clap_handler::async_trait]
62                                impl clap_handler::Handler for #name {
63                                    async fn handle_command(&mut self, ctx: &mut clap_handler::Context) -> anyhow::Result<()> {
64                                        #context_injector
65                                        Ok(())
66                                    }
67
68                                    async fn handle_subcommand(&mut self, ctx: clap_handler::Context) -> anyhow::Result<()> {
69                                        clap_handler::Handler::execute(&mut self.#subcommand_field, ctx).await
70                                    }
71                                }
72                            }
73                        }
74                        None => panic!("Struct without #[clap(subcommand)] is not supported!"),
75                    }
76                }
77                _ => panic!("Unnamed fields or None struct is not supported"),
78            }
79        }
80        Data::Enum(DataEnum { variants, .. }) => {
81            let subcommands: Vec<_> = variants.iter().map(|v| {
82                let ident = &v.ident;
83                quote! { #name::#ident }
84            }).collect();
85            #[cfg(not(feature = "async"))]
86            quote! {
87                impl clap_handler::Handler for #name {
88                    fn execute(&mut self, mut ctx: clap_handler::Context) -> anyhow::Result<()> {
89                        match self {
90                            #(#subcommands(s) => clap_handler::Handler::execute(s, ctx),)*
91                        }
92                    }
93                }
94            }
95            #[cfg(feature = "async")]
96            quote! {
97                #[clap_handler::async_trait]
98                impl clap_handler::Handler for #name {
99                    async fn execute(&mut self, mut ctx: clap_handler::Context) -> anyhow::Result<()> {
100                        match self {
101                            #(#subcommands(s) => clap_handler::Handler::execute(s, ctx).await,)*
102                        }
103                    }
104                }
105            }
106        }
107        _ => panic!("Union type is not supported"),
108    };
109    expanded.into()
110}
111
112#[proc_macro_attribute]
113#[proc_macro_error]
114pub fn handler(args: TokenStream, input: TokenStream) -> TokenStream {
115    let attr = parse_macro_input!(args as AttributeArgs);
116    let attr = match attr.get(0).as_ref().unwrap() {
117        NestedMeta::Meta(Meta::Path(ref attr_ident)) => attr_ident.get_ident().unwrap(),
118        _ => unreachable!("it not gonna happen."),
119    };
120
121    let func = parse_macro_input!(input as ItemFn);
122    let func_block = &func.block;
123    let func_sig = func.sig;
124    let func_name = &func_sig.ident;
125    let func_generics = &func_sig.generics;
126    let func_inputs = &func_sig.inputs;
127    let func_output = &func_sig.output;
128    let types: Vec<_> = func_inputs.iter().map(|i| {
129        match i {
130            syn::FnArg::Typed(ty) => {
131                let ty: &Type = &ty.ty;
132                match ty {
133                    Type::Reference(r) => {
134                        if r.mutability.is_some() {
135                            quote! { ctx.get_mut().unwrap() }
136                        } else {
137                            quote! { ctx.get().unwrap() }
138                        }
139                    }
140                    _ => {
141                        // owned type
142                        // TODO: do not unwrap when ty is Option<T>
143                        // TODO: do not deref when ty is Box<T>
144                        quote! { *ctx.take().unwrap() }
145                    }
146                }
147            }
148            _ => unreachable!("syntax error"),
149        }
150    }).collect();
151
152    cfg_if::cfg_if! {
153        if #[cfg(feature = "async")] {
154            let expanded = quote! {
155                #[clap_handler::async_trait]
156                impl clap_handler::Handler for #attr {
157                    async fn handle_command(&mut self, ctx: &mut clap_handler::Context) -> anyhow::Result<()> {
158                        async fn #func_name #func_generics(#func_inputs)#func_output {
159                            #func_block
160                        }
161                        let result = #func_name(#(#types,)*);
162                        Ok(result.await?)
163                    }
164                }
165            };
166        } else {
167            let expanded = quote! {
168                impl clap_handler::Handler for #attr {
169                    fn handle_command(&mut self, ctx: &mut clap_handler::Context) -> anyhow::Result<()> {
170                        fn #func_name #func_generics(#func_inputs)#func_output {
171                            #func_block
172                        }
173                        let result = #func_name(#(#types,)*);
174                        Ok(result?)
175                    }
176                }
177            };
178        }
179    }
180
181    expanded.into()
182}