enumeric/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Error, Expr, ExprLit, ExprRange, ItemEnum, Lit, LitInt, Meta};
4
5#[proc_macro_attribute]
6pub fn range_enum(_: TokenStream, item: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(item as ItemEnum);
8    let mut generated_variants = Vec::default();
9
10    let vis = &input.vis;
11    let generics = &input.generics;
12    let enum_ident = &input.ident;
13
14    for variant in input.variants.iter() {
15        let mut range_variant = false;
16        for attr in &variant.attrs {
17            if attr.path().is_ident("range") {
18                // parse range of discriminant
19                let Meta::List(meta_list) = attr.meta.clone() else {
20                    continue;
21                };
22
23                let Expr::Range(range) = syn::parse2::<Expr>(meta_list.tokens.clone()).unwrap()
24                else {
25                    continue;
26                };
27
28                let range = match ParsedRange::try_new(range) {
29                    Ok(r) => r,
30                    Err(err) => return err.to_compile_error().into(),
31                };
32
33                // parse base variant name
34                let base = &variant.ident;
35
36                let (start, end) = match (range.start, range.end) {
37                    (Some(start), Some(end)) => (start, end),
38                    _ => unimplemented!("Currently only x..y and x..=y supported."),
39                };
40
41                for i in start..end {
42                    let variant_name = syn::Ident::new(&format!("{}{}", base, i), base.span());
43                    let fields = &variant.fields;
44                    let discriminant = variant
45                        .discriminant
46                        .as_ref()
47                        .map(|(_, expr)| quote! { = #expr });
48
49                    generated_variants.push(quote! {
50                        #variant_name #fields #discriminant,
51                    });
52                }
53                range_variant = true;
54                break;
55            }
56        }
57
58        // keep original variant if not range
59        if !range_variant {
60            let variant_name = &variant.ident;
61            let fields = &variant.fields;
62            let discriminant = variant
63                .discriminant
64                .as_ref()
65                .map(|(_, expr)| quote! { = #expr });
66
67            generated_variants.push(quote! {
68                #variant_name #fields #discriminant,
69            });
70        }
71    }
72    let output = quote! {
73        #vis enum #enum_ident #generics {
74            #(#generated_variants)*
75        }
76    };
77    output.into()
78}
79
80// FIXME: use bigger type instead of `u64`
81#[derive(Copy, Clone, Debug)]
82struct ParsedRange {
83    start: Option<u64>,
84    end: Option<u64>,
85}
86impl ParsedRange {
87    fn try_new(range: ExprRange) -> Result<ParsedRange, Error> {
88        let start = match range.start.as_deref() {
89            Some(Expr::Lit(ExprLit {
90                lit: Lit::Int(i), ..
91            })) => Some(parse_litint_auto(i)),
92            Some(expr) => {
93                return Err(Error::new_spanned(
94                    expr,
95                    "Expected integer literal for range start.",
96                ))
97            }
98            _ => None,
99        };
100
101        let end_raw = match range.end.as_deref() {
102            Some(Expr::Lit(ExprLit {
103                lit: Lit::Int(i), ..
104            })) => Some(parse_litint_auto(i)),
105            Some(expr) => {
106                return Err(Error::new_spanned(
107                    expr,
108                    "Expected integer literal for range end.",
109                ))
110            }
111            _ => None,
112        };
113
114        let end = if let Some(end) = end_raw {
115            Some(match range.limits {
116                syn::RangeLimits::Closed(_) => end + 1,
117                syn::RangeLimits::HalfOpen(_) => end,
118            })
119        } else {
120            None
121        };
122
123        Ok(ParsedRange { start, end })
124    }
125}
126fn parse_litint_auto(lit: &LitInt) -> u64 {
127    let s = lit.to_string();
128    if let Some(hex) = s.strip_prefix("0x") {
129        u64::from_str_radix(hex, 16).unwrap()
130    } else if let Some(oct) = s.strip_prefix("0o") {
131        u64::from_str_radix(oct, 8).unwrap()
132    } else if let Some(bin) = s.strip_prefix("0b") {
133        u64::from_str_radix(bin, 2).unwrap()
134    } else {
135        s.parse::<u64>().unwrap()
136    }
137}