human-units-derive 0.4.0

Units serialization and formatting library designed for configuration files and command line arguments
Documentation
use proc_macro::Span;
use proc_macro::TokenStream;
use quote::quote;
use syn::parse_macro_input;
use syn::punctuated::Punctuated;
use syn::Data;
use syn::DeriveInput;
use syn::Expr;
use syn::Fields;
use syn::Ident;
use syn::Lit;
use syn::Meta;
use syn::Path;
use syn::PathArguments;
use syn::PathSegment;
use syn::Type;

#[proc_macro_attribute]
pub fn si_unit(args: TokenStream, item: TokenStream) -> TokenStream {
    let args = parse_macro_input!(args with Punctuated::<Meta, syn::Token![,]>::parse_terminated);
    let symbol = args
        .iter()
        .find_map(|meta| match meta {
            Meta::NameValue(nv) => {
                let path = nv.path.get_ident()?;
                if path != "symbol" {
                    return None;
                }
                match &nv.value {
                    Expr::Lit(literal) => match &literal.lit {
                        Lit::Str(symbol) => Some(symbol.value()),
                        _ => panic!("`si_unit(symbol = \"...\")` should be a string literal"),
                    },
                    Expr::Group(group) => match &*group.expr {
                        Expr::Lit(literal) => match &literal.lit {
                            Lit::Str(symbol) => Some(symbol.value()),
                            _ => panic!("`si_unit(symbol = \"...\")` should be a string literal"),
                        },
                        _ => panic!("`si_unit(symbol = \"...\")` should be a literal"),
                    },
                    _ => panic!("`si_unit(symbol = \"...\")` should be a literal"),
                }
            }
            _ => None,
        })
        .expect("`si_unit` should at least contain `symbol = \"...\"` attribute");
    let internal = args.iter().any(|meta| match meta {
        Meta::Path(path) => {
            let Some(path) = path.get_ident() else {
                return false;
            };
            path == "internal"
        }
        _ => false,
    });
    let item = parse_macro_input!(item as DeriveInput);
    let newtype = &item.ident;
    let Data::Struct(data) = &item.data else {
        panic!("`si_unit` can only be applied to structs with a single unnamed field");
    };
    let Fields::Unnamed(fields) = &data.fields else {
        panic!("`si_unit` can only be applied to structs with a single unnamed field");
    };
    if fields.unnamed.len() != 1 {
        panic!("`si_unit` can only be applied to structs with a single unnamed field");
    }
    let field = fields.unnamed.first().expect("Checked the length above");
    let Type::Path(path) = &field.ty else {
        panic!("`si_unit`: the struct field should be a primitive unsigned integer");
    };
    let uint = path
        .path
        .get_ident()
        .expect("Failed to parse the type of the struct field");
    if !is_supported_type(uint) {
        panic!("`si_unit`: the struct field should be a primitive unsigned integer, supported types: {UINT_TYPES:?}");
    }
    let uint_string_len = max_string_len(uint.to_string().as_str());
    let min_prefix_len = 1;
    let max_string_len = uint_string_len + 1 + min_prefix_len + symbol.len();
    let serde_visitor = Ident::new(&format!("{newtype}HumanUnitsSerdeVisitor"), newtype.span());
    let crate_name = if internal {
        let mut segments = Punctuated::new();
        segments.push_value(PathSegment {
            ident: Ident::new("crate", Span::call_site().into()),
            arguments: PathArguments::None,
        });
        Path {
            leading_colon: None,
            segments,
        }
    } else {
        let mut segments = Punctuated::new();
        segments.push_value(PathSegment {
            ident: Ident::new("human_units", Span::call_site().into()),
            arguments: PathArguments::None,
        });
        Path {
            leading_colon: Some(Default::default()),
            segments,
        }
    };
    let write_unit = Ident::new(&format!("write_unit_{uint}"), Span::call_site().into());
    let serde = cfg!(feature = "serde").then_some(quote! {
        impl serde::Serialize for #newtype {
            fn serialize<S>(&self, s: S) -> core::result::Result<S::Ok, S::Error>
            where
                S: serde::Serializer,
            {
                let mut buf = #crate_name::Buffer::<{ #newtype::MAX_STRING_LEN }>::new();
                buf.#write_unit(self.0, #symbol);
                s.serialize_str(unsafe { buf.as_str() })
            }
        }

        impl<'a> serde::Deserialize<'a> for #newtype {
            fn deserialize<D>(d: D) -> Result<Self, D::Error>
            where
                D: serde::Deserializer<'a>,
            {
                d.deserialize_str(#serde_visitor)
            }
        }

        struct #serde_visitor;

        impl<'a> serde::de::Visitor<'a> for #serde_visitor {
            type Value = #newtype;

            fn expecting(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
                f.write_str(concat!("A string obtained by `", stringify!(#newtype), "::to_string`"))
            }

            fn visit_str<E>(self, value: &str) -> core::result::Result<Self::Value, E>
            where
                E: serde::de::Error,
            {
                value
                    .parse()
                    .map_err(|_| E::custom(concat!("Invalid `", stringify!(#newtype), "`")))
            }
        }
    });
    quote! {
        #item

        impl #newtype {
            /// Max. length in string form.
            pub const MAX_STRING_LEN: usize = #max_string_len;

            /// Unit symbol.
            pub const SYMBOL: &'static str = #symbol;

            /// Convert from a value without SI prefix.
            pub fn from_si(value: #uint) -> Self {
                Self(value * 1_000_000_000)
            }
        }

        impl #crate_name::si::FormatSi for #newtype {
            fn format_si(&self) -> #crate_name::si::FormattedUnit<'static> {
                #crate_name::si::FormatSiUnit::format_si_unit(self.0, #symbol)
            }
        }

        impl core::fmt::Display for #newtype {
            fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
                let mut buf = #crate_name::Buffer::<{ Self::MAX_STRING_LEN }>::new();
                buf.#write_unit(self.0, #symbol);
                f.write_str(unsafe { buf.as_str() })
            }
        }

        impl core::str::FromStr for #newtype {
            type Err = #crate_name::si::Error;
            fn from_str(other: &str) -> Result<Self, Self::Err> {
                #crate_name::si::SiFromStr::si_unit_from_str(other, Self::SYMBOL).map(#newtype)
            }
        }

        #serde
    }
    .into()
}

fn is_supported_type(ty: &Ident) -> bool {
    for t in UINT_TYPES {
        if ty == t {
            return true;
        }
    }
    false
}

fn max_string_len(ty: &str) -> usize {
    match ty {
        "u128" => 39,
        "u64" => 20,
        "u32" => 10,
        "u16" => 5,
        _ => panic!("`max_string_len`: Unsupported type {ty:?}"),
    }
}

const UINT_TYPES: [&str; 5] = ["u128", "u64", "u32", "u16", "u8"];

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_max_string_len() {
        for (ty, max) in [
            ("u128", u128::MAX.to_string().len()),
            ("u64", u64::MAX.to_string().len()),
            ("u32", u32::MAX.to_string().len()),
            ("u16", u16::MAX.to_string().len()),
        ] {
            assert_eq!(max, max_string_len(ty));
        }
    }
}