1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, Data, DeriveInput, Fields, Variant};
/// Compare enum only by variant
/// Enum::Variant(value) == Enum::Variant(other_value) => true
#[proc_macro_derive(PartialEqVariant)]
pub fn partial_eq_variant(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);

    let name = &input.ident;
    if let Data::Enum(data) = &input.data {
        &data.variants
    } else {
        panic!("PartialEqVariant can only be derived for enums");
    };

    TokenStream::from(quote! {
        impl PartialEq for #name {
            fn eq(&self, other: &Self) -> bool {
                std::mem::discriminant(self) == std::mem::discriminant(other)
            }
        }
    })
}

/// Compare enum except last value
/// Enum::Variant(value, first_value) == Enum::Variant(value, other_second_value) => true
#[proc_macro_derive(PartialEqExceptLast)]
pub fn partial_eq_except_last(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);

    let name = &input.ident;
    let variants = if let Data::Enum(data) = &input.data {
        &data.variants
    } else {
        panic!("PartialEqVariant can only be derived for enums");
    };

    let variant_checks = variants.iter().map(|Variant { ident, fields, .. }| {
        let variant = quote! { #name::#ident };
        match fields {
            Fields::Unnamed(fields) => {
                let all_fields = fields.unnamed.iter().rev().skip(1).rev().enumerate();

                let fields1 = all_fields.clone().map(|(i, _)| format_ident!("f_{}", i));
                let fields2 = all_fields.clone().map(|(i, _)| format_ident!("s_{}", i));

                let fields3 = fields1.clone();
                let fields4 = fields2.clone();

                if fields.unnamed.len() > 1 {
                    quote! { (#variant(#(#fields1,)* _), #variant(#(#fields2,)* _)) => #(
                        #fields3 == #fields4
                    )&&* }
                } else {
                    quote! { (#variant(_), #variant(_)) => true }
                }
            }
            Fields::Unit => quote! { (#variant, #variant) => true },

            _ => panic!(
                "PartialEqVariant can only be derived for enums with unnamed or unit variants"
            ),
        }
    });

    let expanded = quote! {
        impl PartialEq for #name {
            fn eq(&self, other: &Self) -> bool {
                match (self, other) {
                    #(#variant_checks,)*
                    _ => false,
                }
            }
        }
    };

    TokenStream::from(expanded)
}