clap_handler_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro_error::proc_macro_error;
3use quote::quote;
4use syn::{
5    AttributeArgs, Data, DataEnum, DataStruct, DeriveInput, Fields, Ident, ItemFn, Meta,
6    NestedMeta, Type, parse_macro_input,
7};
8
9#[proc_macro_derive(Handler, attributes(handler_inject))]
10#[proc_macro_error]
11pub fn derive_handler(item: TokenStream) -> TokenStream {
12    let input = parse_macro_input!(item as DeriveInput);
13    let name = &input.ident;
14
15    let context_injector = match input
16        .attrs
17        .iter()
18        .find(|attr| attr.path.is_ident("handler_inject"))
19        .and_then(|a| a.parse_args::<Ident>().ok())
20    {
21        Some(ident) => {
22            cfg_if::cfg_if! {
23                if #[cfg(feature = "async")] {
24                    quote! { self.#ident(ctx).await?; }
25                } else {
26                    quote! { self.#ident(ctx)?; }
27                }
28            }
29        }
30        None => quote! {},
31    };
32
33    let expanded = match input.data {
34        Data::Struct(DataStruct { fields, .. }) => match fields {
35            Fields::Named(ref fields_name) => {
36                let subcommand_field: Option<syn::Ident> =
37                    fields_name.named.iter().find_map(|field| {
38                        for attr in field.attrs.iter() {
39                            if attr.path.is_ident("clap") {
40                                let ident: syn::Ident = attr.parse_args().ok()?;
41                                if ident == "subcommand" {
42                                    return Some(field.ident.clone().unwrap());
43                                }
44                            }
45                        }
46                        None
47                    });
48
49                match subcommand_field {
50                    Some(subcommand_field) => {
51                        #[cfg(not(feature = "async"))]
52                        quote! {
53                            impl clap_handler::Handler for #name {
54                                fn handle_command(&mut self, ctx: &mut clap_handler::Context) -> anyhow::Result<()> {
55                                    #context_injector
56                                    Ok(())
57                                }
58
59                                fn handle_subcommand(&mut self, ctx: clap_handler::Context) -> anyhow::Result<()> {
60                                    clap_handler::Handler::execute(&mut self.#subcommand_field, ctx)
61                                }
62                            }
63                        }
64
65                        #[cfg(feature = "async")]
66                        quote! {
67                            #[clap_handler::async_trait]
68                            impl clap_handler::Handler for #name {
69                                async fn handle_command(&mut self, ctx: &mut clap_handler::Context) -> anyhow::Result<()> {
70                                    #context_injector
71                                    Ok(())
72                                }
73
74                                async fn handle_subcommand(&mut self, ctx: clap_handler::Context) -> anyhow::Result<()> {
75                                    clap_handler::Handler::execute(&mut self.#subcommand_field, ctx).await
76                                }
77                            }
78                        }
79                    }
80                    None => panic!("Struct without #[clap(subcommand)] is not supported!"),
81                }
82            }
83            _ => panic!("Unnamed fields or None struct is not supported"),
84        },
85        Data::Enum(DataEnum { variants, .. }) => {
86            let subcommands: Vec<_> = variants
87                .iter()
88                .map(|v| {
89                    let ident = &v.ident;
90                    quote! { #name::#ident }
91                })
92                .collect();
93            #[cfg(not(feature = "async"))]
94            quote! {
95                impl clap_handler::Handler for #name {
96                    fn execute(&mut self, mut ctx: clap_handler::Context) -> anyhow::Result<()> {
97                        match self {
98                            #(#subcommands(s) => clap_handler::Handler::execute(s, ctx),)*
99                        }
100                    }
101                }
102            }
103            #[cfg(feature = "async")]
104            quote! {
105                #[clap_handler::async_trait]
106                impl clap_handler::Handler for #name {
107                    async fn execute(&mut self, mut ctx: clap_handler::Context) -> anyhow::Result<()> {
108                        match self {
109                            #(#subcommands(s) => clap_handler::Handler::execute(s, ctx).await,)*
110                        }
111                    }
112                }
113            }
114        }
115        _ => panic!("Union type is not supported"),
116    };
117    expanded.into()
118}
119
120#[proc_macro_attribute]
121#[proc_macro_error]
122pub fn handler(args: TokenStream, input: TokenStream) -> TokenStream {
123    let attr = parse_macro_input!(args as AttributeArgs);
124    let attr = match attr.get(0).as_ref().unwrap() {
125        NestedMeta::Meta(Meta::Path(attr_ident)) => attr_ident.get_ident().unwrap(),
126        _ => unreachable!("it not gonna happen."),
127    };
128
129    let func = parse_macro_input!(input as ItemFn);
130    let func_block = &func.block;
131    let func_sig = func.sig;
132    let func_name = &func_sig.ident;
133    let func_generics = &func_sig.generics;
134    let func_inputs = &func_sig.inputs;
135    let func_output = &func_sig.output;
136    let types: Vec<_> = func_inputs
137        .iter()
138        .map(|i| {
139            match i {
140                syn::FnArg::Typed(ty) => {
141                    let ty: &Type = &ty.ty;
142                    match ty {
143                        Type::Reference(r) => {
144                            if r.mutability.is_some() {
145                                quote! { ctx.get_mut().unwrap() }
146                            } else {
147                                quote! { ctx.get().unwrap() }
148                            }
149                        }
150                        _ => {
151                            // owned type
152                            // TODO: do not unwrap when ty is Option<T>
153                            // TODO: do not deref when ty is Box<T>
154                            quote! { *ctx.take().unwrap() }
155                        }
156                    }
157                }
158                _ => unreachable!("syntax error"),
159            }
160        })
161        .collect();
162
163    cfg_if::cfg_if! {
164        if #[cfg(feature = "async")] {
165            let expanded = quote! {
166                #[clap_handler::async_trait]
167                impl clap_handler::Handler for #attr {
168                    async fn handle_command(&mut self, ctx: &mut clap_handler::Context) -> anyhow::Result<()> {
169                        async fn #func_name #func_generics(#func_inputs)#func_output {
170                            #func_block
171                        }
172                        let result = #func_name(#(#types,)*);
173                        Ok(result.await?)
174                    }
175                }
176            };
177        } else {
178            let expanded = quote! {
179                impl clap_handler::Handler for #attr {
180                    fn handle_command(&mut self, ctx: &mut clap_handler::Context) -> anyhow::Result<()> {
181                        fn #func_name #func_generics(#func_inputs)#func_output {
182                            #func_block
183                        }
184                        let result = #func_name(#(#types,)*);
185                        Ok(result?)
186                    }
187                }
188            };
189        }
190    }
191
192    expanded.into()
193}