nan_default/
lib.rs

1//! # NaN Default Derive Macro
2//!
3//! A macro which lets you create structs whose floating-point members
4//! (f32, f64) are initialized as `NAN` rather than `0.0` by default.
5//! Struct members which are not floats get `Default::default()`
6//! assignments, and partial defaulting via `..Default::default()` is
7//! also supported.
8//!
9//! This can be convenient when using the special NAN state of a
10//! floating point number to designate it as invalid or uninitialized,
11//! rather than wrapping it in an `Option` which takes some extra
12//! space — especially if you have many of them.  Of course, any
13//! operations on such variables need to check for this state just as
14//! they would need to check for the `Some` variant of an `Option`
15//! (albeit without idiomatic if-let statements, mapping, and so on).
16//! Depending on the application, `is_nan()` or `!is_finite()` may be
17//! appropriate functions to perform the check.
18//!
19//! # Example
20//! ```
21//! use nan_default::NanDefault;
22//!
23//! #[derive(NanDefault)]
24//! struct Data {
25//!     a: f32,
26//!     b: [f64; 10],
27//!     c: i32,
28//! }
29//!
30//! fn basic() {
31//!     let data = Data::default();
32//!     assert!(data.a.is_nan());
33//!     assert!(data.b[9].is_nan());
34//!     assert_eq!(data.c, 0);
35//! }
36//! ```
37//!
38
39use proc_macro2::TokenStream;
40use quote::{quote, ToTokens};
41use syn::{parse::Error, parse_macro_input, DeriveInput, Type};
42
43fn get_defaults(body: &syn::Fields) -> Result<TokenStream, Error> {
44    match body {
45        syn::Fields::Named(fields) => {
46            let defaults = fields
47                .named
48                .iter()
49                .map(|field| {
50                    let (ty, len) = match &field.ty {
51                        Type::Path(v) => (v, None),
52                        Type::Array(a) => match *a.elem {
53                            Type::Path(ref v) => {
54                                (v, Some(a.len.to_token_stream()))
55                            }
56                            _ => unimplemented!(),
57                        },
58                        _ => unimplemented!(),
59                    };
60
61                    let name = field.ident.as_ref().unwrap().to_string();
62                    let ty = &ty.path.segments[0].ident.to_string();
63                    match ty.as_str() {
64                        "f32" | "f64" => Ok(match len {
65                            None => format!("{name} : {ty}::NAN"),
66                            Some(len) => {
67                                format!("{name} : [{ty}::NAN; {}]", len)
68                            }
69                        }
70                        .parse::<TokenStream>()?),
71
72                        _ => Ok(format!("{name} : Default::default()")
73                            .parse::<TokenStream>()?),
74                    }
75                })
76                .collect::<Result<Vec<_>, Error>>()?;
77
78            Ok(quote! {
79                #( #defaults ),*
80            })
81        }
82        syn::Fields::Unnamed(ref _fields) => Ok(quote! {}),
83        &syn::Fields::Unit => Ok(quote! {}),
84    }
85}
86
87fn impl_nan_derive(input: &DeriveInput) -> Result<TokenStream, Error> {
88    let name = &input.ident;
89
90    let (impl_generics, ty_generics, where_clause) =
91        input.generics.split_for_impl();
92
93    match input.data {
94        syn::Data::Struct(ref body) => {
95            let defaults = get_defaults(&body.fields)?;
96
97            let output = quote! {
98                #[automatically_derived]
99                impl #impl_generics Default for #name #ty_generics #where_clause {
100                    fn default() -> Self {
101                        Self {
102                            #defaults
103                        }
104                    }
105                }
106            };
107            Ok(output)
108        }
109        _ => Err(Error::new(name.span(), "Unsupported type")),
110    }
111}
112
113#[proc_macro_derive(NanDefault)]
114pub fn derive_nan_default(
115    input: proc_macro::TokenStream,
116) -> proc_macro::TokenStream {
117    let input = parse_macro_input!(input as DeriveInput);
118    match impl_nan_derive(&input) {
119        Ok(output) => output.into(),
120        Err(error) => error.to_compile_error().into(),
121    }
122}