float_eq_derive 0.4.1

Derive macro support for float_eq.
Documentation
use proc_macro2::{Ident, Span, TokenStream};
use quote::ToTokens;
use syn::{
    spanned::Spanned, Attribute, Data, DeriveInput, Fields, FieldsNamed, FieldsUnnamed, Lit,
    LitInt, Meta, NestedMeta, Type,
};

pub enum FieldName<'a> {
    Ident(&'a Ident),
    Num(Lit),
}

impl ToTokens for FieldName<'_> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        match self {
            FieldName::Ident(ident) => ident.to_tokens(tokens),
            FieldName::Num(num) => num.to_tokens(tokens),
        }
    }
}

pub struct FieldInfo<'a> {
    pub name: FieldName<'a>,
    pub ty: &'a Type,
}

pub enum FieldListType {
    Named,
    Tuple,
    Unit,
}

pub struct FieldInfoList<'a> {
    pub ty: FieldListType,
    fields: Vec<FieldInfo<'a>>,
}

impl FieldInfoList<'_> {
    pub fn expand<F: std::ops::Fn(&FieldInfo) -> TokenStream>(&self, func: F) -> Vec<TokenStream> {
        self.fields.iter().map(|f| func(f)).collect()
    }
}

pub fn all_fields_info<'a>(
    trait_name: &str,
    input: &'a DeriveInput,
) -> Result<FieldInfoList<'a>, syn::Error> {
    if !input.generics.params.is_empty() {
        return Err(syn::Error::new(
            Span::call_site(),
            "This trait does not yet support derive for generic types.",
        ));
    }

    match &input.data {
        Data::Struct(data) => match &data.fields {
            Fields::Named(FieldsNamed { named, .. }) => Ok(FieldInfoList {
                ty: FieldListType::Named,
                fields: named.iter().map(named_field_info).collect(),
            }),
            Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => Ok(FieldInfoList {
                ty: FieldListType::Tuple,
                fields: unnamed.iter().enumerate().map(unnamed_field_info).collect(),
            }),
            Fields::Unit => Ok(FieldInfoList {
                ty: FieldListType::Unit,
                fields: Vec::new(),
            }),
        },
        _ => Err(syn::Error::new(
            input.ident.span(),
            format!("{} may only be derived for structs.", trait_name),
        )),
    }
}

fn named_field_info(field: &syn::Field) -> FieldInfo {
    FieldInfo {
        name: FieldName::Ident(field.ident.as_ref().expect("Expected named field")),
        ty: &field.ty,
    }
}

fn unnamed_field_info((n, field): (usize, &syn::Field)) -> FieldInfo {
    FieldInfo {
        name: FieldName::Num(Lit::Int(LitInt::new(&format!("{}", n), Span::call_site()))),
        ty: &field.ty,
    }
}

#[derive(Default)]
pub struct FloatEqAttr {
    struct_name: String,
    ulps_type_name: Option<Ident>,
    all_epsilon_type_name: Option<Ident>,
}

impl FloatEqAttr {
    pub fn ulps_type(&self) -> Result<&Ident, syn::Error> {
        self.ulps_type_name.as_ref().ok_or({
            let msg = format!(
                r#"Missing ULPs type name required to derive trait.

help: try adding `#[float_eq(ulps = "{}Ulps")]` to your type."#,
                self.struct_name
            );
            syn::Error::new(Span::call_site(), msg)
        })
    }

    pub fn all_epsilon_type(&self) -> Result<&Ident, syn::Error> {
        self.all_epsilon_type_name.as_ref().ok_or({
            let msg = format!(
                r#"Missing Epsilon type name required to derive trait.

help: try adding `#[float_eq(all_epsilon = "T")]` to your type, where T is commonly `f32` or `f64`."#
            );
            syn::Error::new(Span::call_site(), msg)
        })
    }
}

pub fn float_eq_attr(input: &DeriveInput) -> Result<FloatEqAttr, syn::Error> {
    let nv_pair_lists: Vec<Vec<NameTypePair>> = input
        .attrs
        .iter()
        .filter(|a| a.path.is_ident("float_eq"))
        .map(|a| name_type_pair_list(&input.ident, a))
        .collect::<Result<_, _>>()?;

    let mut attr_values = FloatEqAttr {
        struct_name: input.ident.to_string(),
        ..Default::default()
    };
    for nv in nv_pair_lists.into_iter().flatten() {
        let name = nv.name.to_string();
        if name == "ulps" {
            if attr_values.ulps_type_name.is_none() {
                attr_values.ulps_type_name = Some(nv.value);
            } else {
                let msg = format!(
                    r#"Expected only one ULPs type name, previously saw `ulps = "{}"`."#,
                    attr_values.ulps_type_name.unwrap().to_string()
                );
                return Err(syn::Error::new(nv.value.span(), msg));
            }
        } else if name == "all_epsilon" {
            if attr_values.all_epsilon_type_name.is_none() {
                attr_values.all_epsilon_type_name = Some(nv.value);
            } else {
                let msg = format!(
                    r#"Expected only one Epsilon type name, previously saw `all_epsilon = "{}"`."#,
                    attr_values.all_epsilon_type_name.unwrap().to_string()
                );
                return Err(syn::Error::new(nv.value.span(), msg));
            }
        } else {
            let msg = r"Not a valid float_eq derive option.";
            return Err(syn::Error::new(nv.name.span(), msg));
        }
    }

    Ok(attr_values)
}

fn name_type_pair_list(
    struct_name: &Ident,
    attr: &Attribute,
) -> Result<Vec<NameTypePair>, syn::Error> {
    if let Meta::List(list) = attr.parse_meta()? {
        list.nested.iter().map(name_type_pair).collect()
    } else {
        let msg = format!(
            r#"float_eq attribute must be a list of options, for example `#[float_eq(ulps = "{}Ulps")]`"#,
            struct_name.to_string()
        );
        Err(syn::Error::new(attr.span(), msg))
    }
}

pub struct NameTypePair {
    pub name: Ident,
    pub value: Ident,
}

pub fn name_type_pair(meta: &NestedMeta) -> Result<NameTypePair, syn::Error> {
    if let NestedMeta::Meta(Meta::NameValue(nv)) = meta {
        if let Some(name) = nv.path.get_ident() {
            if let Lit::Str(value) = &nv.lit {
                if let Ok(value) = value.parse::<Ident>() {
                    return Ok(NameTypePair {
                        name: name.clone(),
                        value: value.clone(),
                    });
                }
            }
        }
    }
    Err(syn::Error::new(
        meta.span(),
        "Expected a `name = value` pair.",
    ))
}