human_units_derive/
lib.rs

1#![allow(non_snake_case)]
2
3use core::ops::RangeInclusive;
4use core::str::FromStr;
5use proc_macro::Span;
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::parse_macro_input;
9use syn::punctuated::Punctuated;
10use syn::Data;
11use syn::DeriveInput;
12use syn::Expr;
13use syn::Fields;
14use syn::Ident;
15use syn::Lit;
16use syn::Meta;
17use syn::Path;
18use syn::PathArguments;
19use syn::PathSegment;
20use syn::Type;
21
22#[proc_macro_attribute]
23pub fn si_unit(args: TokenStream, item: TokenStream) -> TokenStream {
24    generic_unit(args, item, "si_unit", "si")
25}
26
27#[proc_macro_attribute]
28pub fn iec_unit(args: TokenStream, item: TokenStream) -> TokenStream {
29    generic_unit(args, item, "iec_unit", "iec")
30}
31
32fn generic_unit(
33    args: TokenStream,
34    item: TokenStream,
35    macro_name: &str,
36    system: &str,
37    // TODO ignore case: either lowercase or uppercase
38) -> TokenStream {
39    let SYSTEM = system.to_ascii_uppercase();
40    let system = Ident::new(system, Span::call_site().into());
41    let item = parse_macro_input!(item as DeriveInput);
42    let newtype = &item.ident;
43    let uint = {
44        let Data::Struct(data) = &item.data else {
45            panic!("`{macro_name}` can only be applied to structs with a single unnamed field");
46        };
47        let Fields::Unnamed(fields) = &data.fields else {
48            panic!("`{macro_name}` can only be applied to structs with a single unnamed field");
49        };
50        if fields.unnamed.len() != 1 {
51            panic!("`{macro_name}` can only be applied to structs with a single unnamed field");
52        }
53        let field = fields.unnamed.first().expect("Checked the length above");
54        let Type::Path(path) = &field.ty else {
55            panic!("`{macro_name}`: the struct field should be a primitive unsigned integer");
56        };
57        let uint = path
58            .path
59            .get_ident()
60            .expect("Failed to parse the type of the struct field");
61        if !is_supported_type(uint) {
62            panic!("`{macro_name}`: the struct field should be a primitive unsigned integer, supported types: {UINT_TYPES:?}");
63        }
64        uint
65    };
66    let args = parse_macro_input!(args with Punctuated::<Meta, syn::Token![,]>::parse_terminated);
67    let mut symbol = None;
68    let mut internal = false;
69    let mut iec_min_power: Option<usize> = None;
70    let mut iec_max_power: Option<usize> = None;
71    let mut si_min_power: Option<usize> = None;
72    let mut si_max_power: Option<usize> = None;
73    for arg in args.iter() {
74        match arg {
75            Meta::Path(path) => match path.get_ident() {
76                Some(path) if path == "internal" => internal = true,
77                _ => panic!("Invalid argument {:?}", path.get_ident()),
78            },
79            Meta::NameValue(nv) => match nv.path.get_ident() {
80                Some(name) if name == "symbol" => {
81                    let value = parse_string(&nv.value).unwrap_or_else(|_| {
82                        panic!("`{macro_name}({name} = \"...\")` should be a string literal")
83                    });
84                    symbol = Some(value);
85                }
86                Some(name) if name == "min_prefix" && system == "si" => {
87                    si_min_power = Some(parse_si_unit(&nv.value, macro_name, name) as usize)
88                }
89                Some(name) if name == "max_prefix" && system == "si" => {
90                    si_max_power = Some(parse_si_unit(&nv.value, macro_name, name) as usize)
91                }
92                Some(name) if name == "min_prefix" && system == "iec" => {
93                    iec_min_power = Some(parse_iec_unit(&nv.value, macro_name, name) as usize)
94                }
95                Some(name) if name == "max_prefix" && system == "iec" => {
96                    iec_max_power = Some(parse_iec_unit(&nv.value, macro_name, name) as usize)
97                }
98                _ => panic!("Invalid argument {:?}", nv.path.get_ident()),
99            },
100            _ => panic!("Invalid argument"),
101        }
102    }
103    let symbol = symbol.unwrap_or_else(|| {
104        panic!("`{macro_name}` should at least contain `symbol = \"...\"` attribute")
105    });
106    let (si_min_power, si_max_power) = match (si_min_power, si_max_power) {
107        (None, None) => get_power_of_1000_range(uint.to_string().as_str()).into_inner(),
108        (Some(min), None) => (
109            min,
110            (min + get_power_of_1000_count(uint.to_string().as_str()) - 1)
111                .min(SiUnit::Quetta as usize),
112        ),
113        (None, Some(max)) => (
114            (max + 1).saturating_sub(get_power_of_1000_count(uint.to_string().as_str())),
115            max,
116        ),
117        (Some(min), Some(max)) => (min, max),
118    };
119    let (iec_min_power, iec_max_power) = match (iec_min_power, iec_max_power) {
120        (None, None) => get_power_of_1024_range(uint.to_string().as_str()).into_inner(),
121        (Some(min), None) => (
122            min,
123            (min + get_power_of_1024_count(uint.to_string().as_str()) - 1)
124                .min(SiUnit::Quetta as usize),
125        ),
126        (None, Some(max)) => (
127            (max + 1).saturating_sub(get_power_of_1024_count(uint.to_string().as_str())),
128            max,
129        ),
130        (Some(min), Some(max)) => (min, max),
131    };
132    let (min_power, max_power, power_count, prefix_strs) = if system == "si" {
133        (
134            si_min_power,
135            si_max_power,
136            get_power_of_1000_count(uint.to_string().as_str()),
137            SiUnit::PREFIXES,
138        )
139    } else {
140        (
141            iec_min_power,
142            iec_max_power,
143            get_power_of_1024_count(uint.to_string().as_str()),
144            IecUnit::PREFIXES,
145        )
146    };
147    if max_power < min_power {
148        panic!("`min_prefix` should be less than or equal to `max_prefix`");
149    }
150    if max_power - min_power > power_count {
151        panic!("`min_prefix..max_prefix` range is too big for `{uint}`");
152    }
153    let powers_rev = (min_power..=max_power).rev().collect::<Vec<_>>();
154    let min_prefix_name = match system {
155        ref s if s == "si" => format!("{:?}", SiUnit::try_from(min_power).unwrap()),
156        ref s if s == "iec" => format!("{:?}", IecUnit::try_from(min_power).unwrap()),
157        _ => unreachable!(),
158    };
159    let min_prefix_name = Ident::new(&min_prefix_name, Span::call_site().into());
160    let max_prefix_name = match system {
161        ref s if s == "si" => format!("{:?}", SiUnit::try_from(max_power).unwrap()),
162        ref s if s == "iec" => format!("{:?}", IecUnit::try_from(max_power).unwrap()),
163        _ => unreachable!(),
164    };
165    let max_prefix_name = Ident::new(&max_prefix_name, Span::call_site().into());
166    let uint_string_len = max_uint_string_len(uint.to_string().as_str());
167    let space_len = 1;
168    let max_prefix_len = prefix_strs[min_power].len();
169    let max_string_len = uint_string_len + space_len + max_prefix_len + symbol.len();
170    let serde_visitor = Ident::new(&format!("{newtype}HumanUnitsSerdeVisitor"), newtype.span());
171    let crate_name = if internal {
172        let mut segments = Punctuated::new();
173        segments.push_value(PathSegment {
174            ident: Ident::new("crate", Span::call_site().into()),
175            arguments: PathArguments::None,
176        });
177        Path {
178            leading_colon: None,
179            segments,
180        }
181    } else {
182        let mut segments = Punctuated::new();
183        segments.push_value(PathSegment {
184            ident: Ident::new("human_units", Span::call_site().into()),
185            arguments: PathArguments::None,
186        });
187        Path {
188            leading_colon: Some(Default::default()),
189            segments,
190        }
191    };
192    let scale: u16 = if system == "si" { 1000 } else { 1024 };
193    let write_unit = Ident::new(
194        &format!("write_unit_{uint}_{scale}"),
195        Span::call_site().into(),
196    );
197    let from = Ident::new(&format!("from_{system}"), Span::call_site().into());
198    let try_with_prefix = Ident::new(
199        &format!("try_with_{system}_prefix"),
200        Span::call_site().into(),
201    );
202    let with_prefix = Ident::new(&format!("with_{system}_prefix"), Span::call_site().into());
203    let format = Ident::new(&format!("format_{system}"), Span::call_site().into());
204    let unit_from_str = Ident::new(&format!("{uint}_unit_from_str"), Span::call_site().into());
205    let prefixes = Ident::new(&format!("{SYSTEM}_PREFIXES"), Span::call_site().into());
206    let serde = cfg!(feature = "serde").then_some(quote! {
207        impl #crate_name::imp::serde::Serialize for #newtype {
208            fn serialize<S>(&self, s: S) -> ::core::result::Result<S::Ok, S::Error>
209            where
210                S: #crate_name::imp::serde::Serializer,
211            {
212                let mut buf = #crate_name::Buffer::<{ #newtype::MAX_STRING_LEN }>::new();
213                buf.#write_unit::<#min_power, #max_power>(self.0, #symbol);
214                s.serialize_str(unsafe { buf.as_str() })
215            }
216        }
217
218        impl<'a> #crate_name::imp::serde::Deserialize<'a> for #newtype {
219            fn deserialize<D>(d: D) -> Result<Self, D::Error>
220            where
221                D: #crate_name::imp::serde::Deserializer<'a>,
222            {
223                d.deserialize_str(#serde_visitor)
224            }
225        }
226
227        struct #serde_visitor;
228
229        impl<'a> #crate_name::imp::serde::de::Visitor<'a> for #serde_visitor {
230            type Value = #newtype;
231
232            fn expecting(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
233                f.write_str(concat!("A string obtained by `", stringify!(#newtype), "::to_string`"))
234            }
235
236            fn visit_str<E>(self, value: &str) -> ::core::result::Result<Self::Value, E>
237            where
238                E: #crate_name::imp::serde::de::Error,
239            {
240                value
241                    .parse()
242                    .map_err(|_| E::custom(concat!("Invalid `", stringify!(#newtype), "`")))
243            }
244        }
245    });
246    let format_unit = match system {
247        ref s if s == "si" => {
248            quote! {
249                #(
250                    {
251                        const SCALE: #uint = (1000 as #uint).pow(#powers_rev as u32 - #min_power as u32);
252                        if self.0 >= SCALE {
253                            let integer = self.0 / SCALE;
254                            let mut fraction = self.0 % SCALE;
255                            if fraction != 0 {
256                                // Compute the first digit of the fractional part.
257                                fraction /= (SCALE / 10);
258                            }
259                            debug_assert!(fraction <= 9);
260                            return #crate_name::imp::FormattedUnit::new(
261                                #crate_name::imp::#prefixes[#powers_rev],
262                                Self::SYMBOL,
263                                integer,
264                                fraction as u8,
265                            );
266                        }
267                    }
268                )*
269                let integer = self.0;
270                let fraction = 0;
271                #crate_name::imp::FormattedUnit::new(
272                    #crate_name::imp::#prefixes[#min_power],
273                    Self::SYMBOL,
274                    integer,
275                    fraction,
276                )
277            }
278        }
279        ref s if s == "iec" => {
280            quote! {
281                #(
282                    {
283                        const SCALE: #uint = (1024 as #uint).pow(#powers_rev as u32 - #min_power as u32);
284                        if self.0 >= SCALE {
285                            let integer = self.0 / SCALE;
286                            let mut fraction = self.0 % SCALE;
287                            if fraction != 0 {
288                                // Compute the first digit of the fractional part,
289                                // i.e. `fraction = fraction * 10 / SCALE` without an
290                                // overflow.
291                                fraction = match fraction.checked_mul(5) {
292                                    Some(numerator) => numerator / (SCALE / 2),
293                                    None => {
294                                        debug_assert!(0 == SCALE % 16);
295                                        (fraction / 8) * 5 / (SCALE / 16)
296                                    }
297                                };
298                            }
299                            debug_assert!(fraction <= 9);
300                            return #crate_name::imp::FormattedUnit::new(
301                                #crate_name::imp::#prefixes[#powers_rev],
302                                Self::SYMBOL,
303                                integer,
304                                fraction as u8,
305                            );
306                        }
307                    }
308                )*
309                let integer = self.0;
310                #crate_name::imp::FormattedUnit::new(
311                    #crate_name::imp::#prefixes[#min_power],
312                    Self::SYMBOL,
313                    integer,
314                    0,
315                )
316            }
317        }
318        _ => unreachable!(),
319    };
320    // TODO Implement via Vec::set_len
321    //let to_string_fast = cfg!(feature = "alloc").then_some(quote!{
322    //    impl #newtype {
323    //        pub fn to_string_fast(&self) -> String {
324    //            let mut buf = #crate_name::Buffer::<{ Self::MAX_STRING_LEN }>::new();
325    //            buf.#write_unit::<#min_power, #max_power>(self.0, #symbol);
326    //            ::alloc::string::String::from(unsafe { buf.as_str() })
327    //        }
328    //    }
329    //});
330    quote! {
331        #item
332
333        impl #newtype {
334            /// Max. length in string form.
335            pub const MAX_STRING_LEN: usize = #max_string_len;
336
337            /// Unit symbol.
338            pub const SYMBOL: &'static str = #symbol;
339
340            #[doc = concat!("Minimum ", #SYSTEM, " prefix.")]
341            pub const MIN_PREFIX: #crate_name::#system::Prefix = #crate_name::#system::Prefix::#min_prefix_name;
342
343            #[doc = concat!("Maximum ", #SYSTEM, " prefix.")]
344            pub const MAX_PREFIX: #crate_name::#system::Prefix = #crate_name::#system::Prefix::#max_prefix_name;
345
346            #[doc = concat!("Convert from ", #SYSTEM, " value without prefix. Panics if the prefix is out of range.")]
347            pub const fn #from(value: #uint) -> Self {
348                let prefix = #crate_name::#system::Prefix::None;
349                Self::#with_prefix(value, prefix)
350            }
351
352            #[doc = concat!("Convert from ", #SYSTEM, " value with prefix. Panics if the prefix is out of range.")]
353            #[inline]
354            pub const fn #with_prefix(value: #uint, prefix: #crate_name::#system::Prefix) -> Self {
355                let power = prefix as u32;
356                assert!(#min_power as u32 <= power && power <= #max_power as u32, "Invalid prefix");
357                let factor = (#scale as #uint).pow(power - #min_power as u32);
358                Self(value * factor)
359            }
360
361            #[doc = concat!("Convert from ", #SYSTEM, " value with prefix. Returns an error if the prefix is out of range.")]
362            #[inline]
363            pub const fn #try_with_prefix(value: #uint, prefix: #crate_name::#system::Prefix) -> Result<Self, #crate_name::Error> {
364                let power = prefix as u32;
365                if (#max_power as u32) < power && power < (#min_power as u32) {
366                    return Err(#crate_name::Error);
367                }
368                let factor = (#scale as #uint).pow(power - #min_power as u32);
369                let v = match value.checked_mul(factor) {
370                    Some(value) => value,
371                    None => return Err(#crate_name::Error),
372                };
373                Ok(Self(v))
374            }
375
376            /// Represent the value as a number using the largest possible unit prefix.
377            #[allow(clippy::modulo_one)]
378            pub const fn #format(&self) -> #crate_name::imp::FormattedUnit<'static, #uint, { Self::MAX_STRING_LEN }> {
379                #format_unit
380            }
381        }
382
383        impl ::core::fmt::Display for #newtype {
384            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
385                let mut buf = #crate_name::Buffer::<{ Self::MAX_STRING_LEN }>::new();
386                buf.#write_unit::<#min_power, #max_power>(self.0, #symbol);
387                f.write_str(unsafe { buf.as_str() })
388            }
389        }
390
391        impl ::core::str::FromStr for #newtype {
392            type Err = #crate_name::Error;
393            fn from_str(other: &str) -> Result<Self, Self::Err> {
394                #crate_name::imp::#unit_from_str::<{ #scale as #uint }>(
395                    other,
396                    Self::SYMBOL,
397                    &#crate_name::imp::#prefixes[#min_power..=#max_power],
398                ).map(#newtype)
399            }
400        }
401
402        #serde
403    }
404    .into()
405}
406
407fn parse_string(value: &Expr) -> Result<String, InvalidString> {
408    match value {
409        Expr::Lit(literal) => match &literal.lit {
410            Lit::Str(s) => Ok(s.value()),
411            _ => Err(InvalidString),
412        },
413        Expr::Group(group) => match &*group.expr {
414            Expr::Lit(literal) => match &literal.lit {
415                Lit::Str(s) => Ok(s.value()),
416                _ => Err(InvalidString),
417            },
418            _ => Err(InvalidString),
419        },
420        _ => Err(InvalidString),
421    }
422}
423
424fn parse_si_unit(value: &Expr, macro_name: &str, name: &Ident) -> SiUnit {
425    let value = parse_string(value)
426        .unwrap_or_else(|_| panic!("`{macro_name}({name} = \"...\")` should be a string literal"));
427    value.parse().unwrap_or_else(|_| {
428        panic!(
429            "Unsupported value of `{name}`: {value:?}. Supported values are {:?} and their long counterparts {:?}",
430            SiUnit::PREFIXES,
431            SiUnit::LONG_PREFIXES,
432        )
433    })
434}
435
436fn parse_iec_unit(value: &Expr, macro_name: &str, name: &Ident) -> IecUnit {
437    let value = parse_string(value)
438        .unwrap_or_else(|_| panic!("`{macro_name}({name} = \"...\")` should be a string literal"));
439    value.parse().unwrap_or_else(|_| {
440        panic!(
441            "Unsupported value of `{name}`: {value:?}. Supported values are {:?} and their long counterparts {:?}",
442            IecUnit::PREFIXES,
443            IecUnit::LONG_PREFIXES,
444        )
445    })
446}
447
448#[derive(Debug)]
449struct InvalidString;
450
451macro_rules! define_prefix {
452    ($enum: ident
453     $error: ident
454     $((
455         $name: ident
456         $value: literal
457         $prefix: literal
458         $long_prefix: literal
459     ))+) => {
460        #[derive(Debug, Copy, Clone)]
461        enum $enum {
462            $(
463                $name = $value,
464            )+
465        }
466
467        impl $enum {
468            const PREFIXES: &'static [&'static str] = {
469                &[$($prefix,)+]
470            };
471
472            const LONG_PREFIXES: &'static [&'static str] = {
473                &[$($long_prefix,)+]
474            };
475        }
476
477        impl FromStr for $enum {
478            type Err = $error;
479
480            fn from_str(s: &str) -> Result<Self, Self::Err> {
481                match s {
482                    $($prefix => Ok(Self::$name),)+
483                    $(s if $long_prefix.eq_ignore_ascii_case(s) => Ok(Self::$name),)+
484                    _ => Err($error),
485                }
486            }
487        }
488
489        impl TryFrom<usize> for $enum {
490            type Error = $error;
491
492            fn try_from(other: usize) -> Result<Self, Self::Error> {
493                match other {
494                    $($value => Ok(Self::$name),)+
495                    _ => Err($error),
496                }
497            }
498        }
499
500        #[derive(Debug)]
501        struct $error;
502    }
503}
504
505define_prefix! {
506    SiUnit
507    InvalidSiUnit
508    (Quecto 0 "q" "quecto")
509    (Ronto 1 "r" "ronto")
510    (Yocto 2 "y" "yocto")
511    (Zepto 3 "z" "zepto")
512    (Atto 4 "a" "atto")
513    (Femto 5 "f" "femto")
514    (Pico 6 "p" "pico")
515    (Nano 7 "n" "nano")
516    (Micro 8 "μ" "micro")
517    (Milli 9 "m" "milli")
518    (None 10 "" "")
519    (Kilo 11 "k" "kilo")
520    (Mega 12 "M" "mega")
521    (Giga 13 "G" "giga")
522    (Tera 14 "T" "tera")
523    (Peta 15 "P" "peta")
524    (Exa 16 "E" "exa")
525    (Zetta 17 "Z" "zetta")
526    (Yotta 18 "Y" "yotta")
527    (Ronna 19 "R" "ronna")
528    (Quetta 20 "Q" "quetta")
529}
530
531define_prefix! {
532    IecUnit
533    InvalidIecUnit
534    (None 0 "" "")
535    (Kibi 1 "Ki" "kibi")
536    (Mebi 2 "Mi" "mebi")
537    (Gibi 3 "Gi" "gibi")
538    (Tebi 4 "Ti" "tebi")
539    (Pebi 5 "Pi" "pebi")
540    (Exbi 6 "Ei" "exbi")
541    (Zebi 7 "Zi" "zebi")
542    (Yobi 8 "Yi" "yobi")
543    (Robi 9 "Ri" "robi")
544    (Quebi 10 "Qi" "quebi")
545}
546
547fn is_supported_type(ty: &Ident) -> bool {
548    UINT_TYPES.iter().any(|t| ty == t)
549}
550
551fn max_uint_string_len(ty: &str) -> usize {
552    match ty {
553        "u128" => 39,
554        "u64" => 20,
555        "u32" => 10,
556        "u16" => 5,
557        _ => panic!("`max_uint_string_len`: Unsupported type {ty:?}"),
558    }
559}
560
561fn get_power_of_1000_count(ty: &str) -> usize {
562    let range = get_power_of_1000_range(ty);
563    range.end() - range.start() + 1
564}
565
566fn get_power_of_1000_range(ty: &str) -> RangeInclusive<usize> {
567    match ty {
568        "u128" => SiUnit::Nano as usize..=SiUnit::Ronna as usize,
569        "u64" => SiUnit::Nano as usize..=SiUnit::Giga as usize,
570        "u32" => SiUnit::Nano as usize..=SiUnit::None as usize,
571        "u16" => SiUnit::Nano as usize..=SiUnit::Micro as usize,
572        _ => panic!("`max_power`: Unsupported type {ty:?}"),
573    }
574}
575
576fn get_power_of_1024_count(ty: &str) -> usize {
577    let range = get_power_of_1024_range(ty);
578    range.end() - range.start() + 1
579}
580
581fn get_power_of_1024_range(ty: &str) -> RangeInclusive<usize> {
582    match ty {
583        "u128" => IecUnit::None as usize..=IecUnit::Quebi as usize,
584        "u64" => IecUnit::None as usize..=IecUnit::Exbi as usize,
585        "u32" => IecUnit::None as usize..=IecUnit::Gibi as usize,
586        "u16" => IecUnit::None as usize..=IecUnit::Kibi as usize,
587        _ => panic!("`max_power`: Unsupported type {ty:?}"),
588    }
589}
590
591const UINT_TYPES: [&str; 5] = ["u128", "u64", "u32", "u16", "u8"];
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596
597    #[test]
598    fn test_max_string_len() {
599        for (ty, max) in [
600            ("u128", u128::MAX.to_string().len()),
601            ("u64", u64::MAX.to_string().len()),
602            ("u32", u32::MAX.to_string().len()),
603            ("u16", u16::MAX.to_string().len()),
604        ] {
605            assert_eq!(max, max_uint_string_len(ty));
606        }
607    }
608}