discriminant_hash_eq/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, DeriveInput, Data, DataEnum};
6
7
8/// A derive macro that implements `Hash`, `PartialEq`, and `Eq` based on enum discriminants.
9/// 
10/// This macro ensures that hashing and equality checks are done based only on the enum variant type,
11/// ignoring any associated data. This is particularly useful for using enums with custom data in
12/// collections like `HashSet` or as keys in `HashMap`, where equality and hashing are determined by
13/// the variant type alone.
14///
15#[proc_macro_derive(DiscriminantHashEq)]
16pub fn discriminant_hash_eq_derive(input: TokenStream) -> TokenStream {
17    // Parse the input tokens into a syntax tree
18    let input = parse_macro_input!(input as DeriveInput);
19
20    // Get the name of the type we are deriving
21    let name = input.ident;
22
23    // Ensure it's an enum
24    let data = match input.data {
25        Data::Enum(data) => data,
26        _ => panic!("#[derive(DiscriminantHashEq)] is only defined for enums."),
27    };
28
29    // Generate the implementation
30    let gen = impl_discriminant_hash_eq(&name, &data);
31
32    // Return the generated implementation
33    gen.into()
34}
35
36fn impl_discriminant_hash_eq(name: &syn::Ident, data: &DataEnum) -> proc_macro2::TokenStream {
37    let variants = &data.variants;
38
39    // Generate the match arms for the discriminant
40    let match_arms = variants.iter().map(|variant| {
41        let variant_name = &variant.ident;
42        match &variant.fields {
43            syn::Fields::Named(_) | syn::Fields::Unnamed(_) => {
44                quote! {
45                    #name::#variant_name { .. } => std::mem::discriminant(self),
46                }
47            }
48            syn::Fields::Unit => {
49                quote! {
50                    #name::#variant_name => std::mem::discriminant(self),
51                }
52            }
53        }
54    });
55
56    quote! {
57        impl std::hash::Hash for #name {
58            fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
59                let discriminant = match self {
60                    #(#match_arms)*
61                };
62                discriminant.hash(state);
63            }
64        }
65
66        impl PartialEq for #name {
67            fn eq(&self, other: &Self) -> bool {
68                std::mem::discriminant(self) == std::mem::discriminant(other)
69            }
70        }
71
72        impl Eq for #name {}
73    }
74}