use std::collections::HashSet;
use proc_macro::TokenStream as TokenStream1;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{DeriveInput, Path, parse_macro_input};
struct TypeSignatureImpl {
ident: syn::Ident,
generics: syn::Generics,
generic_constraints: Vec<syn::Type>,
variants: Vec<TokenStream>,
rename: Option<String>,
crate_path: Path,
}
impl TryFrom<DeriveInput> for TypeSignatureImpl {
type Error = syn::Error;
fn try_from(ast: DeriveInput) -> syn::Result<Self> {
let type_attrs = TypeAttrs::parse(&ast.attrs)?;
let crate_path = type_attrs.crate_path.unwrap_or_else(|| Path {
leading_colon: Some(syn::token::PathSep(Span::call_site())),
segments: {
let mut segments = syn::punctuated::Punctuated::new();
segments.push(syn::Ident::new("type_signature", Span::call_site()).into());
segments
},
});
for param in &ast.generics.params {
if let syn::GenericParam::Const(const_param) = param {
let is_ident = matches!(
&const_param.ty,
syn::Type::Path(syn::TypePath { qself: None, path })
if path.get_ident().is_some()
);
if !is_ident {
return Err(syn::Error::new_spanned(
&const_param.ty,
"TypeSignature derive only supports const generic parameters whose type is a simple identifier (e.g. `usize`, `bool`)",
));
}
}
}
let any_generic_tys = ast
.generics
.params
.iter()
.any(|param| matches!(param, syn::GenericParam::Type(_)));
let (variants, generic_constraints) = match ast.data {
syn::Data::Struct(st) => {
let (field_impls, field_tys) = fields_info(&st.fields, &crate_path)?;
let variants = vec![quote!(("", &[ #( #field_impls ),* ]))];
(variants, field_tys)
}
syn::Data::Enum(en) => {
let rows = en
.variants
.iter()
.map(|variant| -> syn::Result<_> {
let variant_attrs = TypeAttrs::parse(&variant.attrs)?;
let variant_name = variant_attrs
.rename
.unwrap_or_else(|| variant.ident.to_string());
let (field_impls, field_tys) = fields_info(&variant.fields, &crate_path)?;
let variant_impl = quote!((#variant_name, &[ #( #field_impls ),* ]));
Ok((variant_impl, field_tys))
})
.collect::<syn::Result<Vec<_>>>()?;
let (variants, per_variant_field_tys): (Vec<_>, Vec<_>) = rows.into_iter().unzip();
let field_tys = deduplicate(per_variant_field_tys.into_iter().flatten()).collect();
(variants, field_tys)
}
syn::Data::Union(un) => un
.fields
.named
.iter()
.filter_map(|field| {
let attrs = match FieldAttrs::parse(&field.attrs) {
Ok(a) => a,
Err(e) => return Some(Err(e)),
};
if attrs.skip {
return None;
}
let name = attrs.rename.unwrap_or_else(|| {
field
.ident
.as_ref()
.expect("union fields are always named")
.to_string()
});
let ty = &field.ty;
let variant = quote!(
(#name, &[("", &<#ty as #crate_path::TypeSignature>::SIGNATURE)])
);
Some(Ok((variant, field.ty.clone())))
})
.collect::<syn::Result<Vec<_>>>()?
.into_iter()
.unzip(),
};
let generic_constraints = if any_generic_tys {
generic_constraints
} else {
Vec::new()
};
Ok(Self {
ident: ast.ident,
generics: ast.generics,
generic_constraints,
variants,
rename: type_attrs.rename,
crate_path,
})
}
}
impl quote::ToTokens for TypeSignatureImpl {
fn to_tokens(&self, tokens: &mut TokenStream) {
tokens.extend(self.to_token_stream());
}
fn to_token_stream(&self) -> TokenStream {
let (impl_generics, ty_generics, _) = self.generics.split_for_impl();
let user_where_predicates: Vec<&syn::WherePredicate> = self
.generics
.where_clause
.as_ref()
.map(|wc| wc.predicates.iter().collect())
.unwrap_or_default();
let ident = &self.ident;
let ty_name = self
.rename
.clone()
.unwrap_or_else(|| self.ident.to_string());
let generic_constraints = &self.generic_constraints;
let variants = &self.variants;
let crate_path = &self.crate_path;
let generic_ty_bounds = self.generics.params.iter().filter_map(|param| {
if let syn::GenericParam::Type(ty) = param {
let ident = &ty.ident;
Some(quote!(#ident: #crate_path::TypeSignature))
} else {
None
}
});
let generic_ty_signatures = self.generics.params.iter().filter_map(|param| {
if let syn::GenericParam::Type(ty) = param {
let ident = &ty.ident;
Some(quote!(&<#ident as #crate_path::TypeSignature>::SIGNATURE))
} else {
None
}
});
let const_generic_signatures = self.generics.params.iter().filter_map(|param| {
if let syn::GenericParam::Const(const_param) = param {
let syn::Type::Path(syn::TypePath { qself: None, path }) = &const_param.ty else {
unreachable!("validated in TryFrom::try_from")
};
let param_ty = path
.get_ident()
.expect("validated in TryFrom::try_from")
.to_string();
let hash_fn_name =
syn::Ident::new(&format!("hash_const_{param_ty}"), Span::call_site());
let param_val = &const_param.ident;
let param_name = const_param.ident.to_string();
Some(quote! { const {
let mut acc = #crate_path::__macro_export::hash_str(#param_name);
#crate_path::__macro_export::mix_values(
&mut acc,
#crate_path::__macro_export::#hash_fn_name(#param_val)
);
acc
}})
} else {
None
}
});
quote! {
impl #impl_generics #crate_path::TypeSignature for #ident #ty_generics
where
#( #user_where_predicates, )*
#( #generic_ty_bounds, )*
#( #generic_constraints: #crate_path::TypeSignature ),*
{
#![allow(single_use_lifetimes, reason = "Macro-generated code")]
const SIGNATURE: #crate_path::TypeSignatureHasher = #crate_path::TypeSignatureHasher {
ty_name: #ty_name,
ty_generics: &[ #( #generic_ty_signatures ),* ],
const_generic_hashes: &[ #( #const_generic_signatures ),* ],
variants: &[ #( #variants ),* ],
};
}
}
}
}
#[proc_macro_derive(TypeSignature, attributes(type_signature))]
pub fn derive_type_signature(input: TokenStream1) -> TokenStream1 {
let ast = parse_macro_input!(input as DeriveInput);
match TypeSignatureImpl::try_from(ast) {
Ok(imp) => quote!(#imp),
Err(e) => e.into_compile_error(),
}
.into()
}
fn deduplicate<T: core::hash::Hash + Eq + Clone>(
elems: impl IntoIterator<Item = T>,
) -> impl Iterator<Item = T> {
let mut seen = HashSet::new();
elems.into_iter().filter(move |ty| seen.insert(ty.clone()))
}
fn fields_info(
fields: &syn::Fields,
crate_path: &Path,
) -> syn::Result<(Vec<TokenStream>, Vec<syn::Type>)> {
let rows = fields
.iter()
.enumerate()
.filter_map(|(idx, field)| {
let attrs = match FieldAttrs::parse(&field.attrs) {
Ok(a) => a,
Err(e) => return Some(Err(e)),
};
if attrs.skip {
return None;
}
let name = attrs.rename.unwrap_or_else(|| {
field
.ident
.as_ref()
.map_or_else(|| idx.to_string(), syn::Ident::to_string)
});
let ty = &field.ty;
let impl_tokens = quote!((#name, &<#ty as #crate_path::TypeSignature>::SIGNATURE));
Some(Ok((impl_tokens, field.ty.clone())))
})
.collect::<syn::Result<Vec<_>>>()?;
Ok(rows.into_iter().unzip())
}
#[derive(Default)]
struct TypeAttrs {
rename: Option<String>,
crate_path: Option<Path>,
}
impl TypeAttrs {
fn parse(attrs: &[syn::Attribute]) -> syn::Result<Self> {
let mut out = Self::default();
for attr in attrs {
if !attr.path().is_ident("type_signature") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("rename") {
let lit: syn::LitStr = meta.value()?.parse()?;
out.rename = Some(lit.value());
Ok(())
} else if meta.path.is_ident("crate") {
let crate_path: Path = meta.value()?.parse()?;
out.crate_path = Some(crate_path);
Ok(())
} else {
Err(meta.error("unrecognized type_signature attribute {attr:?}"))
}
})?;
}
Ok(out)
}
}
#[derive(Default)]
struct FieldAttrs {
skip: bool,
rename: Option<String>,
}
impl FieldAttrs {
fn parse(attrs: &[syn::Attribute]) -> syn::Result<Self> {
let mut out = Self::default();
for attr in attrs {
if !attr.path().is_ident("type_signature") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("skip") {
out.skip = true;
Ok(())
} else if meta.path.is_ident("rename") {
let lit: syn::LitStr = meta.value()?.parse()?;
out.rename = Some(lit.value());
Ok(())
} else {
Err(meta.error(
"unrecognized type_signature attribute; expected `skip` or `rename = \"...\"`",
))
}
})?;
}
Ok(out)
}
}