use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{parse::Error, parse_macro_input, DeriveInput, Type};
fn get_defaults(body: &syn::Fields) -> Result<TokenStream, Error> {
match body {
syn::Fields::Named(fields) => {
let defaults = fields
.named
.iter()
.map(|field| {
let (ty, len) = match &field.ty {
Type::Path(v) => (v, None),
Type::Array(a) => match *a.elem {
Type::Path(ref v) => {
(v, Some(a.len.to_token_stream()))
}
_ => unimplemented!(),
},
_ => unimplemented!(),
};
let name = field.ident.as_ref().unwrap().to_string();
let ty = &ty.path.segments[0].ident.to_string();
match ty.as_str() {
"f32" | "f64" => Ok(match len {
None => format!("{name} : {ty}::NAN"),
Some(len) => {
format!("{name} : [{ty}::NAN; {}]", len)
}
}
.parse::<TokenStream>()?),
_ => Ok(format!("{name} : Default::default()")
.parse::<TokenStream>()?),
}
})
.collect::<Result<Vec<_>, Error>>()?;
Ok(quote! {
#( #defaults ),*
})
}
syn::Fields::Unnamed(ref _fields) => Ok(quote! {}),
&syn::Fields::Unit => Ok(quote! {}),
}
}
fn impl_nan_derive(input: &DeriveInput) -> Result<TokenStream, Error> {
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) =
input.generics.split_for_impl();
match input.data {
syn::Data::Struct(ref body) => {
let defaults = get_defaults(&body.fields)?;
let output = quote! {
#[automatically_derived]
impl #impl_generics Default for #name #ty_generics #where_clause {
fn default() -> Self {
Self {
#defaults
}
}
}
};
Ok(output)
}
_ => Err(Error::new(name.span(), "Unsupported type")),
}
}
#[proc_macro_derive(NanDefault)]
pub fn derive_nan_default(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match impl_nan_derive(&input) {
Ok(output) => output.into(),
Err(error) => error.to_compile_error().into(),
}
}