logos_display/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use std::matches;
4
5use proc_macro2::{Spacing, TokenTree};
6use quote::{quote, ToTokens};
7use syn::{spanned::Spanned, DataEnum, DeriveInput, Ident, Lit, LitStr, Result};
8
9#[proc_macro_derive(Display, attributes(display_override, display_concat))]
10pub fn logos_display(input: TokenStream) -> TokenStream {
11    _logos_display(input.into(), false).into()
12}
13
14#[proc_macro_derive(Debug, attributes(display_override, display_concat))]
15pub fn logos_debug(input: TokenStream) -> TokenStream {
16    _logos_display(input.into(), true).into()
17}
18
19fn _logos_display(input: TokenStream2, debug: bool) -> TokenStream2 {
20    let ast = match syn::parse2::<DeriveInput>(input) {
21        Ok(res) => res,
22        Err(e) => return e.to_compile_error(),
23    };
24    let span = ast.span();
25    let ident = ast.ident;
26    let mut concat = Some("/".to_string());
27    for attr in ast.attrs.into_iter() {
28        if let syn::Meta::List(l) = attr.meta {
29            if l.path.is_ident("display_concat") {
30                let as_str = l.tokens.to_string();
31                let cand = match syn::parse2::<LitStr>(l.tokens) {
32                    Ok(res) => Ok(res.value()),
33                    Err(e) => {
34                        let resp = as_str;
35                        if resp == "None" {
36                            Ok(resp)
37                        } else {
38                            Err(syn::Error::new(
39                                e.span(),
40                                "Concat must be either a string or None",
41                            ))
42                        }
43                    }
44                };
45                let litstr = match cand {
46                    Ok(res) => res,
47                    Err(e) => return e.to_compile_error(),
48                };
49                if litstr == "None" {
50                    concat = None;
51                } else {
52                    concat = Some(litstr);
53                }
54                break;
55            }
56        }
57    }
58    let resp = match ast.data {
59        syn::Data::Enum(e) => logos_display_derive(e, &ident, concat, debug),
60        _ => Err(syn::Error::new(span, "Can only derive display for enums")),
61    };
62    let traitt = if debug {
63        quote!(core::fmt::Debug)
64    } else {
65        quote!(core::fmt::Display)
66    };
67    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
68    match resp {
69        Ok(stream) => {
70            quote! {
71                impl #impl_generics #traitt for #ident #ty_generics #where_clause {
72                    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
73                        use core::fmt::Write;
74                        match &self {
75                            #stream
76                        }
77                    }
78                }
79            }
80        }
81        Err(e) => e.to_compile_error(),
82    }
83}
84
85fn gen_anon_args(n: usize) -> Vec<TokenStream2> {
86    let mut args = Vec::new();
87    for i in 1..=n {
88        let arg_ident = Ident::new(&format!("_arg{}", i), proc_macro2::Span::call_site());
89        args.push(quote!(#arg_ident));
90    }
91    args
92}
93
94fn logos_display_derive(
95    e: DataEnum,
96    ident: &Ident,
97    concat: Option<String>,
98    include_inner: bool,
99) -> Result<TokenStream2> {
100    let mut repr_map: Vec<(TokenStream2, TokenStream2)> = Vec::with_capacity(e.variants.len());
101    for variant in e.variants.into_iter() {
102        let res = match variant.fields {
103            syn::Fields::Named(f) if include_inner => {
104                let args = gen_anon_args(f.named.len());
105                let key_list = quote!(#(#args),*);
106                let key_list = quote!({#key_list});
107                let ref_array = quote!(#(&#args),*);
108                Some((key_list, Some(ref_array)))
109            }
110            syn::Fields::Unnamed(f) if include_inner => {
111                let args = gen_anon_args(f.unnamed.len());
112                let key_list = quote!(#(#args),*);
113                let key_list = quote!((#key_list));
114                let ref_array = quote!(#(&#args),*);
115                Some((key_list, Some(ref_array)))
116            }
117            syn::Fields::Named(..) => {
118                let key_list = quote!({ .. });
119                Some((key_list, None))
120            }
121            syn::Fields::Unnamed(..) => {
122                let key_list = quote!((..));
123                Some((key_list, None))
124            }
125            _ => None,
126        };
127        let id = variant.ident;
128        let mut repr = id.to_string().into_token_stream();
129        let mut found = None;
130        for attr in variant.attrs.into_iter() {
131            if let syn::Meta::List(l) = attr.meta {
132                if l.path.is_ident("display_override") {
133                    let litstr = match syn::parse2::<LitStr>(l.tokens) {
134                        Ok(res) => res,
135                        Err(e) => return Err(e),
136                    };
137                    found = Some(litstr.value());
138                    break;
139                } else if l.path.is_ident("token") || l.path.is_ident("regex") {
140                    let mut new_stream = TokenStream2::new();
141                    let span = l.span();
142                    for tt in l.tokens.into_iter() {
143                        if matches!(tt, TokenTree::Punct(ref punct) if punct.as_char() == ',' && punct.spacing() == Spacing::Alone)
144                        {
145                            break;
146                        } else {
147                            new_stream.extend(Some(tt));
148                        }
149                    }
150                    let lit: Lit = syn::parse2(new_stream)?;
151                    if let Lit::Str(s) = lit {
152                        let string = s.value();
153                        if let Some(f) = found {
154                            if let Some(ref conc) = concat {
155                                found = Some(format!("{}{}{}", f, conc, string))
156                            } else {
157                                found = Some(string);
158                            }
159                        } else {
160                            found = Some(string);
161                        }
162                    } else {
163                        return Err(syn::Error::new(
164                            span,
165                            "Error extracting token from attribute, not a string",
166                        ));
167                    }
168                }
169            }
170        }
171        if let Some(string) = found {
172            repr = string.into_token_stream();
173        }
174        if let Some((key_list, ref_array)) = res {
175            match ref_array {
176                None => repr_map.push((quote!(#id #key_list), quote!(write!(f, "{}", #repr)))),
177                Some(ref_array) => repr_map.push((
178                    quote!(#id #key_list),
179                    quote!(write!(f, "{}{:?}", #repr, [#ref_array])),
180                )),
181            }
182        } else {
183            repr_map.push((quote!(#id), quote!(write!(f, "{}", #repr))));
184        }
185    }
186    let arms: Vec<TokenStream2> = repr_map
187        .iter()
188        .map(|(k, v)| quote!(#ident::#k => #v,))
189        .collect();
190    Ok(quote!(#( #arms )*))
191}
192
193#[cfg(test)]
194mod tests {
195    use super::_logos_display;
196    use assert_tokenstreams_eq::assert_tokenstreams_eq;
197    use proc_macro2::TokenStream;
198    use quote::quote;
199
200    fn expect(
201        arms: TokenStream,
202        debug: bool,
203        generic_lifetimes: Option<&[TokenStream]>,
204        where_clauses: Option<&[TokenStream]>,
205    ) -> TokenStream {
206        let traitt = if debug {
207            quote!(core::fmt::Debug)
208        } else {
209            quote!(core::fmt::Display)
210        };
211
212        let generic_lifetimes = match generic_lifetimes {
213            None => quote!(),
214            Some(generic_lifetimes) => quote!(<#(#generic_lifetimes),*>),
215        };
216
217        let where_clauses = match where_clauses {
218            None => quote!(),
219            Some(where_clauses) => quote!(where #(#where_clauses),*),
220        };
221
222        quote!(
223            impl #generic_lifetimes #traitt for A #generic_lifetimes #where_clauses {
224                fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
225                    use core::fmt::Write;
226                    match &self {
227                        #arms
228                    }
229                }
230            }
231        )
232    }
233
234    #[test]
235    fn test_basic() {
236        let input = quote!(
237            enum A {
238                #[token("{")]
239                LCur,
240
241                #[regex("}")]
242                RCur,
243            }
244        );
245        let arms = quote!(
246            A::LCur => write!(f, "{}" ,"{"),
247            A::RCur => write!(f, "{}", "}"),
248        );
249        let expected = expect(arms, false, None, None);
250        let result = _logos_display(input, false);
251        assert_tokenstreams_eq!(&result, &expected);
252    }
253
254    #[test]
255    fn test_override() {
256        let input = quote!(
257            enum A {
258                #[display_override("fancy curly thing")]
259                #[token("{")]
260                LCur,
261
262                #[regex("}")]
263                RCur,
264
265                #[token("-")]
266                #[display_override("dash")]
267                Minus,
268            }
269        );
270        let arms = quote!(
271            A::LCur => write!(f, "{}", "fancy curly thing"),
272            A::RCur => write!(f, "{}", "}"),
273            A::Minus => write!(f, "{}", "dash"),
274        );
275        let expected = expect(arms, false, None, None);
276        let result = _logos_display(input, false);
277        assert_tokenstreams_eq!(&result, &expected);
278    }
279
280    #[test]
281    fn test_quoted_string() {
282        let input = quote!(
283            enum A {
284                #[token("{")]
285                LCur,
286
287                #[regex("\".*\"")]
288                RCur,
289            }
290        );
291        let arms = quote!(
292            A::LCur => write!(f, "{}", "{"),
293            A::RCur => write!(f, "{}", "\".*\""),
294        );
295        let expected = expect(arms, false, None, None);
296        let result = _logos_display(input, false);
297        assert_tokenstreams_eq!(&result, &expected);
298    }
299
300    #[test]
301    fn test_raw_string() {
302        let input = quote!(
303            enum A {
304                #[token("{")]
305                LCur,
306
307                #[regex(r#"".*""#)]
308                RCur,
309            }
310        );
311        let arms = quote!(
312            A::LCur => write!(f, "{}", "{"),
313            A::RCur => write!(f, "{}", "\".*\""),
314        );
315        let expected = expect(arms, false, None, None);
316        let result = _logos_display(input, false);
317        assert_tokenstreams_eq!(&result, &expected);
318    }
319
320    #[test]
321    fn test_concat_basic() {
322        let input = quote!(
323            enum A {
324                #[token("{")]
325                #[token("}")]
326                Cur,
327            }
328        );
329        let arms = quote!(
330            A::Cur => write!(f, "{}", "{/}"),
331        );
332        let expected = expect(arms, false, None, None);
333        let result = _logos_display(input, false);
334        assert_tokenstreams_eq!(&result, &expected);
335    }
336
337    #[test]
338    fn test_concat_some() {
339        let input = quote!(
340            #[display_concat(" or ")]
341            enum A {
342                #[token("{")]
343                #[token("}")]
344                Cur,
345            }
346        );
347        let arms = quote!(
348            A::Cur => write!(f, "{}", "{ or }"),
349        );
350        let expected = expect(arms, false, None, None);
351        let result = _logos_display(input, false);
352        assert_tokenstreams_eq!(&result, &expected);
353    }
354
355    #[test]
356    fn test_concat_none() {
357        let input = quote!(
358            #[display_concat(None)]
359            enum A {
360                #[token("{")]
361                #[token("}")]
362                Cur,
363            }
364        );
365        let arms = quote!(
366            A::Cur => write!(f, "{}", "}"),
367        );
368        let expected = expect(arms, false, None, None);
369        let result = _logos_display(input, false);
370        assert_tokenstreams_eq!(&result, &expected);
371    }
372
373    #[test]
374    fn test_with_args() {
375        let input = quote!(
376            enum A {
377                #[regex("[a-z]", |lex| funny_business(lex.slice()))]
378                Reg(First, Second, Third),
379
380                #[regex("[A-Z]", |lex| more_funny(lex.slice()))]
381                Reg2 { first: Type, second: Another },
382            }
383        );
384        let arms_debug = quote!(
385            A::Reg(_arg1, _arg2, _arg3) => write!(f, "{}{:?}", "[a-z]", [&_arg1, &_arg2, &_arg3]),
386            A::Reg2{_arg1, _arg2} => write!(f, "{}{:?}", "[A-Z]", [&_arg1, &_arg2]),
387        );
388        let expected_debug = expect(arms_debug, true, None, None);
389        let result_debug = _logos_display(input.clone(), true);
390        assert_tokenstreams_eq!(&result_debug, &expected_debug);
391        let arms_display = quote!(
392            A::Reg(..) => write!(f, "{}", "[a-z]"),
393            A::Reg2{..} => write!(f, "{}", "[A-Z]"),
394        );
395        let expected_display = expect(arms_display, false, None, None);
396        let result_display = _logos_display(input, false);
397        assert_tokenstreams_eq!(&result_display, &expected_display);
398    }
399
400    #[test]
401    fn test_with_generics() {
402        let input = quote!(
403            enum A<'a, 'b, T, U>
404            where
405                T: U,
406            {
407                #[regex("[a-z]", |lex| funny_business(lex.slice()))]
408                Reg(First<'a>, Second<'b>, Third<'a>, T),
409
410                #[regex("[A-Z]", |lex| more_funny(lex.slice()))]
411                Reg2 {
412                    first: Type<'a>,
413                    second: Another<'b>,
414                    third: U,
415                },
416            }
417        );
418        let arms_debug = quote!(
419            A::Reg(_arg1, _arg2, _arg3, _arg4) => write!(f, "{}{:?}", "[a-z]", [&_arg1, &_arg2, &_arg3, &_arg4]),
420            A::Reg2{_arg1, _arg2, _arg3} => write!(f, "{}{:?}", "[A-Z]", [&_arg1, &_arg2, &_arg3]),
421        );
422        let generics = [quote!('a), quote!('b), quote!(T), quote!(U)];
423        let wheres = [quote!(T:U)];
424        let expected_debug = expect(arms_debug, true, Some(&generics), Some(&wheres));
425        let result_debug = _logos_display(input.clone(), true);
426        assert_tokenstreams_eq!(&result_debug, &expected_debug);
427        let arms_display = quote!(
428            A::Reg(..) => write!(f, "{}", "[a-z]"),
429            A::Reg2{..} => write!(f, "{}", "[A-Z]"),
430        );
431        let expected_display = expect(arms_display, false, Some(&generics), Some(&wheres));
432        let result_display = _logos_display(input, false);
433        assert_tokenstreams_eq!(&result_display, &expected_display);
434    }
435}