1extern crate proc_macro;
2
3use std::{collections::HashMap, str::FromStr};
4
5use proc_macro2::{Literal, Span, TokenStream, TokenTree};
6
7use quote::{quote, ToTokens};
8use syn::{
9 parse::Parse, parse_macro_input, spanned::Spanned, Expr, ExprArray, ExprMatch, Ident, Pat, PatRange,
10 RangeLimits, Token,
11};
12
13#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy)]
14enum UType {
15 N,
16 P,
17 U,
18 False,
19 None,
20 Literal(isize),
21}
22
23impl std::fmt::Display for UType {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 UType::N => write!(f, "N"),
27 UType::P => write!(f, "P"),
28 UType::U => write!(f, "U"),
29 UType::False => write!(f, "False"),
30 UType::None => write!(f, ""),
31 UType::Literal(_) => write!(f, ""),
32 }
33 }
34}
35
36struct UNumIt {
37 range: Vec<isize>,
38 arms: HashMap<UType, Box<Expr>>,
39 expr: Box<Expr>,
40}
41
42fn range_boundary(val: &Option<Box<Expr>>) -> syn::Result<Option<isize>> {
43 if let Some(val) = val.clone() {
44 let string = val.to_token_stream().to_string().replace(' ', "");
45 let value = string
46 .parse::<isize>()
47 .map_err(|e| syn::Error::new(val.span(), format!("{e}: `{string}`").as_str()))?;
48
49 Ok(Some(value))
50 } else {
51 Ok(None)
52 }
53}
54
55impl Parse for UNumIt {
56 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
57 let range: Vec<isize> = if input.peek(syn::token::Bracket) {
59 let array: ExprArray = input.parse()?;
61 array.elems.iter().map(|expr| {
62 let string = expr.to_token_stream().to_string().replace(' ', "");
63 string.parse::<isize>()
64 .map_err(|e| syn::Error::new(expr.span(), format!("invalid number in array: {e}")))
65 }).collect::<syn::Result<Vec<isize>>>()?
66 } else {
67 let range: PatRange = input.parse()?;
69 let start = range_boundary(&range.start)?.unwrap_or(0);
70 let end = range_boundary(&range.end)?.unwrap_or(isize::MAX);
71 match &range.limits {
72 RangeLimits::HalfOpen(_) => (start..end).collect(),
73 RangeLimits::Closed(_) => (start..=end).collect(),
74 }
75 };
76
77 input.parse::<Token![,]>()?;
78 let matcher: ExprMatch = input.parse()?;
79
80 let mut arms = HashMap::new();
81
82 for arm in matcher.arms.iter() {
83 let u_type = match &arm.pat {
84 Pat::Ident(t) => match t.ident.to_token_stream().to_string().as_str() {
85 "N" => UType::N,
86 "P" => UType::P,
87 "U" => UType::U,
88 "False" => UType::False,
89 _ => {
90 return Err(syn::Error::new(
91 t.span(),
92 "exepected idents N | P | U, False or _",
93 ))
94 }
95 },
96 Pat::Lit(lit_expr) => {
97 let lit_str = lit_expr.to_token_stream().to_string();
99 let value = lit_str.parse::<isize>().map_err(|e| {
100 syn::Error::new(lit_expr.span(), format!("invalid literal: {e}"))
101 })?;
102 UType::Literal(value)
103 }
104 Pat::Wild(_) => UType::None,
105 _ => return Err(syn::Error::new(arm.pat.span(), "exepected ident")),
106 };
107 let arm_expr = arm.body.clone();
108 if arms.insert(u_type, arm_expr.clone()).is_some() {
109 return Err(syn::Error::new(arm_expr.span(), "duplicate type"));
110 }
111 }
112
113 if arms.get(&UType::P).and(arms.get(&UType::U)).is_some() {
114 return Err(syn::Error::new(
115 matcher.span(),
116 "ambiguous type, don't use P and U in the same macro call",
117 ));
118 }
119
120 if arms.get(&UType::Literal(0)).and(arms.get(&UType::False)).is_some() {
122 return Err(syn::Error::new(
123 matcher.span(),
124 "ambiguous type, don't use literal 0 and False in the same macro call (they represent the same value)",
125 ));
126 }
127
128 let expr = matcher.expr;
129
130 Ok(UNumIt { range, arms, expr })
131 }
132}
133
134fn make_match_arm(i: &isize, body: &Expr, u_type: UType) -> TokenStream {
135 let match_expr = TokenTree::Literal(Literal::from_str(i.to_string().as_str()).unwrap());
136
137 let i_str = if *i != 0 {
139 i.abs().to_string()
140 } else {
141 Default::default()
142 };
143
144 let u_type_for_typenum = match u_type {
146 UType::Literal(val) if val == 0 => UType::False,
147 UType::Literal(val) if val < 0 => UType::N,
148 UType::Literal(val) if val > 0 => UType::P,
149 _ => u_type,
150 };
151
152 let typenum_type = TokenTree::Ident(Ident::new(
153 format!("{}{}", u_type_for_typenum, i_str).as_str(),
154 Span::mixed_site(),
155 ));
156 let type_variant = quote!(typenum::consts::#typenum_type);
157
158 let body_tokens = body.to_token_stream();
160
161 quote! {
162 #match_expr => {
163 type NumType = #type_variant;
164 #body_tokens
165 },
166 }
167}
168
169#[proc_macro]
213pub fn u_num_it(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
214 let UNumIt { range, arms, expr } = parse_macro_input!(tokens as UNumIt);
215
216 let pos_u = arms.get(&UType::U).is_some();
217
218 let expanded_arms = range.iter().filter_map(|i| {
219 if let Some(body) = arms.get(&UType::Literal(*i)) {
221 return Some(make_match_arm(i, body, UType::Literal(*i)));
222 }
223
224 match i {
226 0 => arms
227 .get(&UType::False)
228 .map(|body| make_match_arm(i, body, UType::False)),
229 i if *i < 0 => arms
230 .get(&UType::N)
231 .map(|body| make_match_arm(i, body, UType::N)),
232 i if *i > 0 => {
233 if pos_u {
234 arms.get(&UType::U)
235 .map(|body| make_match_arm(i, body, UType::U))
236 } else {
237 arms.get(&UType::P)
238 .map(|body| make_match_arm(i, body, UType::P))
239 }
240 }
241 _ => unreachable!(),
242 }
243 });
244
245 let fallback = arms
246 .get(&UType::None)
247 .map(|body| {
248 quote! {
249 _ => {
250 #body
251 }
252 }
253 })
254 .unwrap_or_else(|| {
255 let first = range.first().unwrap_or(&0);
256 let last = range.last().unwrap_or(&0);
257 quote! {
258 i => unreachable!("{i} is not in range {}-{:?}", #first, #last)
259 }
260 });
261
262 let expanded = quote! {
263 match #expr {
264 #(#expanded_arms)*
265 #fallback
266 }
267 };
268
269 proc_macro::TokenStream::from(expanded)
270}