enum_procs/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{parse_macro_input, Data, DeriveInput, Fields, Variant};
4
5/// Derive macro generating an impl of the trait `PartialEq` that compare enum only by variant
6/// Enum::Variant(value) == Enum::Variant(other_value) => true
7#[proc_macro_derive(PartialEqVariant)]
8pub fn eq_variant(input: TokenStream) -> TokenStream {
9    let input = parse_macro_input!(input as DeriveInput);
10
11    let name = &input.ident;
12    if let Data::Enum(data) = &input.data {
13        &data.variants
14    } else {
15        panic!("PartialEqVariant can only be derived for enums");
16    };
17
18    TokenStream::from(quote! {
19        impl PartialEq for #name {
20            fn eq(&self, other: &Self) -> bool {
21                std::mem::discriminant(self) == std::mem::discriminant(other)
22            }
23        }
24    })
25}
26
27/// Derive macro generating an impl of the trait `From` for all types
28/// inside tuple variants with one type
29#[proc_macro_derive(AutoFrom)]
30pub fn auto_from(input: TokenStream) -> TokenStream {
31    let input = parse_macro_input!(input as DeriveInput);
32
33    let name = &input.ident;
34    let variants = if let Data::Enum(data) = &input.data {
35        &data.variants
36    } else {
37        panic!("AutoFrom can only be derived for enums");
38    };
39
40    let implementations = variants.iter().map(|Variant { ident, fields, .. }| {
41        let variant = quote! { #name::#ident };
42        match fields {
43            Fields::Unnamed(fields) => {
44                if fields.unnamed.len() != 1 {
45                    return quote! {};
46                }
47
48                let field = fields.unnamed.iter().next().unwrap();
49
50                let first = quote! {
51                    impl From<#field> for #name {
52                        fn from(item: #field) -> Self {
53                            #variant(item)
54                        }
55                    }
56                };
57
58                let second = {
59                    if field.to_token_stream().to_string() == "String" {
60                        quote! {
61                            impl From<&str> for #name {
62                                fn from(item: &str) -> Self {
63                                    #variant(item.to_owned())
64                                }
65                            }
66                        }
67                    } else {
68                        quote! {}
69                    }
70                };
71
72                quote! {
73                    #first
74                    #second
75                }
76            }
77            _ => quote! {},
78        }
79    });
80
81    let expanded = quote! {
82        #(#implementations)*
83    };
84
85    TokenStream::from(expanded)
86}