use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::quote;
use syn::{
Attribute, Data, DeriveInput, Error, Field, Fields, Ident, Path, Result as SynResult, Variant,
parse::ParseStream, parse_macro_input, punctuated::Punctuated, token,
};
#[proc_macro_derive(ReflectHash, attributes(reflect_hash))]
pub fn derive_reflect_hash(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match derive_reflect_hash_impl(input) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
fn derive_reflect_hash_impl(input: DeriveInput) -> SynResult<TokenStream2> {
let ty_name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let crate_name = reflect_hash_crate_path(&input.attrs)?;
let hash_computation = match &input.data {
Data::Struct(data_struct) => {
generate_struct_hash(ty_name, &data_struct.fields, &crate_name)?
}
Data::Enum(data_enum) => generate_enum_hash(&data_enum.variants, &crate_name)?,
Data::Union(_) => {
return Err(Error::new_spanned(
ty_name,
"ReflectHash cannot be derived for unions",
));
}
};
Ok(quote! {
impl #impl_generics ReflectHash for #ty_name #ty_generics #where_clause {
const HASH: [u8; 32] = #hash_computation;
}
})
}
fn reflect_hash_crate_path(attrs: &[Attribute]) -> SynResult<TokenStream2> {
for attr in attrs {
if attr.path().is_ident("reflect_hash") {
let parser = |input: ParseStream| -> SynResult<Path> {
input.parse::<token::Crate>()?;
input.parse::<token::Eq>()?;
input.parse::<Path>()
};
let path = attr.parse_args_with(parser)?;
return Ok(quote!(#path));
}
}
match proc_macro_crate::crate_name("sails-reflect-hash") {
Ok(proc_macro_crate::FoundCrate::Itself) => Ok(quote!(crate)),
Ok(proc_macro_crate::FoundCrate::Name(name)) => {
let ident = Ident::new(&name, proc_macro2::Span::call_site());
Ok(quote!(::#ident))
}
Err(e) => Err(Error::new(
Span::call_site(),
format!(
"Could not detect sails-reflect-hash crate: {e}. Consider using #[reflect_hash(crate = path::to::crate)]"
),
)),
}
}
fn generate_struct_hash(
ty_name: &Ident,
fields: &Fields,
crate_name: &TokenStream2,
) -> SynResult<TokenStream2> {
let name_str = ty_name.to_string();
fn fields_hash<'a>(
fields: impl Iterator<Item = &'a Field>,
crate_name: &TokenStream2,
name_str: String,
) -> TokenStream2 {
let field_hashes = fields.map(|field| {
let ty = &field.ty;
quote! {
.update(&<#ty as ReflectHash>::HASH)
}
});
quote! {
#crate_name::keccak_const::Keccak256::new()
.update(#name_str.as_bytes())
#(#field_hashes)*
.finalize()
}
}
match fields {
Fields::Unit => {
Ok(quote! {
#crate_name::keccak_const::Keccak256::new()
.update(#name_str.as_bytes())
.finalize()
})
}
Fields::Unnamed(fields_unnamed) => {
Ok(fields_hash(
fields_unnamed.unnamed.iter(),
crate_name,
name_str,
))
}
Fields::Named(fields_named) => {
Ok(fields_hash(fields_named.named.iter(), crate_name, name_str))
}
}
}
fn generate_enum_hash(
variants: &Punctuated<Variant, token::Comma>,
crate_name: &TokenStream2,
) -> SynResult<TokenStream2> {
let mut variant_hash_computations = Vec::new();
for variant in variants {
let variant_hash = generate_struct_hash(&variant.ident, &variant.fields, crate_name)?;
variant_hash_computations.push(variant_hash);
}
Ok(quote! {
{
let mut final_hasher = #crate_name::keccak_const::Keccak256::new();
#(
{
let variant_hash = #variant_hash_computations;
final_hasher = final_hasher.update(&variant_hash);
}
)*
final_hasher.finalize()
}
})
}