u_num_it/
lib.rs

1extern crate proc_macro;
2
3use std::{collections::HashMap, str::FromStr};
4
5use proc_macro2::{Group, Literal, Span, TokenStream, TokenTree};
6
7use quote::{quote, ToTokens};
8use syn::{
9    parse::Parse, parse_macro_input, spanned::Spanned, Expr, ExprMatch, Ident, Pat, PatRange,
10    RangeLimits, Token,
11};
12
13#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy)]
14enum UType {
15    N,
16    P,
17    U,
18    False,
19    None,
20    Literal(isize),
21}
22
23impl std::fmt::Display for UType {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            UType::N => write!(f, "N"),
27            UType::P => write!(f, "P"),
28            UType::U => write!(f, "U"),
29            UType::False => write!(f, "False"),
30            UType::None => write!(f, ""),
31            UType::Literal(_) => write!(f, ""),
32        }
33    }
34}
35
36struct UNumIt {
37    range: Vec<isize>,
38    arms: HashMap<UType, Box<Expr>>,
39    expr: Box<Expr>,
40}
41
42fn range_boundary(val: &Option<Box<Expr>>) -> syn::Result<Option<isize>> {
43    if let Some(val) = val.clone() {
44        let string = val.to_token_stream().to_string().replace(' ', "");
45        let value = string
46            .parse::<isize>()
47            .map_err(|e| syn::Error::new(val.span(), format!("{e}: `{string}`").as_str()))?;
48
49        Ok(Some(value))
50    } else {
51        Ok(None)
52    }
53}
54
55impl Parse for UNumIt {
56    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
57        let range: PatRange = input.parse()?;
58
59        let start = range_boundary(&range.start)?.unwrap_or(0);
60        let end = range_boundary(&range.end)?.unwrap_or(isize::MAX);
61
62        let range = match &range.limits {
63            RangeLimits::HalfOpen(_) => (start..end).collect(),
64            RangeLimits::Closed(_) => (start..=end).collect(),
65        };
66
67        input.parse::<Token![,]>()?;
68        let matcher: ExprMatch = input.parse()?;
69
70        let mut arms = HashMap::new();
71
72        for arm in matcher.arms.iter() {
73            let u_type = match &arm.pat {
74                Pat::Ident(t) => match t.ident.to_token_stream().to_string().as_str() {
75                    "N" => UType::N,
76                    "P" => UType::P,
77                    "U" => UType::U,
78                    "False" => UType::False,
79                    _ => {
80                        return Err(syn::Error::new(
81                            t.span(),
82                            "exepected idents N | P | U, False or _",
83                        ))
84                    }
85                },
86                Pat::Lit(lit_expr) => {
87                    // Parse literal numbers in match arms
88                    let lit_str = lit_expr.to_token_stream().to_string();
89                    let value = lit_str.parse::<isize>().map_err(|e| {
90                        syn::Error::new(lit_expr.span(), format!("invalid literal: {e}"))
91                    })?;
92                    UType::Literal(value)
93                }
94                Pat::Wild(_) => UType::None,
95                _ => return Err(syn::Error::new(arm.pat.span(), "exepected ident")),
96            };
97            let arm_expr = arm.body.clone();
98            if arms.insert(u_type, arm_expr.clone()).is_some() {
99                return Err(syn::Error::new(arm_expr.span(), "duplicate type"));
100            }
101        }
102
103        if arms.get(&UType::P).and(arms.get(&UType::U)).is_some() {
104            return Err(syn::Error::new(
105                matcher.span(),
106                "ambiguous type, don't use P and U in the same macro call",
107            ));
108        }
109
110        // Check for conflict between literal 0 and False (they represent the same value in typenum)
111        if arms.get(&UType::Literal(0)).and(arms.get(&UType::False)).is_some() {
112            return Err(syn::Error::new(
113                matcher.span(),
114                "ambiguous type, don't use literal 0 and False in the same macro call (they represent the same value)",
115            ));
116        }
117
118        let expr = matcher.expr;
119
120        Ok(UNumIt { range, arms, expr })
121    }
122}
123
124fn make_body_variant(body: TokenStream, type_variant: TokenStream, u_type: UType) -> TokenStream {
125    let tokens = body.into_iter().fold(vec![], |mut acc, token| {
126        let type_variant = type_variant.clone();
127        match token {
128            TokenTree::Ident(ref ident) => {
129                if *ident == u_type.to_string() {
130                    acc.extend(quote!(#type_variant).to_token_stream());
131                } else {
132                    acc.push(token);
133                }
134            }
135            TokenTree::Group(ref group) => {
136                let inner = make_body_variant(group.stream(), type_variant, u_type);
137                acc.push(TokenTree::Group(Group::new(group.delimiter(), inner)));
138            }
139            _ => acc.push(token),
140        };
141        acc
142    });
143
144    quote! {#(#tokens)*}
145}
146
147fn make_match_arm(i: &isize, body: &Expr, u_type: UType) -> TokenStream {
148    let match_expr = TokenTree::Literal(Literal::from_str(i.to_string().as_str()).unwrap());
149    
150    // For literal types, use the body as-is without type replacement
151    if let UType::Literal(_) = u_type {
152        let body_tokens = body.to_token_stream();
153        return quote! {
154            #match_expr => {
155                #body_tokens
156            },
157        };
158    }
159    
160    // For type patterns (N, P, U, False), perform type replacement
161    let i_str = if *i != 0 {
162        i.abs().to_string()
163    } else {
164        Default::default()
165    };
166    let typenum_type = TokenTree::Ident(Ident::new(
167        format!("{}{}", u_type, i_str).as_str(),
168        Span::mixed_site(),
169    ));
170    let type_variant = quote!(typenum::consts::#typenum_type);
171    let body_variant = make_body_variant(body.to_token_stream(), type_variant, u_type);
172
173    quote! {
174        #match_expr => {
175            #body_variant
176        },
177    }
178}
179
180/// matches `typenum::consts` in a given range
181///
182/// use with an open or closed range
183///
184/// use `P` | `N` | `U` | `False` | `_` as match arms
185///
186/// ## Example
187///
188/// ```
189/// let x = 3;
190///
191/// u_num_it::u_num_it!(1..10, match x {
192///     U => {
193///         let val = U::new();
194///         println!("{:?}", val);
195///         // UInt { msb: UInt { msb: UTerm, lsb: B1 }, lsb: B1 }
196///     }
197/// })
198/// ```
199#[proc_macro]
200pub fn u_num_it(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
201    let UNumIt { range, arms, expr } = parse_macro_input!(tokens as UNumIt);
202
203    let pos_u = arms.get(&UType::U).is_some();
204
205    let expanded_arms = range.iter().filter_map(|i| {
206        // First check if there's a specific literal match for this number
207        if let Some(body) = arms.get(&UType::Literal(*i)) {
208            return Some(make_match_arm(i, body, UType::Literal(*i)));
209        }
210        
211        // Otherwise, use the general type patterns
212        match i {
213            0 => arms
214                .get(&UType::False)
215                .map(|body| make_match_arm(i, body, UType::False)),
216            i if *i < 0 => arms
217                .get(&UType::N)
218                .map(|body| make_match_arm(i, body, UType::N)),
219            i if *i > 0 => {
220                if pos_u {
221                    arms.get(&UType::U)
222                        .map(|body| make_match_arm(i, body, UType::U))
223                } else {
224                    arms.get(&UType::P)
225                        .map(|body| make_match_arm(i, body, UType::P))
226                }
227            }
228            _ => unreachable!(),
229        }
230    });
231
232    let fallback = arms
233        .get(&UType::None)
234        .map(|body| {
235            quote! {
236                _ => {
237                    #body
238                }
239            }
240        })
241        .unwrap_or_else(|| {
242            let first = range.first().unwrap_or(&0);
243            let last = range.last().unwrap_or(&0);
244            quote! {
245                i => unreachable!("{i} is not in range {}-{:?}", #first, #last)
246            }
247        });
248
249    let expanded = quote! {
250        match #expr {
251            #(#expanded_arms)*
252            #fallback
253        }
254    };
255
256    proc_macro::TokenStream::from(expanded)
257}