1#![doc = include_str!("../README.md")]
2
3extern crate proc_macro;
4
5use std::{collections::HashMap, str::FromStr};
6
7use proc_macro2::{Literal, Span, TokenStream, TokenTree};
8
9use quote::{quote, ToTokens};
10use syn::{
11 parse::Parse, parse_macro_input, spanned::Spanned, Expr, ExprArray, ExprMatch, Ident, Pat,
12 PatRange, RangeLimits, Token,
13};
14
15#[proc_macro]
74pub fn u_num_it(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
75 let UNumIt { range, arms, expr } = parse_macro_input!(tokens as UNumIt);
76
77 let pos_u = arms.contains_key(&UType::U);
78
79 let expanded_arms = range.iter().filter_map(|i| {
80 if let Some(body) = arms.get(&UType::Literal(*i)) {
82 return Some(make_match_arm(i, body, UType::Literal(*i)));
83 }
84
85 match i {
87 0 => arms
88 .get(&UType::False)
89 .map(|body| make_match_arm(i, body, UType::False)),
90 i if *i < 0 => arms
91 .get(&UType::N)
92 .map(|body| make_match_arm(i, body, UType::N)),
93 i if *i > 0 => {
94 if pos_u {
95 arms.get(&UType::U)
96 .map(|body| make_match_arm(i, body, UType::U))
97 } else {
98 arms.get(&UType::P)
99 .map(|body| make_match_arm(i, body, UType::P))
100 }
101 }
102 _ => unreachable!(),
103 }
104 });
105
106 let fallback = arms
107 .get(&UType::None)
108 .map(|body| {
109 quote! {
110 _ => {
111 #body
112 },
113 }
114 })
115 .unwrap_or_else(|| {
116 let first = range.first().unwrap_or(&0);
117 let last = range.last().unwrap_or(&0);
118 quote! {
119 i => unreachable!("{i} not in range {}..={}", #first, #last),
120 }
121 });
122
123 let expanded = quote! {
124 match #expr {
125 #(#expanded_arms)*
126 #fallback
127 }
128 };
129
130 proc_macro::TokenStream::from(expanded)
131}
132
133#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy)]
134enum UType {
135 N,
136 P,
137 U,
138 False,
139 None,
140 Literal(isize),
141}
142
143impl std::fmt::Display for UType {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 match self {
146 UType::N => write!(f, "N"),
147 UType::P => write!(f, "P"),
148 UType::U => write!(f, "U"),
149 UType::False => write!(f, "False"),
150 UType::None => write!(f, ""),
151 UType::Literal(_) => write!(f, ""),
152 }
153 }
154}
155
156struct UNumIt {
157 range: Vec<isize>,
158 arms: HashMap<UType, Box<Expr>>,
159 expr: Box<Expr>,
160}
161
162fn range_boundary(val: &Option<Box<Expr>>) -> syn::Result<Option<isize>> {
163 if let Some(val) = val.clone() {
164 let string = val.to_token_stream().to_string().replace(' ', "");
165 let value = string
166 .parse::<isize>()
167 .map_err(|e| syn::Error::new(val.span(), format!("{e}: `{string}`").as_str()))?;
168
169 Ok(Some(value))
170 } else {
171 Ok(None)
172 }
173}
174
175impl Parse for UNumIt {
176 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
177 let range: Vec<isize> = if input.peek(syn::token::Bracket) {
179 let array: ExprArray = input.parse()?;
181 let mut vals = array
182 .elems
183 .iter()
184 .map(|expr| {
185 let raw = expr.to_token_stream().to_string();
186 let norm = raw.replace([' ', '_'], "");
187 norm.parse::<isize>().map_err(|e| {
188 syn::Error::new(
189 expr.span(),
190 format!("invalid number in array: {e}: `{raw}` (normalized `{norm}`)"),
191 )
192 })
193 })
194 .collect::<syn::Result<Vec<isize>>>()?;
195 vals.sort();
196 vals.dedup();
197 vals
198 } else {
199 let range: PatRange = input.parse()?;
201 let start = range_boundary(&range.start)?.unwrap_or(0);
202 let end = range_boundary(&range.end)?.unwrap_or(isize::MAX);
203 match &range.limits {
204 RangeLimits::HalfOpen(_) => (start..end).collect(),
205 RangeLimits::Closed(_) => (start..=end).collect(),
206 }
207 };
208
209 input.parse::<Token![,]>()?;
210 let matcher: ExprMatch = input.parse()?;
211
212 let mut arms = HashMap::new();
213
214 for arm in matcher.arms.iter() {
215 let u_type = match &arm.pat {
216 Pat::Ident(t) => match t.ident.to_token_stream().to_string().as_str() {
217 "N" => UType::N,
218 "P" => UType::P,
219 "U" => UType::U,
220 "False" => UType::False,
221 _ => {
222 return Err(syn::Error::new(
223 t.span(),
224 "expected idents N | P | U | False | _",
225 ))
226 }
227 },
228 Pat::Lit(lit_expr) => {
229 let raw = lit_expr.to_token_stream().to_string();
231 let norm = raw.replace([' ', '_'], "");
232 if norm.starts_with("0x") || norm.starts_with("0b") || norm.starts_with("0o") {
233 return Err(syn::Error::new(
234 lit_expr.span(),
235 format!("unsupported non-decimal literal `{raw}`"),
236 ));
237 }
238 let value = norm.parse::<isize>().map_err(|e| {
239 syn::Error::new(
240 lit_expr.span(),
241 format!("invalid literal: {e}: `{raw}` (normalized `{norm}`)"),
242 )
243 })?;
244 UType::Literal(value)
245 }
246 Pat::Wild(_) => UType::None,
247 _ => return Err(syn::Error::new(arm.pat.span(), "expected ident")),
248 };
249 let arm_expr = arm.body.clone();
250 if arms.insert(u_type, arm_expr.clone()).is_some() {
251 return Err(syn::Error::new(arm_expr.span(), "duplicate type"));
252 }
253 }
254
255 if arms.get(&UType::P).and(arms.get(&UType::U)).is_some() {
256 return Err(syn::Error::new(
257 matcher.span(),
258 "ambiguous type, don't use P and U in the same macro call",
259 ));
260 }
261
262 if arms
264 .get(&UType::Literal(0))
265 .and(arms.get(&UType::False))
266 .is_some()
267 {
268 return Err(syn::Error::new(
269 matcher.span(),
270 "ambiguous type, don't use literal 0 and False in the same macro call (they represent the same value)",
271 ));
272 }
273
274 let expr = matcher.expr;
275
276 Ok(UNumIt { range, arms, expr })
277 }
278}
279
280fn make_match_arm(i: &isize, body: &Expr, u_type: UType) -> TokenStream {
281 let match_expr = TokenTree::Literal(Literal::from_str(i.to_string().as_str()).unwrap());
282
283 let i_str = if *i != 0 {
285 i.abs().to_string()
286 } else {
287 Default::default()
288 };
289
290 let u_type_for_typenum = match u_type {
292 UType::Literal(0) => UType::False,
293 UType::Literal(val) if val < 0 => UType::N,
294 UType::Literal(val) if val > 0 => UType::P,
295 _ => u_type,
296 };
297
298 let typenum_type = TokenTree::Ident(Ident::new(
299 format!("{}{}", u_type_for_typenum, i_str).as_str(),
300 Span::mixed_site(),
301 ));
302 let type_variant = quote!(typenum::consts::#typenum_type);
303
304 let body_tokens = body.to_token_stream();
306
307 quote! {
308 #match_expr => {
309 type NumType = #type_variant;
310 #body_tokens
311 },
312 }
313}