u_num_it/
lib.rs

1extern crate proc_macro;
2
3use std::{collections::HashMap, str::FromStr};
4
5use proc_macro2::{Literal, Span, TokenStream, TokenTree};
6
7use quote::{quote, ToTokens};
8use syn::{
9    parse::Parse, parse_macro_input, spanned::Spanned, Expr, ExprArray, ExprMatch, Ident, Pat,
10    PatRange, RangeLimits, Token,
11};
12
13#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy)]
14enum UType {
15    N,
16    P,
17    U,
18    False,
19    None,
20    Literal(isize),
21}
22
23impl std::fmt::Display for UType {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            UType::N => write!(f, "N"),
27            UType::P => write!(f, "P"),
28            UType::U => write!(f, "U"),
29            UType::False => write!(f, "False"),
30            UType::None => write!(f, ""),
31            UType::Literal(_) => write!(f, ""),
32        }
33    }
34}
35
36struct UNumIt {
37    range: Vec<isize>,
38    arms: HashMap<UType, Box<Expr>>,
39    expr: Box<Expr>,
40}
41
42fn range_boundary(val: &Option<Box<Expr>>) -> syn::Result<Option<isize>> {
43    if let Some(val) = val.clone() {
44        let string = val.to_token_stream().to_string().replace(' ', "");
45        let value = string
46            .parse::<isize>()
47            .map_err(|e| syn::Error::new(val.span(), format!("{e}: `{string}`").as_str()))?;
48
49        Ok(Some(value))
50    } else {
51        Ok(None)
52    }
53}
54
55impl Parse for UNumIt {
56    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
57        // Try to parse as array first, then fallback to range
58        let range: Vec<isize> = if input.peek(syn::token::Bracket) {
59            // Parse array syntax: [1, 2, 8, 22]
60            let array: ExprArray = input.parse()?;
61            array
62                .elems
63                .iter()
64                .map(|expr| {
65                    let string = expr.to_token_stream().to_string().replace(' ', "");
66                    string.parse::<isize>().map_err(|e| {
67                        syn::Error::new(expr.span(), format!("invalid number in array: {e}"))
68                    })
69                })
70                .collect::<syn::Result<Vec<isize>>>()?
71        } else {
72            // Parse range syntax: 1..10 or 1..=10
73            let range: PatRange = input.parse()?;
74            let start = range_boundary(&range.start)?.unwrap_or(0);
75            let end = range_boundary(&range.end)?.unwrap_or(isize::MAX);
76            match &range.limits {
77                RangeLimits::HalfOpen(_) => (start..end).collect(),
78                RangeLimits::Closed(_) => (start..=end).collect(),
79            }
80        };
81
82        input.parse::<Token![,]>()?;
83        let matcher: ExprMatch = input.parse()?;
84
85        let mut arms = HashMap::new();
86
87        for arm in matcher.arms.iter() {
88            let u_type = match &arm.pat {
89                Pat::Ident(t) => match t.ident.to_token_stream().to_string().as_str() {
90                    "N" => UType::N,
91                    "P" => UType::P,
92                    "U" => UType::U,
93                    "False" => UType::False,
94                    _ => {
95                        return Err(syn::Error::new(
96                            t.span(),
97                            "exepected idents N | P | U, False or _",
98                        ))
99                    }
100                },
101                Pat::Lit(lit_expr) => {
102                    // Parse literal numbers in match arms
103                    let lit_str = lit_expr.to_token_stream().to_string();
104                    let value = lit_str.parse::<isize>().map_err(|e| {
105                        syn::Error::new(lit_expr.span(), format!("invalid literal: {e}"))
106                    })?;
107                    UType::Literal(value)
108                }
109                Pat::Wild(_) => UType::None,
110                _ => return Err(syn::Error::new(arm.pat.span(), "exepected ident")),
111            };
112            let arm_expr = arm.body.clone();
113            if arms.insert(u_type, arm_expr.clone()).is_some() {
114                return Err(syn::Error::new(arm_expr.span(), "duplicate type"));
115            }
116        }
117
118        if arms.get(&UType::P).and(arms.get(&UType::U)).is_some() {
119            return Err(syn::Error::new(
120                matcher.span(),
121                "ambiguous type, don't use P and U in the same macro call",
122            ));
123        }
124
125        // Check for conflict between literal 0 and False (they represent the same value in typenum)
126        if arms
127            .get(&UType::Literal(0))
128            .and(arms.get(&UType::False))
129            .is_some()
130        {
131            return Err(syn::Error::new(
132                matcher.span(),
133                "ambiguous type, don't use literal 0 and False in the same macro call (they represent the same value)",
134            ));
135        }
136
137        let expr = matcher.expr;
138
139        Ok(UNumIt { range, arms, expr })
140    }
141}
142
143fn make_match_arm(i: &isize, body: &Expr, u_type: UType) -> TokenStream {
144    let match_expr = TokenTree::Literal(Literal::from_str(i.to_string().as_str()).unwrap());
145
146    // Determine the typenum type for all cases
147    let i_str = if *i != 0 {
148        i.abs().to_string()
149    } else {
150        Default::default()
151    };
152
153    // Determine the type variant based on UType
154    let u_type_for_typenum = match u_type {
155        UType::Literal(0) => UType::False,
156        UType::Literal(val) if val < 0 => UType::N,
157        UType::Literal(val) if val > 0 => UType::P,
158        _ => u_type,
159    };
160
161    let typenum_type = TokenTree::Ident(Ident::new(
162        format!("{}{}", u_type_for_typenum, i_str).as_str(),
163        Span::mixed_site(),
164    ));
165    let type_variant = quote!(typenum::consts::#typenum_type);
166
167    // All match arms get NumType and use body as-is (no pattern replacement)
168    let body_tokens = body.to_token_stream();
169
170    quote! {
171        #match_expr => {
172            type NumType = #type_variant;
173            #body_tokens
174        },
175    }
176}
177
178/// matches `typenum::consts` in a given range or array
179///
180/// use with an open or closed range, or an array of arbitrary numbers
181///
182/// use `P` | `N` | `U` | `False` | `_` or literals `1` | `-1` as match arms
183///
184/// a `NumType` type alias is available in each match arm,
185/// resolving to the specific typenum type for that value.
186/// Use `NumType` to reference the resolved type in the match arm body.
187///
188/// ## Example (range)
189///
190/// ```
191/// let x = 3;
192///
193/// u_num_it::u_num_it!(1..10, match x {
194///     U => {
195///         // NumType is typenum::consts::U3 when x=3
196///         let val = NumType::new();
197///         println!("{:?}", val);
198///         // UInt { msb: UInt { msb: UTerm, lsb: B1 }, lsb: B1 }
199///
200///         use typenum::ToInt;
201///         let num: usize = NumType::to_int();
202///         assert_eq!(num, 3);
203///     }
204/// })
205/// ```
206///
207/// ## Example (array)
208///
209/// ```
210/// let x = 8;
211///
212/// u_num_it::u_num_it!([1, 2, 8, 22], match x {
213///     P => {
214///         // NumType is typenum::consts::P8 when x=8
215///         use typenum::ToInt;
216///         let num: i32 = NumType::to_int();
217///         assert_eq!(num, 8);
218///     }
219/// })
220/// ```
221#[proc_macro]
222pub fn u_num_it(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
223    let UNumIt { range, arms, expr } = parse_macro_input!(tokens as UNumIt);
224
225    let pos_u = arms.contains_key(&UType::U);
226
227    let expanded_arms = range.iter().filter_map(|i| {
228        // First check if there's a specific literal match for this number
229        if let Some(body) = arms.get(&UType::Literal(*i)) {
230            return Some(make_match_arm(i, body, UType::Literal(*i)));
231        }
232
233        // Otherwise, use the general type patterns
234        match i {
235            0 => arms
236                .get(&UType::False)
237                .map(|body| make_match_arm(i, body, UType::False)),
238            i if *i < 0 => arms
239                .get(&UType::N)
240                .map(|body| make_match_arm(i, body, UType::N)),
241            i if *i > 0 => {
242                if pos_u {
243                    arms.get(&UType::U)
244                        .map(|body| make_match_arm(i, body, UType::U))
245                } else {
246                    arms.get(&UType::P)
247                        .map(|body| make_match_arm(i, body, UType::P))
248                }
249            }
250            _ => unreachable!(),
251        }
252    });
253
254    let fallback = arms
255        .get(&UType::None)
256        .map(|body| {
257            quote! {
258                _ => {
259                    #body
260                }
261            }
262        })
263        .unwrap_or_else(|| {
264            let first = range.first().unwrap_or(&0);
265            let last = range.last().unwrap_or(&0);
266            quote! {
267                i => unreachable!("{i} is not in range {}-{:?}", #first, #last)
268            }
269        });
270
271    let expanded = quote! {
272        match #expr {
273            #(#expanded_arms)*
274            #fallback
275        }
276    };
277
278    proc_macro::TokenStream::from(expanded)
279}