apyee_macros/
lib.rs

1extern crate proc_macro;
2
3use convert_case::{Case, Casing};
4use proc_macro::TokenStream;
5use proc_macro2::{Span, TokenStream as TokenStream2};
6use quote::{format_ident, quote, quote_spanned, ToTokens};
7use syn::spanned::Spanned;
8use syn::{Data, Error, Fields, Type};
9
10macro_rules! derive_error {
11    ($string: tt) => {
12        Error::new(Span::call_site(), $string)
13            .to_compile_error()
14            .into()
15    };
16}
17
18#[proc_macro_derive(GetParams)]
19pub fn get_params_derive(input: TokenStream) -> TokenStream {
20    let ast = syn::parse(input).unwrap();
21    impl_get_params(&ast)
22}
23
24fn impl_get_params(ast: &syn::DeriveInput) -> TokenStream {
25    let name = &ast.ident;
26    let data = &ast.data;
27
28    let mut variant_match_arms;
29
30    match data {
31        Data::Enum(data_enum) => {
32            variant_match_arms = TokenStream2::new();
33
34            for variant in &data_enum.variants {
35                let variant_name = &variant.ident;
36
37                // Variant can have unnamed fields like `Variant(i32, i64)`
38                // Variant can have named fields like `Variant {x: i32, y: i32}`
39                // Variant can be named Unit like `Variant`
40                match &variant.fields {
41                    Fields::Unnamed(fields) => {
42                        let mut alphabet = (b'a'..=b'z')
43                            .map(char::from)
44                            .map(String::from)
45                            .map(|s| format_ident!("{}", s))
46                            .collect::<Vec<_>>();
47
48                        let mut field_names = TokenStream2::new();
49                        let mut vec_inits = TokenStream2::new();
50                        let mut vec_extends = TokenStream2::new();
51                        for field in fields.unnamed.iter() {
52                            let field_name = alphabet.remove(0);
53                            field_names.extend(quote_spanned! { field.span() =>
54                                #field_name,
55                            });
56
57                            // if let Type::Path(type_path) = &field.ty {
58                            //     println!("{}", type_path.into_token_stream());
59                            // }
60
61                            match &field.ty {
62                                Type::Path(type_path)
63                                    if type_path.clone().into_token_stream().to_string()
64                                        == "bool" =>
65                                {
66                                    vec_inits.extend(quote_spanned! {variant.span()=>
67                                    match #field_name {true => "on", false => "off"}.into(),})
68                                }
69                                // check if type is a vec
70                                Type::Path(type_path)
71                                    if type_path
72                                        .clone()
73                                        .into_token_stream()
74                                        .to_string()
75                                        .starts_with("Vec<") =>
76                                {
77                                    vec_extends.extend(quote_spanned! {variant.span()=>
78                                    vec.extend(#field_name.iter().map(serde_json::Value::from));})
79                                }
80                                Type::Path(type_path)
81                                    if type_path
82                                        .clone()
83                                        .into_token_stream()
84                                        .to_string()
85                                        .starts_with("Vec <") =>
86                                {
87                                    vec_extends.extend(quote_spanned! {variant.span()=>
88                                    vec.extend(#field_name.iter().map(serde_json::Value::from));})
89                                }
90                                // check if type is an option
91                                Type::Path(type_path)
92                                    if type_path
93                                        .clone()
94                                        .into_token_stream()
95                                        .to_string()
96                                        .starts_with("Option <") =>
97                                {
98                                    vec_extends.extend(quote_spanned! {variant.span()=>
99                                        if let Some(val) = #field_name {
100                                            vec.push(val.to_owned().into());
101                                        }
102                                    });
103                                }
104                                Type::Path(type_path)
105                                    if type_path
106                                        .clone()
107                                        .into_token_stream()
108                                        .to_string()
109                                        .starts_with("Option<") =>
110                                {
111                                    vec_extends.extend(quote_spanned! {variant.span()=>
112                                        if let Some(val) = #field_name {
113                                            vec.push(val.to_owned().into());
114                                        }
115                                    });
116                                }
117
118                                _ => vec_inits.extend(quote_spanned! {variant.span()=>
119                                #field_name.to_owned().into(),}),
120                            }
121                        }
122
123                        if vec_extends.is_empty() {
124                            variant_match_arms.extend(quote_spanned! {variant.span()=>
125                                #name::#variant_name (#field_names) => {
126                                    vec![#vec_inits]
127                                },
128                            });
129                        } else {
130                            variant_match_arms.extend(quote_spanned! {variant.span()=>
131                                        #name::#variant_name (#field_names) => {
132                                            let mut vec = vec![#vec_inits];
133                                            #vec_extends
134                                            vec
135                                        },
136                            });
137                        }
138                    }
139                    Fields::Unit => {
140                        variant_match_arms.extend(quote_spanned! {variant.span()=>
141                                    #name::#variant_name => vec![],
142                        });
143                    }
144                    Fields::Named(_) => {
145                        todo!("Named fields");
146                    }
147                };
148            }
149        }
150        _ => return derive_error!("get_params is only implemented for enums"),
151    };
152
153    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
154
155    let expanded = quote! {
156        impl #impl_generics #name #ty_generics #where_clause {
157            pub(crate) fn get_params(&self) -> Vec<serde_json::Value> {
158            match self {
159                #variant_match_arms
160            }
161        }
162        }
163    };
164
165    TokenStream::from(expanded)
166}
167
168#[proc_macro_derive(IntoJsonValue)]
169pub fn into_json_value_derive(input: TokenStream) -> TokenStream {
170    let ast = syn::parse(input).unwrap();
171    impl_into_json_value_derive(&ast)
172}
173
174fn impl_into_json_value_derive(ast: &syn::DeriveInput) -> TokenStream {
175    let name = &ast.ident;
176    let data = &ast.data;
177
178    let mut variant_match_arms;
179
180    match data {
181        Data::Enum(data_enum) => {
182            variant_match_arms = TokenStream2::new();
183
184            for variant in &data_enum.variants {
185                let variant_name = &variant.ident;
186
187                // Variant can have unnamed fields like `Variant(i32, i64)`
188                // Variant can have named fields like `Variant {x: i32, y: i32}`
189                // Variant can be named Unit like `Variant`
190                let fields_in_variant = match &variant.fields {
191                    Fields::Unnamed(_) => quote_spanned! {variant.span()=> (..) },
192                    Fields::Unit => quote_spanned! { variant.span()=> },
193                    Fields::Named(_) => quote_spanned! {variant.span()=> {..} },
194                };
195
196                let variant_string = variant_name.to_string().to_case(Case::Snake).to_string();
197
198                variant_match_arms.extend(quote_spanned! {variant.span()=>
199                            #name::#variant_name #fields_in_variant => #variant_string.to_string().into(),
200                });
201            }
202        }
203        _ => return derive_error!("get_params is only implemented for enums"),
204    };
205
206    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
207
208    let expanded = quote! {
209        impl From<#impl_generics #name #ty_generics> for serde_json::Value #where_clause {
210            fn from(val: #name) -> Self {
211                match val {
212                    #variant_match_arms
213                }
214            }
215        }
216        impl From<#impl_generics &#name #ty_generics> for serde_json::Value #where_clause {
217            fn from(val: &#name) -> Self {
218                match *val {
219                    #variant_match_arms
220                }
221            }
222        }
223    };
224
225    TokenStream::from(expanded)
226}
227
228#[proc_macro_derive(FromRawCommand)]
229pub fn from_raw_command_derive(input: TokenStream) -> TokenStream {
230    let ast = syn::parse(input).unwrap();
231    impl_from_raw_command_derive(&ast)
232}
233
234fn impl_from_raw_command_derive(ast: &syn::DeriveInput) -> TokenStream {
235    let name = &ast.ident;
236    let data = &ast.data;
237
238    let mut method_construction;
239
240    match data {
241        Data::Enum(data_enum) => {
242            method_construction = TokenStream2::new();
243
244            for variant in &data_enum.variants {
245                let variant_name = &variant.ident;
246                let variant_snake_case = variant_name.to_string().to_case(Case::Snake).to_string();
247
248                let mut param_construction = TokenStream2::new();
249
250                match &variant.fields {
251                    Fields::Unnamed(fields) => {
252                        for (i, field) in fields.unnamed.iter().enumerate() {
253                            match &field.ty {
254                                Type::Path(type_path)
255                                    if type_path.clone().into_token_stream().to_string()
256                                        == "bool" =>
257                                {
258                                    param_construction.extend(quote_spanned! {variant.span()=>
259                                        if raw.params.len() > #i { match raw.params[#i].as_str().unwrap() { "on" => true, "off" => false, _ => false } } else { panic!("Value for non optional field '{} - {}' in '{}' is missing", #i+1, stringify!(#field), stringify!(#variant_name)) },
260                                    });
261                                }
262                                Type::Path(type_path)
263                                    if type_path
264                                        .clone()
265                                        .into_token_stream()
266                                        .to_string()
267                                        .starts_with("Option <") =>
268                                {
269                                    param_construction.extend(quote_spanned! {variant.span()=>
270                                                if raw.params.len() > #i { Some(serde_json::from_value(raw.params[#i].to_owned()).unwrap()) } else { None },
271                                            });
272                                }
273                                _ => {
274                                    param_construction.extend(quote_spanned! {variant.span()=>
275                                        if raw.params.len() > #i { serde_json::from_value(raw.params[#i].to_owned()).unwrap() } else { panic!("Value for non optional field '{} - {}' in '{}' is missing", #i+1, stringify!(#field), stringify!(#variant_name)) },
276                                    });
277                                }
278                            }
279                        }
280                    }
281                    Fields::Unit => {}
282                    _ => todo!(),
283                };
284
285                if param_construction.is_empty() {
286                    method_construction.extend(quote_spanned! {variant.span()=>
287                        #variant_snake_case => #name::#variant_name,
288                    });
289                } else {
290                    method_construction.extend(quote_spanned! {variant.span()=>
291                        #variant_snake_case => #name::#variant_name(#param_construction),
292                    });
293                }
294            }
295        }
296        _ => return derive_error!("FromRawCommand is only implemented for enums"),
297    };
298
299    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
300
301    let expanded = quote! {
302        impl From<RawCommand> for #impl_generics #name #ty_generics #where_clause {
303            fn from(raw: RawCommand) -> Self {
304                match raw.method.as_str() {
305                    #method_construction
306                    _ => panic!("Unknown method"),
307                }
308            }
309        }
310        impl From<&RawCommand> for #impl_generics #name #ty_generics #where_clause {
311            fn from(raw: &RawCommand) -> Self {
312                match raw.method.as_str() {
313                    #method_construction
314                    _ => panic!("Unknown method"),
315                }
316            }
317        }
318    };
319
320    TokenStream::from(expanded)
321}