libtor_derive/
lib.rs

1extern crate proc_macro;
2extern crate proc_macro2;
3extern crate quote;
4extern crate syn;
5
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::{format_ident, quote, quote_spanned};
8use syn::parse::{Parse, ParseStream};
9use syn::spanned::Spanned;
10use syn::{braced, parenthesized, parse_macro_input, token, Data, DeriveInput, Fields, Token};
11
12#[cfg_attr(feature = "debug", derive(Debug))]
13struct ExpandToArg {
14    keyword: Ident,
15    name: syn::Lit,
16}
17
18impl Parse for ExpandToArg {
19    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
20        let keyword = input.parse()?;
21        input.parse::<Token![=]>()?;
22        let name = input.parse()?;
23
24        Ok(ExpandToArg { keyword, name })
25    }
26}
27
28#[cfg_attr(feature = "debug", derive(Debug))]
29struct TestStruct {
30    args_group: Option<TokenStream>,
31    expected: syn::LitStr,
32}
33
34impl Parse for TestStruct {
35    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
36        let keyword: Ident = input.parse()?;
37        if keyword != "test" {
38            return Err(syn::Error::new(keyword.span(), "expected `test`"));
39        }
40        input.parse::<Token![=]>()?;
41
42        let args_group: Option<TokenStream> = if input.peek(token::Brace) {
43            let content;
44            braced!(content in input);
45            let content: TokenStream = content.parse()?;
46
47            Some(quote! {
48                { #content }
49            })
50        } else if input.peek(token::Paren) {
51            let content;
52            parenthesized!(content in input);
53            let content: TokenStream = content.parse()?;
54
55            Some(quote! {
56                ( #content )
57            })
58        } else {
59            None
60        };
61
62        input.parse::<Token![=>]>()?;
63
64        let expected = input.parse()?;
65        if let syn::Lit::Str(expected) = expected {
66            Ok(TestStruct {
67                args_group,
68                expected,
69            })
70        } else {
71            Err(syn::Error::new(keyword.span(), "expected a string literal"))
72        }
73    }
74}
75
76fn split_first_space_args(val: TokenStream) -> TokenStream {
77    quote! {
78        {
79            let formatted = #val;
80            let parts = formatted.splitn(2, " ").collect::<Vec<_>>();
81
82            let mut answer = vec![parts[0].to_string()];
83            if let Some(part) = parts.get(1) {
84                answer.push(part.to_string());
85            }
86
87            answer
88        }
89    }
90}
91
92fn generate_test(
93    parsed: TestStruct,
94    test_count: usize,
95    enum_name: &Ident,
96    name: &Ident,
97    span: Span,
98) -> TokenStream {
99    let test_name = format_ident!("TEST_{}_{}", name, test_count);
100    let args_group = &parsed.args_group.unwrap_or_default();
101    let expected = &parsed.expected;
102
103    quote_spanned! {span=>
104        #[test]
105        fn #test_name() {
106            use Expand;
107
108            let v = #enum_name::#name#args_group;
109            println!("{:?} => {}", v, v.expand_cli());
110            assert_eq!(v.expand_cli(), #expected);
111        }
112    }
113}
114
115#[proc_macro_derive(Expand, attributes(expand_to))]
116pub fn derive_helper_attr(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
117    let input = parse_macro_input!(input as DeriveInput);
118    let enum_name = &input.ident;
119
120    let (match_body, test_funcs) = match input.data {
121        Data::Enum(data) => {
122            let mut stream = TokenStream::new();
123            let mut test_stream = TokenStream::new();
124
125            for variant in data.variants {
126                let span = &variant.span();
127                let name = &variant.ident;
128                let mut name_string = name.to_string();
129                let mut test_count = 0;
130                let mut implemented_with = false;
131
132                let mut fmt_attr = None;
133
134                for attr in &variant.attrs {
135                    if attr.path.get_ident() != Some(&format_ident!("expand_to")) {
136                        continue;
137                    }
138
139                    if attr.parse_args::<syn::Lit>().is_ok() {
140                        fmt_attr = Some(attr);
141                    } else if let Ok(arg) = attr.parse_args::<ExpandToArg>() {
142                        if arg.keyword == "rename" {
143                            if let syn::Lit::Str(lit_str) = arg.name {
144                                name_string = lit_str.value();
145                            } else {
146                                let tokens = quote_spanned! {*span=>
147                                    #enum_name::#name{..} => compile_error!("`rename` must be followed by a string literal, eg #[expand_to(rename = \"example\")]"),
148                                };
149                                stream.extend(tokens);
150                            }
151                        } else if arg.keyword == "with" {
152                            let tokens = if let syn::Lit::Str(lit_str) = arg.name {
153                                let ident = Ident::new(&lit_str.value(), *span);
154
155                                let matcher = match variant.fields {
156                                    Fields::Unnamed(_) => quote! {
157                                        #enum_name::#name(..)
158                                    },
159                                    Fields::Named(_) => quote! {
160                                        #enum_name::#name{..}
161                                    },
162                                    Fields::Unit => quote! {
163                                        #enum_name::#name()
164                                    },
165                                };
166
167                                quote_spanned! {*span=>
168                                    #matcher => #ident(self),
169                                }
170                            } else {
171                                quote_spanned! {*span=>
172                                    #enum_name::#name{..} => compile_error!("`with` must be followed by a string literal, eg #[expand_to(with = \"my_custom_function\")]"),
173                                }
174                            };
175
176                            stream.extend(tokens);
177                            implemented_with = true;
178                        }
179                    } else {
180                        // TODO: add those example as doc attributes
181                        if let Ok(parsed) = attr.parse_args::<TestStruct>() {
182                            test_stream
183                                .extend(generate_test(parsed, test_count, enum_name, name, *span));
184                            test_count += 1;
185                        }
186                    }
187                }
188
189                if implemented_with {
190                    continue;
191                }
192
193                let ignore_filter = |field: &&syn::Field| {
194                    !field.attrs.iter().any(|a| {
195                        a.parse_args::<syn::Ident>()
196                            .and_then(|ident| Ok(ident == "ignore"))
197                            .unwrap_or(false)
198                    })
199                };
200                let tokens = match (variant.fields, fmt_attr) {
201                    (Fields::Named(_), None) => {
202                        quote_spanned! {*span=>
203                            #enum_name::#name{..} => compile_error!("Named fields require an explicit expansion attribute"),
204                        }
205                    }
206                    (Fields::Named(fields), Some(attr)) => {
207                        let args: TokenStream = attr.parse_args().unwrap();
208
209                        let fmt_params = fields.named.iter().filter(ignore_filter).map(|f| {
210                            let ident = &f.ident;
211                            quote_spanned! {f.span()=> #ident = #ident }
212                        });
213                        let expand_params = fields.named.iter().map(|f| {
214                            let ident = &f.ident;
215                            quote_spanned! {f.span()=> #ident }
216                        });
217
218                        let fmt_str_quoted = quote! { format!(#args, #(#fmt_params, )*) };
219                        let content = split_first_space_args(fmt_str_quoted);
220                        quote_spanned! {attr.span()=>
221                            #enum_name::#name{#(#expand_params, )*} => {
222                                #content
223                            },
224                        }
225                    }
226                    (Fields::Unnamed(fields), attr) => {
227                        let expand_params = (0..fields.unnamed.len())
228                            .map(|i| Ident::new(&format!("p_{}", i), i.span()));
229                        let fmt_params = (0..fields.unnamed.len())
230                            .filter(|i| ignore_filter(&&fields.unnamed[*i]))
231                            .map(|i| Ident::new(&format!("p_{}", i), i.span()));
232
233                        if let Some(attr) = attr {
234                            let args: TokenStream = attr.parse_args().unwrap();
235                            let fmt_str_quoted = quote! { format!(#args, #(#fmt_params, )*) };
236                            let content = split_first_space_args(fmt_str_quoted);
237                            quote_spanned! {*span=>
238                                #enum_name::#name(#(#expand_params, )*) => {
239                                    #content
240                                },
241                            }
242                        } else {
243                            let fmt_str = (0..fields.unnamed.len())
244                                .map(|_| "{}")
245                                .collect::<Vec<&str>>()
246                                .join(" ");
247                            quote_spanned! {*span=>
248                                #enum_name::#name(#(#expand_params, )*) => vec![#name_string.to_string(), format!(#fmt_str, #(#fmt_params, )*)],
249                            }
250                        }
251                    }
252                    (Fields::Unit, None) => quote! {
253                        #enum_name::#name => vec![#name_string.to_string()],
254                    },
255                    (Fields::Unit, Some(attr)) => {
256                        let args: TokenStream = attr.parse_args().unwrap();
257                        let args_str_quoted = quote! { #args.to_string() };
258                        let content = split_first_space_args(args_str_quoted);
259                        quote! {
260                            #enum_name::#name => #content,
261                        }
262                    }
263                };
264
265                stream.extend(tokens);
266            }
267
268            (stream, test_stream)
269        }
270        _ => unimplemented!(),
271    };
272
273    let test_mod_name = Ident::new(
274        &format!("_GENERATED_TESTS_FOR_{}", enum_name),
275        enum_name.span(),
276    );
277
278    let name = input.ident;
279    let expanded = quote! {
280        impl Expand for #name {
281            fn expand(&self) -> Vec<String> {
282                #[allow(unused)]
283                #[allow(clippy::useless_format)]
284                match self {
285                    #match_body
286                }
287            }
288        }
289
290        #[cfg(test)]
291        #[allow(non_snake_case)]
292        mod #test_mod_name {
293            use super::*;
294
295            #test_funcs
296        }
297    };
298
299    proc_macro::TokenStream::from(expanded)
300}