use proc_macro::TokenStream;
use proc_macro2::{Ident, TokenStream as TokenStream2};
use quote::{quote, ToTokens};
use syn::{parse_macro_input, spanned::Spanned, Data, DeriveInput, Fields, Index, Type};
fn is_float_type(ty: &Type) -> bool {
if let Type::Path(path) = ty {
let segments = &path.path.segments;
if segments.len() == 1 {
return segments[0].ident == "f32" || segments[0].ident == "f64"
}
}
false
}
fn partial_eq_impl(ty: &Type, self_tokens: &impl ToTokens, other_tokens: &impl ToTokens, is_first: &mut bool) -> TokenStream2 {
let first_tokens = if *is_first {
TokenStream2::new()
} else {
quote! { && }
};
let result = if is_float_type(ty) {
quote! { #first_tokens ::float_derive::utils::eq(#self_tokens, #other_tokens)}
} else {
quote! {
#first_tokens #self_tokens == #other_tokens
}
};
*is_first = false;
result
}
#[proc_macro_derive(FloatPartialEq)]
pub fn derive_partial_eq(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let ident = &input.ident;
TokenStream::from(match input.data {
Data::Enum(data) => {
let mut num_non_unit_variants = 0;
let mut num_unit_variants = 0;
let variants = data
.variants
.iter()
.map(|variant| {
let variant_ident = &variant.ident;
let mut is_first = true;
match &variant.fields {
Fields::Named(fields) => {
let self_args = fields.named.iter().enumerate().map(|(i, field)| {
let field_ident = &field.ident.as_ref().unwrap();
let arg_ident = Ident::new(&format!("__self_{i}"), field.span());
quote! {
#field_ident: #arg_ident,
}
});
let other_args = fields.named.iter().enumerate().map(|(i, field)| {
let field_ident = &field.ident.as_ref().unwrap();
let arg_ident = Ident::new(&format!("__arg1_{i}"), field.span());
quote! {
#field_ident: #arg_ident,
}
});
let impls = fields.named.iter().enumerate().map(|(i, field)| {
let self_ident = Ident::new(&format!("__self_{i}"), field.span());
let other_ident = Ident::new(&format!("__arg1_{i}"), field.span());
partial_eq_impl(&field.ty, "e! { *#self_ident}, "e! { *#other_ident }, &mut is_first)
});
num_non_unit_variants += 1;
quote! {
(
#ident::#variant_ident { #(#self_args)* },
#ident::#variant_ident { #(#other_args)* }
) => {
#(#impls)*
}
}
}
Fields::Unnamed(fields) => {
let self_args = fields.unnamed.iter().enumerate().map(|(i, field)| {
let arg_ident = Ident::new(&format!("__self_{i}"), field.span());
quote! {
#arg_ident,
}
});
let other_args = fields.unnamed.iter().enumerate().map(|(i, field)| {
let arg_ident = Ident::new(&format!("__arg1_{i}"), field.span());
quote! {
#arg_ident,
}
});
let impls = fields.unnamed.iter().enumerate().map(|(i, field)| {
let self_ident = Ident::new(&format!("__self_{i}"), field.span());
let other_ident = Ident::new(&format!("__arg1_{i}"), field.span());
partial_eq_impl(&field.ty, "e! { *#self_ident}, "e! { *#other_ident }, &mut is_first)
});
num_non_unit_variants += 1;
quote! {
(
#ident::#variant_ident(#(#self_args)*),
#ident::#variant_ident(#(#other_args)*)
) => {
#(#impls)*
}
}
}
Fields::Unit => {
num_unit_variants += 1;
TokenStream2::new()
}
}
})
.collect::<TokenStream2>();
let num_variants = num_non_unit_variants + num_unit_variants;
let body = if num_non_unit_variants == 0 && num_variants < 2 {
quote! { true }
} else {
let default_pattern = if num_unit_variants == 0 {
quote! {
_ => unsafe { ::core::intrinsics::unreachable() }
}
} else {
quote! {
_ => true
}
};
let matched_variants = quote! {
match (self, other) {
#variants
#default_pattern
}
};
if num_variants > 1 {
let tags = quote! {
let __self_tag = ::core::intrinsics::discriminant_value(self);
let __arg1_tag = ::core::intrinsics::discriminant_value(other);
__self_tag == __arg1_tag
};
if num_non_unit_variants > 0 {
quote! {
#tags && #matched_variants
}
} else {
tags
}
} else {
matched_variants
}
};
quote! {
#[automatically_derived]
impl ::core::marker::StructuralPartialEq for #ident {}
#[automatically_derived]
impl ::core::cmp::PartialEq for #ident {
#[inline]
fn eq(&self, other: &#ident) -> bool {
#body
}
}
}
}
Data::Struct(data) => {
let mut is_first = true;
let fields = match data.fields {
Fields::Named(fields) => {
fields
.named
.iter()
.map(|field| {
let ident = field.ident.as_ref().unwrap();
partial_eq_impl(&field.ty, "e! { self.#ident }, "e! { other.#ident }, &mut is_first)
})
.collect::<TokenStream2>()
}
Fields::Unnamed(fields) => {
fields
.unnamed
.iter()
.enumerate()
.map(|(i, field)| {
let index = Index { index: i as _, span: field.span() };
partial_eq_impl(&field.ty, "e! { self.#index }, "e! { other.#index }, &mut is_first)
})
.collect::<TokenStream2>()
}
Fields::Unit => TokenStream2::new()
};
quote! {
#[automatically_derived]
impl ::core::marker::StructuralPartialEq for #ident {}
#[automatically_derived]
impl ::core::cmp::PartialEq for #ident {
#[inline]
fn eq(&self, other: &#ident) -> bool {
#fields
}
}
}
}
Data::Union(_) => panic!("this trait cannot be derived for unions")
})
}
#[proc_macro_derive(FloatEq)]
pub fn derive_eq(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let ident = &input.ident;
TokenStream::from(quote! {
#[automatically_derived]
impl ::std::cmp::Eq for #ident {}
})
}
fn hash_impl(ty: &Type, tokens: &impl ToTokens) -> TokenStream2 {
if is_float_type(ty) {
quote! {
::float_derive::utils::hash(#tokens, state);
}
} else {
quote! {
::core::hash::Hash::hash(#tokens, state);
}
}
}
#[proc_macro_derive(FloatHash)]
pub fn derive_hash(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let ident = &input.ident;
TokenStream::from(match input.data {
Data::Enum(data) => {
let variants = {
let mut has_non_unit_variants = false;
let mut has_unit_variants = false;
let variants = data
.variants
.iter()
.map(|variant| {
let variant_ident = &variant.ident;
match &variant.fields {
Fields::Named(fields) => {
let args = fields.named.iter().map(|field| {
let field_ident = field.ident.as_ref().unwrap();
quote! {
#field_ident,
}
});
let impls = fields.named.iter().map(|field| {
let field_ident = field.ident.as_ref().unwrap();
hash_impl(&field.ty, "e! { #field_ident})
});
has_non_unit_variants = true;
quote! {
#ident::#variant_ident { #(#args)* } => { #(#impls)* }
}
}
Fields::Unnamed(fields) => {
let args = fields.unnamed.iter().enumerate().map(|(i, field)| {
let field_ident = Ident::new(&format!("__self_{i}"), field.span());
quote! {
#field_ident,
}
});
let impls = fields.unnamed.iter().enumerate().map(|(i, field)| {
let field_ident = Ident::new(&format!("__self_{i}"), field.span());
hash_impl(&field.ty, "e! { #field_ident})
});
has_non_unit_variants = true;
quote! {
#ident::#variant_ident(#(#args)*) => { #(#impls)* }
}
}
Fields::Unit => {
has_unit_variants = true;
TokenStream2::new()
}
}
})
.collect::<TokenStream2>();
let default_pattern = if has_unit_variants {
quote! { _ => () }
} else {
TokenStream2::new()
};
if has_non_unit_variants {
quote! {
match self {
#variants
#default_pattern
}
}
} else {
TokenStream2::new()
}
};
quote! {
#[automatically_derived]
impl ::std::hash::Hash for #ident {
fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) {
let __self_tag = ::core::intrinsics::discriminant_value(self);
::core::hash::Hash::hash(&__self_tag, state);
#variants
}
}
}
}
Data::Struct(data) => {
let fields = match data.fields {
Fields::Named(fields) => {
fields
.named
.iter()
.map(|field| {
let ident = field.ident.as_ref().unwrap();
hash_impl(&field.ty, "e! { &self.#ident})
})
.collect::<TokenStream2>()
}
Fields::Unnamed(fields) => {
fields
.unnamed
.iter()
.enumerate()
.map(|(i, field)| {
let index = Index { index: i as _, span: field.span() };
hash_impl(&field.ty, "e! { &self.#index})
})
.collect::<TokenStream2>()
}
Fields::Unit => TokenStream2::new()
};
quote! {
#[automatically_derived]
impl ::std::hash::Hash for #ident {
fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) {
#fields
}
}
}
}
Data::Union(_) => panic!("this trait cannot be derived for unions")
})
}