kast_try_hash_derive/
lib.rs

1extern crate proc_macro;
2use proc_macro2::TokenStream;
3use quote::quote;
4use syn::spanned::Spanned;
5
6#[proc_macro_derive(TryHash, attributes(try_hash))]
7pub fn derive_try_hash(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
8    let input: syn::DeriveInput = syn::parse_macro_input!(input);
9    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
10    let ident = input.ident;
11    let try_hash = quote! { try_hash::TryHash };
12
13    struct Field {
14        /// `0` is a valid name originally
15        name: syn::Member,
16        /// because `0` is not a valid name anymore
17        new_name: syn::Ident,
18        do_try: bool,
19    }
20    fn fields(fields: &syn::Fields) -> impl Iterator<Item = Field> + '_ {
21        fields.iter().enumerate().map(|(index, field)| Field {
22            name: match &field.ident {
23                Some(ident) => syn::Member::Named(ident.clone()),
24                None => syn::Member::Unnamed(syn::Index {
25                    index: index as u32,
26                    span: field.span(),
27                }),
28            },
29            new_name: syn::Ident::new(&format!("__self_{index}"), field.span()),
30            do_try: field
31                .attrs
32                .iter()
33                .any(|attr| attr.path().is_ident("try_hash")),
34        })
35    }
36    fn hash_fields(fields: impl IntoIterator<Item = Field>) -> TokenStream {
37        fn hash_field(field: Field) -> TokenStream {
38            let ident = &field.new_name;
39            if field.do_try {
40                quote!( try_hash::TryHash::try_hash(#ident, hasher)?; )
41            } else {
42                quote!( std::hash::Hash::hash(#ident, hasher); )
43            }
44        }
45        let hash_fields = fields.into_iter().map(hash_field);
46        quote! {
47            #( #hash_fields )*
48        }
49    }
50    match input.data {
51        syn::Data::Struct(data) => {
52            let field_names = fields(&data.fields).map(|field| field.name);
53            let field_new_names = fields(&data.fields).map(|field| field.new_name);
54            let hash_fields = hash_fields(fields(&data.fields));
55            quote! {
56                impl #impl_generics #try_hash for #ident #ty_generics #where_clause {
57                    type Error = Box<dyn std::error::Error + Send + Sync>;
58                    fn try_hash(&self, hasher: &mut impl std::hash::Hasher) -> Result<(), Self::Error> {
59                        let Self { #(#field_names: #field_new_names),* } = self;
60                        #hash_fields;
61                        Ok(())
62                    }
63                }
64            }
65        }
66        syn::Data::Enum(data) => {
67            let variants = data.variants.iter().map(|variant| {
68                let ident = &variant.ident;
69                let field_names = fields(&variant.fields).map(|field| field.name);
70                let field_new_names: Vec<_> = fields(&variant.fields).map(|field| field.new_name).collect();
71                let hash_fields = hash_fields(fields(&variant.fields));
72                quote! {
73                    Self::#ident { #(#field_names : #field_new_names),* } => {
74                        #hash_fields
75                    }
76                }
77            });
78            quote! {
79                impl #impl_generics try_hash::TryHash for #ident #ty_generics #where_clause {
80                    type Error = Box<dyn std::error::Error + Send + Sync>;
81                    fn try_hash(&self, hasher: &mut impl std::hash::Hasher) -> Result<(), Self::Error> {
82                        std::hash::Hash::hash(&std::mem::discriminant(self), hasher);
83                        match self {
84                            #(#variants)*
85                        }
86                        Ok(())
87                    }
88                }
89            }
90        }
91        syn::Data::Union(_) => panic!("union no support"),
92    }
93    .into()
94}