1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Error, Expr, ExprLit, ExprRange, ItemEnum, Lit, LitInt, Meta};
4
5#[proc_macro_attribute]
6pub fn range_enum(_: TokenStream, item: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(item as ItemEnum);
8 let mut generated_variants = Vec::default();
9
10 let vis = &input.vis;
11 let generics = &input.generics;
12 let enum_ident = &input.ident;
13
14 for variant in input.variants.iter() {
15 let mut range_variant = false;
16 for attr in &variant.attrs {
17 if attr.path().is_ident("range") {
18 let Meta::List(meta_list) = attr.meta.clone() else {
20 continue;
21 };
22
23 let Expr::Range(range) = syn::parse2::<Expr>(meta_list.tokens.clone()).unwrap()
24 else {
25 continue;
26 };
27
28 let range = match ParsedRange::try_new(range) {
29 Ok(r) => r,
30 Err(err) => return err.to_compile_error().into(),
31 };
32
33 let base = &variant.ident;
35
36 let (start, end) = match (range.start, range.end) {
37 (Some(start), Some(end)) => (start, end),
38 _ => unimplemented!("Currently only x..y and x..=y supported."),
39 };
40
41 for i in start..end {
42 let variant_name = syn::Ident::new(&format!("{}{}", base, i), base.span());
43 let fields = &variant.fields;
44 let discriminant = variant
45 .discriminant
46 .as_ref()
47 .map(|(_, expr)| quote! { = #expr });
48
49 generated_variants.push(quote! {
50 #variant_name #fields #discriminant,
51 });
52 }
53 range_variant = true;
54 break;
55 }
56 }
57
58 if !range_variant {
60 let variant_name = &variant.ident;
61 let fields = &variant.fields;
62 let discriminant = variant
63 .discriminant
64 .as_ref()
65 .map(|(_, expr)| quote! { = #expr });
66
67 generated_variants.push(quote! {
68 #variant_name #fields #discriminant,
69 });
70 }
71 }
72 let output = quote! {
73 #vis enum #enum_ident #generics {
74 #(#generated_variants)*
75 }
76 };
77 output.into()
78}
79
80#[derive(Copy, Clone, Debug)]
82struct ParsedRange {
83 start: Option<u64>,
84 end: Option<u64>,
85}
86impl ParsedRange {
87 fn try_new(range: ExprRange) -> Result<ParsedRange, Error> {
88 let start = match range.start.as_deref() {
89 Some(Expr::Lit(ExprLit {
90 lit: Lit::Int(i), ..
91 })) => Some(parse_litint_auto(i)),
92 Some(expr) => {
93 return Err(Error::new_spanned(
94 expr,
95 "Expected integer literal for range start.",
96 ))
97 }
98 _ => None,
99 };
100
101 let end_raw = match range.end.as_deref() {
102 Some(Expr::Lit(ExprLit {
103 lit: Lit::Int(i), ..
104 })) => Some(parse_litint_auto(i)),
105 Some(expr) => {
106 return Err(Error::new_spanned(
107 expr,
108 "Expected integer literal for range end.",
109 ))
110 }
111 _ => None,
112 };
113
114 let end = if let Some(end) = end_raw {
115 Some(match range.limits {
116 syn::RangeLimits::Closed(_) => end + 1,
117 syn::RangeLimits::HalfOpen(_) => end,
118 })
119 } else {
120 None
121 };
122
123 Ok(ParsedRange { start, end })
124 }
125}
126fn parse_litint_auto(lit: &LitInt) -> u64 {
127 let s = lit.to_string();
128 if let Some(hex) = s.strip_prefix("0x") {
129 u64::from_str_radix(hex, 16).unwrap()
130 } else if let Some(oct) = s.strip_prefix("0o") {
131 u64::from_str_radix(oct, 8).unwrap()
132 } else if let Some(bin) = s.strip_prefix("0b") {
133 u64::from_str_radix(bin, 2).unwrap()
134 } else {
135 s.parse::<u64>().unwrap()
136 }
137}