1extern crate proc_macro;
2
3use std::collections::HashMap;
4
5use proc_macro2::{Span, TokenStream, TokenTree};
6
7use quote::{quote, ToTokens};
8use syn::{
9 parse::Parse, parse_macro_input, spanned::Spanned, Expr, ExprArray, ExprMatch, Ident, Pat,
10 PatRange, RangeLimits, Token,
11};
12
13#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy)]
14enum UType {
15 N,
16 P,
17 U,
18 False,
19 None,
20 Literal(isize),
21}
22
23impl std::fmt::Display for UType {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 UType::N => write!(f, "N"),
27 UType::P => write!(f, "P"),
28 UType::U => write!(f, "U"),
29 UType::False => write!(f, "False"),
30 UType::None => write!(f, ""),
31 UType::Literal(_) => write!(f, ""),
32 }
33 }
34}
35
36struct UNumIt {
37 range: Vec<isize>,
38 arms: HashMap<UType, Box<Expr>>,
39 expr: Box<Expr>,
40}
41
42fn range_boundary(val: &Option<Box<Expr>>) -> syn::Result<Option<isize>> {
43 if let Some(val) = val.clone() {
44 let string = val.to_token_stream().to_string().replace(' ', "");
45 let value = string
46 .parse::<isize>()
47 .map_err(|e| syn::Error::new(val.span(), format!("{e}: `{string}`").as_str()))?;
48
49 Ok(Some(value))
50 } else {
51 Ok(None)
52 }
53}
54
55impl Parse for UNumIt {
56 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
57 let range: Vec<isize> = if input.peek(syn::token::Bracket) {
59 let array: ExprArray = input.parse()?;
61 let mut vals = array
62 .elems
63 .iter()
64 .map(|expr| {
65 let raw = expr.to_token_stream().to_string();
66 let norm = raw.replace([' ', '_'], "");
67 norm.parse::<isize>().map_err(|e| {
68 syn::Error::new(
69 expr.span(),
70 format!("invalid number in array: {e}: `{raw}` (normalized `{norm}`)"),
71 )
72 })
73 })
74 .collect::<syn::Result<Vec<isize>>>()?;
75 vals.sort();
76 vals.dedup();
77 vals
78 } else {
79 let range: PatRange = input.parse()?;
81 let start = range_boundary(&range.start)?.unwrap_or(0);
82 let end = range_boundary(&range.end)?.unwrap_or(isize::MAX);
83 match &range.limits {
84 RangeLimits::HalfOpen(_) => (start..end).collect(),
85 RangeLimits::Closed(_) => (start..=end).collect(),
86 }
87 };
88
89 input.parse::<Token![,]>()?;
90 let matcher: ExprMatch = input.parse()?;
91
92 let mut arms = HashMap::new();
93
94 for arm in matcher.arms.iter() {
95 let u_type = match &arm.pat {
96 Pat::Ident(t) => match t.ident.to_token_stream().to_string().as_str() {
97 "N" => UType::N,
98 "P" => UType::P,
99 "U" => UType::U,
100 "False" => UType::False,
101 _ => {
102 return Err(syn::Error::new(
103 t.span(),
104 "expected idents N | P | U | False | _",
105 ))
106 }
107 },
108 Pat::Lit(lit_expr) => {
109 let raw = lit_expr.to_token_stream().to_string();
111 let norm = raw.replace([' ', '_'], "");
112 if norm.starts_with("0x") || norm.starts_with("0b") || norm.starts_with("0o") {
113 return Err(syn::Error::new(
114 lit_expr.span(),
115 format!("unsupported non-decimal literal `{raw}`"),
116 ));
117 }
118 let value = norm.parse::<isize>().map_err(|e| {
119 syn::Error::new(
120 lit_expr.span(),
121 format!("invalid literal: {e}: `{raw}` (normalized `{norm}`)"),
122 )
123 })?;
124 UType::Literal(value)
125 }
126 Pat::Wild(_) => UType::None,
127 _ => return Err(syn::Error::new(arm.pat.span(), "expected ident")),
128 };
129 let arm_expr = arm.body.clone();
130 if arms.insert(u_type, arm_expr.clone()).is_some() {
131 return Err(syn::Error::new(arm_expr.span(), "duplicate type"));
132 }
133 }
134
135 if arms.get(&UType::P).and(arms.get(&UType::U)).is_some() {
136 return Err(syn::Error::new(
137 matcher.span(),
138 "ambiguous type, don't use P and U in the same macro call",
139 ));
140 }
141
142 if arms
144 .get(&UType::Literal(0))
145 .and(arms.get(&UType::False))
146 .is_some()
147 {
148 return Err(syn::Error::new(
149 matcher.span(),
150 "ambiguous type, don't use literal 0 and False in the same macro call (they represent the same value)",
151 ));
152 }
153
154 let expr = matcher.expr;
155
156 Ok(UNumIt { range, arms, expr })
157 }
158}
159
160fn make_match_arm(i: &isize, body: &Expr, u_type: UType) -> TokenStream {
161 let match_expr = quote!(#i);
162
163 let i_str = if *i != 0 {
165 i.abs().to_string()
166 } else {
167 Default::default()
168 };
169
170 let u_type_for_typenum = match u_type {
172 UType::Literal(0) => UType::False,
173 UType::Literal(val) if val < 0 => UType::N,
174 UType::Literal(val) if val > 0 => UType::P,
175 _ => u_type,
176 };
177
178 let typenum_type = TokenTree::Ident(Ident::new(
179 format!("{}{}", u_type_for_typenum, i_str).as_str(),
180 Span::mixed_site(),
181 ));
182 let type_variant = quote!(typenum::consts::#typenum_type);
183
184 let body_tokens = body.to_token_stream();
186
187 quote! {
188 #match_expr => {
189 type NumType = #type_variant;
190 #body_tokens
191 },
192 }
193}
194
195#[proc_macro]
254pub fn u_num_it(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
255 let UNumIt { range, arms, expr } = parse_macro_input!(tokens as UNumIt);
256
257 let pos_u = arms.contains_key(&UType::U);
258
259 let expanded_arms = range.iter().filter_map(|i| {
260 if let Some(body) = arms.get(&UType::Literal(*i)) {
262 return Some(make_match_arm(i, body, UType::Literal(*i)));
263 }
264
265 match i {
267 0 => arms
268 .get(&UType::False)
269 .map(|body| make_match_arm(i, body, UType::False)),
270 i if *i < 0 => arms
271 .get(&UType::N)
272 .map(|body| make_match_arm(i, body, UType::N)),
273 i if *i > 0 => {
274 if pos_u {
275 arms.get(&UType::U)
276 .map(|body| make_match_arm(i, body, UType::U))
277 } else {
278 arms.get(&UType::P)
279 .map(|body| make_match_arm(i, body, UType::P))
280 }
281 }
282 _ => unreachable!(),
283 }
284 });
285
286 let fallback = arms
287 .get(&UType::None)
288 .map(|body| {
289 quote! {
290 _ => {
291 #body
292 },
293 }
294 })
295 .unwrap_or_else(|| {
296 let first = range.first().unwrap_or(&0);
297 let last = range.last().unwrap_or(&0);
298 quote! {
299 i => unreachable!("{i} not in range {}..={}", #first, #last),
300 }
301 });
302
303 let expanded = quote! {
304 match #expr {
305 #(#expanded_arms)*
306 #fallback
307 }
308 };
309
310 proc_macro::TokenStream::from(expanded)
311}