u_num_it/
lib.rs

1#![doc = include_str!("../README.md")]
2
3extern crate proc_macro;
4
5use std::{collections::HashMap, str::FromStr};
6
7use proc_macro2::{Literal, Span, TokenStream, TokenTree};
8
9use quote::{quote, ToTokens};
10use syn::{
11    parse::Parse, parse_macro_input, spanned::Spanned, Expr, ExprArray, ExprMatch, Ident, Pat,
12    PatRange, RangeLimits, Token,
13};
14
15/// matches `typenum::consts` in a given range or array
16///
17/// use with an open or closed range, or an array of arbitrary numbers
18///
19/// use `P` | `N` | `U` | `False` | `_` or literals `1` | `-1` as match arms
20///
21/// a `NumType` type alias is available in each match arm,
22/// resolving to the specific typenum type for that value.
23/// Use `NumType` to reference the resolved type in the match arm body.
24///
25/// ## Example (range)
26///
27/// ```
28/// let x = 3;
29///
30/// u_num_it::u_num_it!(1..10, match x {
31///     U => {
32///         // NumType is typenum::consts::U3 when x=3
33///         let val = NumType::new();
34///         println!("{:?}", val);
35///         // UInt { msb: UInt { msb: UTerm, lsb: B1 }, lsb: B1 }
36///
37///         use typenum::ToInt;
38///         let num: usize = NumType::to_int();
39///         assert_eq!(num, 3);
40///     }
41/// })
42/// ```
43///
44/// ## Example (array)
45///
46/// ```
47/// let x = 8;
48///
49/// u_num_it::u_num_it!([1, 2, 8, 22], match x {
50///     P => {
51///         // NumType is typenum::consts::P8 when x=8
52///         use typenum::ToInt;
53///         let num: i32 = NumType::to_int();
54///         assert_eq!(num, 8);
55///     }
56/// })
57/// ```
58///
59/// ## Example (negative literal)
60/// ```
61/// let result = u_num_it::u_num_it!(-5..=5, match -3 {
62///     -3 => {
63///         use typenum::ToInt;
64///         let n: i32 = NumType::to_int();
65///         assert_eq!(n, -3);
66///         "ok"
67///     },
68///     N => "neg",
69///     _ => "other"
70/// });
71/// assert_eq!(result, "ok");
72/// ```
73#[proc_macro]
74pub fn u_num_it(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
75    let UNumIt { range, arms, expr } = parse_macro_input!(tokens as UNumIt);
76
77    let pos_u = arms.contains_key(&UType::U);
78
79    let expanded_arms = range.iter().filter_map(|i| {
80        // First check if there's a specific literal match for this number
81        if let Some(body) = arms.get(&UType::Literal(*i)) {
82            return Some(make_match_arm(i, body, UType::Literal(*i)));
83        }
84
85        // Otherwise, use the general type patterns
86        match i {
87            0 => arms
88                .get(&UType::False)
89                .map(|body| make_match_arm(i, body, UType::False)),
90            i if *i < 0 => arms
91                .get(&UType::N)
92                .map(|body| make_match_arm(i, body, UType::N)),
93            i if *i > 0 => {
94                if pos_u {
95                    arms.get(&UType::U)
96                        .map(|body| make_match_arm(i, body, UType::U))
97                } else {
98                    arms.get(&UType::P)
99                        .map(|body| make_match_arm(i, body, UType::P))
100                }
101            }
102            _ => unreachable!(),
103        }
104    });
105
106    let fallback = arms
107        .get(&UType::None)
108        .map(|body| {
109            quote! {
110                _ => {
111                    #body
112                },
113            }
114        })
115        .unwrap_or_else(|| {
116            let first = range.first().unwrap_or(&0);
117            let last = range.last().unwrap_or(&0);
118            quote! {
119                i => unreachable!("{i} not in range {}..={}", #first, #last),
120            }
121        });
122
123    let expanded = quote! {
124        match #expr {
125            #(#expanded_arms)*
126            #fallback
127        }
128    };
129
130    proc_macro::TokenStream::from(expanded)
131}
132
133#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy)]
134enum UType {
135    N,
136    P,
137    U,
138    False,
139    None,
140    Literal(isize),
141}
142
143impl std::fmt::Display for UType {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        match self {
146            UType::N => write!(f, "N"),
147            UType::P => write!(f, "P"),
148            UType::U => write!(f, "U"),
149            UType::False => write!(f, "False"),
150            UType::None => write!(f, ""),
151            UType::Literal(_) => write!(f, ""),
152        }
153    }
154}
155
156struct UNumIt {
157    range: Vec<isize>,
158    arms: HashMap<UType, Box<Expr>>,
159    expr: Box<Expr>,
160}
161
162fn range_boundary(val: &Option<Box<Expr>>) -> syn::Result<Option<isize>> {
163    if let Some(val) = val.clone() {
164        let string = val.to_token_stream().to_string().replace(' ', "");
165        let value = string
166            .parse::<isize>()
167            .map_err(|e| syn::Error::new(val.span(), format!("{e}: `{string}`").as_str()))?;
168
169        Ok(Some(value))
170    } else {
171        Ok(None)
172    }
173}
174
175impl Parse for UNumIt {
176    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
177        // Try to parse as array first, then fallback to range
178        let range: Vec<isize> = if input.peek(syn::token::Bracket) {
179            // Parse array syntax: [1, 2, 8, 22]
180            let array: ExprArray = input.parse()?;
181            let mut vals = array
182                .elems
183                .iter()
184                .map(|expr| {
185                    let raw = expr.to_token_stream().to_string();
186                    let norm = raw.replace([' ', '_'], "");
187                    norm.parse::<isize>().map_err(|e| {
188                        syn::Error::new(
189                            expr.span(),
190                            format!("invalid number in array: {e}: `{raw}` (normalized `{norm}`)"),
191                        )
192                    })
193                })
194                .collect::<syn::Result<Vec<isize>>>()?;
195            vals.sort();
196            vals.dedup();
197            vals
198        } else {
199            // Parse range syntax: 1..10 or 1..=10
200            let range: PatRange = input.parse()?;
201            let start = range_boundary(&range.start)?.unwrap_or(0);
202            let end = range_boundary(&range.end)?.unwrap_or(isize::MAX);
203            match &range.limits {
204                RangeLimits::HalfOpen(_) => (start..end).collect(),
205                RangeLimits::Closed(_) => (start..=end).collect(),
206            }
207        };
208
209        input.parse::<Token![,]>()?;
210        let matcher: ExprMatch = input.parse()?;
211
212        let mut arms = HashMap::new();
213
214        for arm in matcher.arms.iter() {
215            let u_type = match &arm.pat {
216                Pat::Ident(t) => match t.ident.to_token_stream().to_string().as_str() {
217                    "N" => UType::N,
218                    "P" => UType::P,
219                    "U" => UType::U,
220                    "False" => UType::False,
221                    _ => {
222                        return Err(syn::Error::new(
223                            t.span(),
224                            "expected idents N | P | U | False | _",
225                        ))
226                    }
227                },
228                Pat::Lit(lit_expr) => {
229                    // Parse literal numbers in match arms (normalize spaces & underscores; base-10 only)
230                    let raw = lit_expr.to_token_stream().to_string();
231                    let norm = raw.replace([' ', '_'], "");
232                    if norm.starts_with("0x") || norm.starts_with("0b") || norm.starts_with("0o") {
233                        return Err(syn::Error::new(
234                            lit_expr.span(),
235                            format!("unsupported non-decimal literal `{raw}`"),
236                        ));
237                    }
238                    let value = norm.parse::<isize>().map_err(|e| {
239                        syn::Error::new(
240                            lit_expr.span(),
241                            format!("invalid literal: {e}: `{raw}` (normalized `{norm}`)"),
242                        )
243                    })?;
244                    UType::Literal(value)
245                }
246                Pat::Wild(_) => UType::None,
247                _ => return Err(syn::Error::new(arm.pat.span(), "expected ident")),
248            };
249            let arm_expr = arm.body.clone();
250            if arms.insert(u_type, arm_expr.clone()).is_some() {
251                return Err(syn::Error::new(arm_expr.span(), "duplicate type"));
252            }
253        }
254
255        if arms.get(&UType::P).and(arms.get(&UType::U)).is_some() {
256            return Err(syn::Error::new(
257                matcher.span(),
258                "ambiguous type, don't use P and U in the same macro call",
259            ));
260        }
261
262        // Check for conflict between literal 0 and False (they represent the same value in typenum)
263        if arms
264            .get(&UType::Literal(0))
265            .and(arms.get(&UType::False))
266            .is_some()
267        {
268            return Err(syn::Error::new(
269                matcher.span(),
270                "ambiguous type, don't use literal 0 and False in the same macro call (they represent the same value)",
271            ));
272        }
273
274        let expr = matcher.expr;
275
276        Ok(UNumIt { range, arms, expr })
277    }
278}
279
280fn make_match_arm(i: &isize, body: &Expr, u_type: UType) -> TokenStream {
281    let match_expr = TokenTree::Literal(Literal::from_str(i.to_string().as_str()).unwrap());
282
283    // Determine the typenum type for all cases
284    let i_str = if *i != 0 {
285        i.abs().to_string()
286    } else {
287        Default::default()
288    };
289
290    // Determine the type variant based on UType
291    let u_type_for_typenum = match u_type {
292        UType::Literal(0) => UType::False,
293        UType::Literal(val) if val < 0 => UType::N,
294        UType::Literal(val) if val > 0 => UType::P,
295        _ => u_type,
296    };
297
298    let typenum_type = TokenTree::Ident(Ident::new(
299        format!("{}{}", u_type_for_typenum, i_str).as_str(),
300        Span::mixed_site(),
301    ));
302    let type_variant = quote!(typenum::consts::#typenum_type);
303
304    // All match arms get NumType and use body as-is (no pattern replacement)
305    let body_tokens = body.to_token_stream();
306
307    quote! {
308        #match_expr => {
309            type NumType = #type_variant;
310            #body_tokens
311        },
312    }
313}