use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{
self, punctuated::Punctuated, token::Comma, DeriveInput, Field, Fields, FieldsNamed,
FieldsUnnamed, Ident, Variant,
};
#[proc_macro_derive(AproxEq)]
pub fn aprox_eq_derive(input: TokenStream) -> TokenStream {
let ast: DeriveInput = syn::parse(input).unwrap();
impl_aprox_eq(ast)
}
fn impl_aprox_eq(input: DeriveInput) -> TokenStream {
let name = &input.ident;
match &input.data {
syn::Data::Struct(syn::DataStruct { fields, .. }) => impl_struct(name, fields),
syn::Data::Enum(syn::DataEnum { variants, .. }) => impl_enum(name, variants),
_ => panic!("`AproxEq` can only derive on struct or enums"),
}
}
fn impl_struct(name: &Ident, fields: &Fields) -> TokenStream {
let condition = match &fields {
syn::Fields::Named(f) => {
let field_ids = f.named.iter().map(|field| &field.ident);
quote! { true #(&& self.#field_ids.aprox_eq(&other.#field_ids))* }
}
syn::Fields::Unnamed(f) => {
let field_ids = (0..f.unnamed.len()).map(syn::Index::from);
quote! { true #(&& self.#field_ids.aprox_eq(&other.#field_ids))* }
}
syn::Fields::Unit => quote! { true },
};
quote! {
impl AproxEq for #name {
fn aprox_eq(&self, other: &Self) -> bool {
#condition
}
}
}
.into()
}
fn impl_enum(name: &Ident, variants: &Punctuated<Variant, Comma>) -> TokenStream {
let var_impls = variants.iter().map(|variant| match &variant.fields {
Fields::Named(FieldsNamed { named, .. }) => impl_var_named_fields(variant, named),
Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => impl_var_unnamed(variant, unnamed),
Fields::Unit => quote! { (Self::#variant, Self::#variant) => true, },
});
quote! {
impl AproxEq for #name {
fn aprox_eq(&self, other: &Self) -> bool {
match (self, other) {
#( #var_impls )*
_ => false,
}
}
}
}
.into()
}
fn impl_var_named_fields(
variant: &Variant,
named: &Punctuated<Field, Comma>,
) -> proc_macro2::TokenStream {
let idents: Vec<_> = named
.iter()
.map(|field| field.ident.as_ref().unwrap())
.collect();
let idents_self: Vec<_> = idents
.iter()
.map(|ident| format_ident!("self_{}", ident))
.collect();
let idents_other: Vec<_> = idents
.iter()
.map(|ident| format_ident!("other_{}", ident))
.collect();
let var_ident = &variant.ident;
quote! {
(
Self::#var_ident { #( #idents: #idents_self ),* },
Self::#var_ident { #( #idents: #idents_other ),* }
) => {
true #( && #idents_self.aprox_eq(#idents_other) )*
}
}
.into()
}
fn impl_var_unnamed(
variant: &Variant,
unnamed: &Punctuated<Field, Comma>,
) -> proc_macro2::TokenStream {
let idents_self: Vec<_> = (0..unnamed.len())
.map(|i| format_ident!("self_{}", i))
.collect();
let idents_other: Vec<_> = (0..unnamed.len())
.map(|i| format_ident!("other_{}", i))
.collect();
let var_ident = &variant.ident;
quote! {
(
Self::#var_ident( #( #idents_self ),* ),
Self::#var_ident( #( #idents_other ),* )
) => {
true #( && #idents_self.aprox_eq(#idents_other) )*
}
}
.into()
}