1extern crate proc_macro;
2
3use std::{collections::HashMap, str::FromStr};
4
5use proc_macro2::{Group, Literal, Span, TokenStream, TokenTree};
6
7use quote::{quote, ToTokens};
8use syn::{
9 parse::Parse, parse_macro_input, spanned::Spanned, Expr, ExprMatch, Ident, Pat, PatRange,
10 RangeLimits, Token,
11};
12
13#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy)]
14enum UType {
15 N,
16 P,
17 U,
18 False,
19 None,
20 Literal(isize),
21}
22
23impl std::fmt::Display for UType {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 UType::N => write!(f, "N"),
27 UType::P => write!(f, "P"),
28 UType::U => write!(f, "U"),
29 UType::False => write!(f, "False"),
30 UType::None => write!(f, ""),
31 UType::Literal(_) => write!(f, ""),
32 }
33 }
34}
35
36struct UNumIt {
37 range: Vec<isize>,
38 arms: HashMap<UType, Box<Expr>>,
39 expr: Box<Expr>,
40}
41
42fn range_boundary(val: &Option<Box<Expr>>) -> syn::Result<Option<isize>> {
43 if let Some(val) = val.clone() {
44 let string = val.to_token_stream().to_string().replace(' ', "");
45 let value = string
46 .parse::<isize>()
47 .map_err(|e| syn::Error::new(val.span(), format!("{e}: `{string}`").as_str()))?;
48
49 Ok(Some(value))
50 } else {
51 Ok(None)
52 }
53}
54
55impl Parse for UNumIt {
56 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
57 let range: PatRange = input.parse()?;
58
59 let start = range_boundary(&range.start)?.unwrap_or(0);
60 let end = range_boundary(&range.end)?.unwrap_or(isize::MAX);
61
62 let range = match &range.limits {
63 RangeLimits::HalfOpen(_) => (start..end).collect(),
64 RangeLimits::Closed(_) => (start..=end).collect(),
65 };
66
67 input.parse::<Token![,]>()?;
68 let matcher: ExprMatch = input.parse()?;
69
70 let mut arms = HashMap::new();
71
72 for arm in matcher.arms.iter() {
73 let u_type = match &arm.pat {
74 Pat::Ident(t) => match t.ident.to_token_stream().to_string().as_str() {
75 "N" => UType::N,
76 "P" => UType::P,
77 "U" => UType::U,
78 "False" => UType::False,
79 _ => {
80 return Err(syn::Error::new(
81 t.span(),
82 "exepected idents N | P | U, False or _",
83 ))
84 }
85 },
86 Pat::Lit(lit_expr) => {
87 let lit_str = lit_expr.to_token_stream().to_string();
89 let value = lit_str.parse::<isize>().map_err(|e| {
90 syn::Error::new(lit_expr.span(), format!("invalid literal: {e}"))
91 })?;
92 UType::Literal(value)
93 }
94 Pat::Wild(_) => UType::None,
95 _ => return Err(syn::Error::new(arm.pat.span(), "exepected ident")),
96 };
97 let arm_expr = arm.body.clone();
98 if arms.insert(u_type, arm_expr.clone()).is_some() {
99 return Err(syn::Error::new(arm_expr.span(), "duplicate type"));
100 }
101 }
102
103 if arms.get(&UType::P).and(arms.get(&UType::U)).is_some() {
104 return Err(syn::Error::new(
105 matcher.span(),
106 "ambiguous type, don't use P and U in the same macro call",
107 ));
108 }
109
110 if arms.get(&UType::Literal(0)).and(arms.get(&UType::False)).is_some() {
112 return Err(syn::Error::new(
113 matcher.span(),
114 "ambiguous type, don't use literal 0 and False in the same macro call (they represent the same value)",
115 ));
116 }
117
118 let expr = matcher.expr;
119
120 Ok(UNumIt { range, arms, expr })
121 }
122}
123
124fn make_body_variant(body: TokenStream, type_variant: TokenStream, u_type: UType) -> TokenStream {
125 let tokens = body.into_iter().fold(vec![], |mut acc, token| {
126 let type_variant = type_variant.clone();
127 match token {
128 TokenTree::Ident(ref ident) => {
129 if *ident == u_type.to_string() {
130 acc.extend(quote!(#type_variant).to_token_stream());
131 } else {
132 acc.push(token);
133 }
134 }
135 TokenTree::Group(ref group) => {
136 let inner = make_body_variant(group.stream(), type_variant, u_type);
137 acc.push(TokenTree::Group(Group::new(group.delimiter(), inner)));
138 }
139 _ => acc.push(token),
140 };
141 acc
142 });
143
144 quote! {#(#tokens)*}
145}
146
147fn make_match_arm(i: &isize, body: &Expr, u_type: UType) -> TokenStream {
148 let match_expr = TokenTree::Literal(Literal::from_str(i.to_string().as_str()).unwrap());
149
150 if let UType::Literal(_) = u_type {
152 let body_tokens = body.to_token_stream();
153 return quote! {
154 #match_expr => {
155 #body_tokens
156 },
157 };
158 }
159
160 let i_str = if *i != 0 {
162 i.abs().to_string()
163 } else {
164 Default::default()
165 };
166 let typenum_type = TokenTree::Ident(Ident::new(
167 format!("{}{}", u_type, i_str).as_str(),
168 Span::mixed_site(),
169 ));
170 let type_variant = quote!(typenum::consts::#typenum_type);
171 let body_variant = make_body_variant(body.to_token_stream(), type_variant, u_type);
172
173 quote! {
174 #match_expr => {
175 #body_variant
176 },
177 }
178}
179
180#[proc_macro]
200pub fn u_num_it(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
201 let UNumIt { range, arms, expr } = parse_macro_input!(tokens as UNumIt);
202
203 let pos_u = arms.get(&UType::U).is_some();
204
205 let expanded_arms = range.iter().filter_map(|i| {
206 if let Some(body) = arms.get(&UType::Literal(*i)) {
208 return Some(make_match_arm(i, body, UType::Literal(*i)));
209 }
210
211 match i {
213 0 => arms
214 .get(&UType::False)
215 .map(|body| make_match_arm(i, body, UType::False)),
216 i if *i < 0 => arms
217 .get(&UType::N)
218 .map(|body| make_match_arm(i, body, UType::N)),
219 i if *i > 0 => {
220 if pos_u {
221 arms.get(&UType::U)
222 .map(|body| make_match_arm(i, body, UType::U))
223 } else {
224 arms.get(&UType::P)
225 .map(|body| make_match_arm(i, body, UType::P))
226 }
227 }
228 _ => unreachable!(),
229 }
230 });
231
232 let fallback = arms
233 .get(&UType::None)
234 .map(|body| {
235 quote! {
236 _ => {
237 #body
238 }
239 }
240 })
241 .unwrap_or_else(|| {
242 let first = range.first().unwrap_or(&0);
243 let last = range.last().unwrap_or(&0);
244 quote! {
245 i => unreachable!("{i} is not in range {}-{:?}", #first, #last)
246 }
247 });
248
249 let expanded = quote! {
250 match #expr {
251 #(#expanded_arms)*
252 #fallback
253 }
254 };
255
256 proc_macro::TokenStream::from(expanded)
257}