float-derive-macros 0.1.0

A crate that allows deriving Eq and Hash for types that contain floating points
Documentation
use proc_macro::TokenStream;
use proc_macro2::{Ident, TokenStream as TokenStream2};
use quote::{quote, ToTokens};
use syn::{parse_macro_input, spanned::Spanned, Data, DeriveInput, Fields, Index, Type};

fn is_float_type(ty: &Type) -> bool {
    if let Type::Path(path) = ty {
        let segments = &path.path.segments;
        if segments.len() == 1 {
            return segments[0].ident == "f32" || segments[0].ident == "f64"
        }
    }
    false
}

fn partial_eq_impl(ty: &Type, self_tokens: &impl ToTokens, other_tokens: &impl ToTokens, is_first: &mut bool) -> TokenStream2 {
    let first_tokens = if *is_first {
        TokenStream2::new()
    } else {
        quote! { && }
    };

    let result = if is_float_type(ty) {
        quote! { #first_tokens ::float_derive::utils::eq(#self_tokens, #other_tokens)}
    } else {
        quote! {
            #first_tokens #self_tokens == #other_tokens
        }
    };
    *is_first = false;
    result
}

#[proc_macro_derive(FloatPartialEq)]
pub fn derive_partial_eq(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let ident = &input.ident;

    TokenStream::from(match input.data {
        Data::Enum(data) => {
            let mut num_non_unit_variants = 0;
            let mut num_unit_variants = 0;

            let variants = data
                .variants
                .iter()
                .map(|variant| {
                    let variant_ident = &variant.ident;
                    let mut is_first = true;

                    match &variant.fields {
                        Fields::Named(fields) => {
                            let self_args = fields.named.iter().enumerate().map(|(i, field)| {
                                let field_ident = &field.ident.as_ref().unwrap();
                                let arg_ident = Ident::new(&format!("__self_{i}"), field.span());

                                quote! {
                                    #field_ident: #arg_ident,
                                }
                            });

                            let other_args = fields.named.iter().enumerate().map(|(i, field)| {
                                let field_ident = &field.ident.as_ref().unwrap();
                                let arg_ident = Ident::new(&format!("__arg1_{i}"), field.span());

                                quote! {
                                    #field_ident: #arg_ident,
                                }
                            });

                            let impls = fields.named.iter().enumerate().map(|(i, field)| {
                                let self_ident = Ident::new(&format!("__self_{i}"), field.span());
                                let other_ident = Ident::new(&format!("__arg1_{i}"), field.span());

                                partial_eq_impl(&field.ty, &quote! { *#self_ident}, &quote! { *#other_ident }, &mut is_first)
                            });

                            num_non_unit_variants += 1;
                            quote! {
                                (
                                    #ident::#variant_ident { #(#self_args)* },
                                    #ident::#variant_ident { #(#other_args)* }
                                ) => {
                                    #(#impls)*
                                }
                            }
                        }
                        Fields::Unnamed(fields) => {
                            let self_args = fields.unnamed.iter().enumerate().map(|(i, field)| {
                                let arg_ident = Ident::new(&format!("__self_{i}"), field.span());

                                quote! {
                                    #arg_ident,
                                }
                            });

                            let other_args = fields.unnamed.iter().enumerate().map(|(i, field)| {
                                let arg_ident = Ident::new(&format!("__arg1_{i}"), field.span());

                                quote! {
                                    #arg_ident,
                                }
                            });

                            let impls = fields.unnamed.iter().enumerate().map(|(i, field)| {
                                let self_ident = Ident::new(&format!("__self_{i}"), field.span());
                                let other_ident = Ident::new(&format!("__arg1_{i}"), field.span());

                                partial_eq_impl(&field.ty, &quote! { *#self_ident}, &quote! { *#other_ident }, &mut is_first)
                            });

                            num_non_unit_variants += 1;
                            quote! {
                                (
                                    #ident::#variant_ident(#(#self_args)*),
                                    #ident::#variant_ident(#(#other_args)*)
                                ) => {
                                    #(#impls)*
                                }
                            }
                        }
                        Fields::Unit => {
                            num_unit_variants += 1;
                            TokenStream2::new()
                        }
                    }
                })
                .collect::<TokenStream2>();

            let num_variants = num_non_unit_variants + num_unit_variants;

            let body = if num_non_unit_variants == 0 && num_variants < 2 {
                quote! { true }
            } else {
                let default_pattern = if num_unit_variants == 0 {
                    quote! {
                        _ => unsafe { ::core::intrinsics::unreachable() }
                    }
                } else {
                    quote! {
                        _ => true
                    }
                };

                let matched_variants = quote! {
                    match (self, other) {
                        #variants
                        #default_pattern
                    }
                };

                if num_variants > 1 {
                    let tags = quote! {
                        let __self_tag = ::core::intrinsics::discriminant_value(self);
                        let __arg1_tag = ::core::intrinsics::discriminant_value(other);
                        __self_tag == __arg1_tag
                    };

                    if num_non_unit_variants > 0 {
                        quote! {
                            #tags && #matched_variants
                        }
                    } else {
                        tags
                    }
                } else {
                    matched_variants
                }
            };

            quote! {
                #[automatically_derived]
                impl ::core::marker::StructuralPartialEq for #ident {}
                #[automatically_derived]
                impl ::core::cmp::PartialEq for #ident {
                    #[inline]
                    fn eq(&self, other: &#ident) -> bool {
                        #body
                    }
                }
            }
        }
        Data::Struct(data) => {
            let mut is_first = true;

            let fields = match data.fields {
                Fields::Named(fields) => {
                    fields
                        .named
                        .iter()
                        .map(|field| {
                            let ident = field.ident.as_ref().unwrap();

                            partial_eq_impl(&field.ty, &quote! { self.#ident }, &quote! { other.#ident }, &mut is_first)
                        })
                        .collect::<TokenStream2>()
                }
                Fields::Unnamed(fields) => {
                    fields
                        .unnamed
                        .iter()
                        .enumerate()
                        .map(|(i, field)| {
                            let index = Index { index: i as _, span: field.span() };
                            partial_eq_impl(&field.ty, &quote! { self.#index }, &quote! { other.#index }, &mut is_first)
                        })
                        .collect::<TokenStream2>()
                }
                Fields::Unit => TokenStream2::new()
            };

            quote! {
                #[automatically_derived]
                impl ::core::marker::StructuralPartialEq for #ident {}
                #[automatically_derived]
                impl ::core::cmp::PartialEq for #ident {
                    #[inline]
                    fn eq(&self, other: &#ident) -> bool {
                        #fields
                    }
                }
            }
        }
        Data::Union(_) => panic!("this trait cannot be derived for unions")
    })
}

#[proc_macro_derive(FloatEq)]
pub fn derive_eq(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let ident = &input.ident;

    TokenStream::from(quote! {
        #[automatically_derived]
        impl ::std::cmp::Eq for #ident {}
    })
}

fn hash_impl(ty: &Type, tokens: &impl ToTokens) -> TokenStream2 {
    if is_float_type(ty) {
        quote! {
            ::float_derive::utils::hash(#tokens, state);
        }
    } else {
        quote! {
            ::core::hash::Hash::hash(#tokens, state);
        }
    }
}

#[proc_macro_derive(FloatHash)]
pub fn derive_hash(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let ident = &input.ident;

    TokenStream::from(match input.data {
        Data::Enum(data) => {
            let variants = {
                let mut has_non_unit_variants = false;
                let mut has_unit_variants = false;

                let variants = data
                    .variants
                    .iter()
                    .map(|variant| {
                        let variant_ident = &variant.ident;

                        match &variant.fields {
                            Fields::Named(fields) => {
                                let args = fields.named.iter().map(|field| {
                                    let field_ident = field.ident.as_ref().unwrap();

                                    quote! {
                                        #field_ident,
                                    }
                                });

                                let impls = fields.named.iter().map(|field| {
                                    let field_ident = field.ident.as_ref().unwrap();
                                    hash_impl(&field.ty, &quote! { #field_ident})
                                });

                                has_non_unit_variants = true;
                                quote! {
                                    #ident::#variant_ident { #(#args)* } => { #(#impls)* }
                                }
                            }
                            Fields::Unnamed(fields) => {
                                let args = fields.unnamed.iter().enumerate().map(|(i, field)| {
                                    let field_ident = Ident::new(&format!("__self_{i}"), field.span());

                                    quote! {
                                        #field_ident,
                                    }
                                });

                                let impls = fields.unnamed.iter().enumerate().map(|(i, field)| {
                                    let field_ident = Ident::new(&format!("__self_{i}"), field.span());
                                    hash_impl(&field.ty, &quote! { #field_ident})
                                });

                                has_non_unit_variants = true;
                                quote! {
                                    #ident::#variant_ident(#(#args)*) => { #(#impls)* }
                                }
                            }
                            Fields::Unit => {
                                has_unit_variants = true;
                                TokenStream2::new()
                            }
                        }
                    })
                    .collect::<TokenStream2>();

                let default_pattern = if has_unit_variants {
                    quote! { _ => () }
                } else {
                    TokenStream2::new()
                };

                if has_non_unit_variants {
                    quote! {
                        match self {
                            #variants
                            #default_pattern
                        }
                    }
                } else {
                    TokenStream2::new()
                }
            };

            quote! {
                #[automatically_derived]
                impl ::std::hash::Hash for #ident {
                    fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) {
                        let __self_tag = ::core::intrinsics::discriminant_value(self);
                        ::core::hash::Hash::hash(&__self_tag, state);

                        #variants
                    }
                }
            }
        }
        Data::Struct(data) => {
            let fields = match data.fields {
                Fields::Named(fields) => {
                    fields
                        .named
                        .iter()
                        .map(|field| {
                            let ident = field.ident.as_ref().unwrap();
                            hash_impl(&field.ty, &quote! { &self.#ident})
                        })
                        .collect::<TokenStream2>()
                }
                Fields::Unnamed(fields) => {
                    fields
                        .unnamed
                        .iter()
                        .enumerate()
                        .map(|(i, field)| {
                            let index = Index { index: i as _, span: field.span() };
                            hash_impl(&field.ty, &quote! { &self.#index})
                        })
                        .collect::<TokenStream2>()
                }
                Fields::Unit => TokenStream2::new()
            };

            quote! {
                #[automatically_derived]
                impl ::std::hash::Hash for #ident {
                    fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) {
                        #fields
                    }
                }
            }
        }
        Data::Union(_) => panic!("this trait cannot be derived for unions")
    })
}