bounded_integer_macro/
lib.rs

1//! A macro for generating bounded integer structs and enums.
2//!
3//! This crate is unstable and must not be used directly.
4#![warn(clippy::pedantic, rust_2018_idioms, unused_qualifications)]
5#![allow(clippy::single_match_else, clippy::match_bool)]
6
7use std::array;
8use std::fmt::Debug;
9
10use proc_macro2::{Delimiter, Ident, Literal, Span, TokenStream, TokenTree};
11use quote::{ToTokens, quote, quote_spanned};
12
13#[proc_macro]
14#[doc(hidden)]
15#[expect(clippy::too_many_lines)]
16pub fn bounded_integer(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
17    let input = TokenStream::from(input).into_iter().map(|t| {
18        let TokenTree::Group(group) = t else {
19            panic!("non-group in input")
20        };
21        assert_eq!(group.delimiter(), Delimiter::Bracket);
22        group.stream()
23    });
24    let [
25        zerocopy,
26        outer_attr,
27        mut attrs,
28        vis,
29        super_vis,
30        is_named,
31        item_kind,
32        name,
33        min_or_variants,
34        max_or_none,
35        crate_path,
36    ] = to_array(input);
37
38    let zerocopy = match to_array(zerocopy) {
39        [TokenTree::Punct(p)] if p.as_char() == '-' => false,
40        [TokenTree::Punct(p)] if p.as_char() == '+' => true,
41        [t] => panic!("zerocopy ({t})"),
42    };
43
44    let [TokenTree::Ident(item_kind)] = to_array(item_kind) else {
45        panic!("item kind")
46    };
47    let is_enum = match &*item_kind.to_string() {
48        "struct" => false,
49        "enum" => true,
50        s => panic!("unknown item kind {s}"),
51    };
52    let [TokenTree::Ident(name)] = to_array(name) else {
53        panic!("name")
54    };
55
56    let mut new_attrs = TokenStream::new();
57    let mut import_attrs = TokenStream::new();
58    let mut maybe_repr = None;
59    for attr in attrs {
60        let TokenTree::Group(group) = &attr else {
61            panic!("attr ({attr})")
62        };
63        let tokens = group.stream().into_iter().collect::<Vec<_>>();
64        if let Some(TokenTree::Ident(i)) = tokens.first() {
65            let name = i.to_string();
66
67            if name == "repr"
68                && let [_, TokenTree::Group(g)] = &*tokens
69                && g.delimiter() == Delimiter::Parenthesis
70            {
71                if maybe_repr.is_some() {
72                    return error!(i.span(), "duplicate `repr` attribute");
73                }
74                maybe_repr = Some(g.stream());
75                continue;
76            } else if ["allow", "expect", "warn", "deny", "forbid"].contains(&&*name)
77                && let [_, TokenTree::Group(g)] = &*tokens
78                && g.delimiter() == Delimiter::Parenthesis
79                && let [Some(TokenTree::Ident(lint)), None] = {
80                    let mut iter = g.stream().into_iter();
81                    [iter.next(), iter.next()]
82                }
83                && (lint == "unused" || lint == "unused_imports")
84            {
85                import_attrs.extend(quote!(# #attr));
86                continue;
87            }
88        }
89        new_attrs.extend(quote!(# #attr));
90    }
91    attrs = new_attrs;
92
93    let (variants, min, max, min_val, max_val);
94    match to_array(is_named) {
95        // Unnamed
96        [TokenTree::Punct(p)] if p.as_char() == '-' => {
97            [min, max] = [min_or_variants, max_or_none].map(ungroup_none);
98            [min_val, max_val] = [&min, &max].map(|lit| {
99                parse_literal(lit.clone()).map(|(lit, repr)| {
100                    // if there is an existing repr, Rust will cause an error anyway later on
101                    if let Some(repr) = repr
102                        && maybe_repr.is_none()
103                    {
104                        maybe_repr = Some(quote!(#repr));
105                    }
106                    lit
107                })
108            });
109
110            variants = match is_enum {
111                false => None,
112                true => {
113                    let Some(min_val) = min_val else {
114                        return error!(min, "`enum` requires bound to be statically known");
115                    };
116                    let Some(max_val) = max_val else {
117                        return error!(max, "`enum` requires bound to be statically known");
118                    };
119                    let Some(range) = range(min_val, max_val) else {
120                        return error!(min, "refusing to generate this many `enum` variants");
121                    };
122                    let mut variants = TokenStream::new();
123                    let min_span = stream_span(min.clone());
124                    for int in range {
125                        let enum_variant_name = int.enum_variant_name(min_span);
126                        if int == min_val {
127                            variants.extend(quote!(#[allow(dead_code)] #enum_variant_name = #min,));
128                        } else {
129                            variants.extend(quote!(#[allow(dead_code)] #enum_variant_name,));
130                        }
131                    }
132                    Some(variants)
133                }
134            };
135        }
136        // Named
137        [TokenTree::Punct(p)] if p.as_char() == '+' => {
138            assert!(is_enum);
139            assert!(max_or_none.into_iter().next().is_none());
140
141            // ((min_val, min), current_val, current_span)
142            let mut min_current = None::<((Int, TokenStream), Int, Span)>;
143            let mut variant_list = TokenStream::new();
144            for variant in min_or_variants {
145                let TokenTree::Group(variant) = variant else {
146                    panic!("variant")
147                };
148                let [
149                    TokenTree::Group(attrs),
150                    TokenTree::Ident(variant_name),
151                    TokenTree::Group(variant_val),
152                ] = to_array(variant.stream())
153                else {
154                    panic!("variant inner")
155                };
156                let attrs = attrs.stream();
157                let variant_val = variant_val.stream();
158                min_current = Some(if variant_val.is_empty() {
159                    variant_list.extend(quote!(#attrs #variant_name,));
160                    match min_current {
161                        Some((min, current, current_span)) => match current.succ() {
162                            Some(current) => (min, current, current_span),
163                            None => {
164                                return error!(
165                                    variant_name.span(),
166                                    "too many variants (overflows a u128)"
167                                );
168                            }
169                        },
170                        None => (
171                            (Int::new(true, 0), quote_spanned!(variant_name.span()=> 0)),
172                            Int::new(true, 0),
173                            variant_name.span(),
174                        ),
175                    }
176                } else {
177                    variant_list.extend(quote!(#attrs #variant_name = #variant_val,));
178                    let variant_val = ungroup_none(variant_val);
179                    let Some((int, _)) = parse_literal(variant_val.clone()) else {
180                        return error!(variant_val, "could not parse variant value");
181                    };
182                    match min_current {
183                        Some((min, current, _)) if current.succ() == Some(int) => {
184                            (min, int, stream_span(variant_val))
185                        }
186                        Some(_) => return error!(variant_val, "enum not contiguous"),
187                        None => ((int, variant_val.clone()), int, stream_span(variant_val)),
188                    }
189                });
190            }
191            variants = Some(variant_list);
192            [(min_val, min), (max_val, max)] = match min_current {
193                Some(((min_val, min), current, current_span)) => [
194                    (Some(min_val), min),
195                    (Some(current), current.literal(current_span)),
196                ],
197                None => [
198                    (Some(Int::new(true, 1)), quote!(1)),
199                    (Some(Int::new(true, 0)), quote!(0)),
200                ],
201            };
202        }
203        [t] => panic!("named ({t})"),
204    }
205
206    let zero = min_val
207        .zip(max_val)
208        .map(|(min, max)| (min..=max).contains(&Int::new(true, 0)));
209    let one = min_val
210        .zip(max_val)
211        .map(|(min, max)| (min..=max).contains(&Int::new(true, 1)));
212    if zero == Some(true) && zerocopy {
213        attrs.extend(quote!(#[derive(#crate_path::__private::zerocopy::FromZeros)]));
214    }
215    let zero_token = match zero {
216        Some(true) => quote!(zero,),
217        Some(false) | None => quote!(),
218    };
219    let one_token = match one {
220        Some(true) => quote!(one,),
221        Some(false) | None => quote!(),
222    };
223
224    let repr = match (maybe_repr, min_val, max_val) {
225        (Some(repr), _, _) => repr,
226        (None, Some(min_val), Some(max_val)) => match infer_repr(min_val, max_val) {
227            Some(repr) => {
228                let repr = Ident::new(&repr, stream_span(min.clone()));
229                quote!(#repr)
230            }
231            None => return error!(min, "range too large for any integer type"),
232        },
233        (None, _, _) => {
234            let msg = "no #[repr] attribute found, and could not infer";
235            return error!(min, "{msg}");
236        }
237    };
238
239    match is_enum {
240        false => attrs.extend(quote!(#[repr(transparent)])),
241        true => attrs.extend(quote!(#[repr(#repr)])),
242    }
243
244    if matches!(repr.to_string().trim(), "u8" | "i8") && zerocopy {
245        attrs.extend(quote!(#[derive(#crate_path::__private::zerocopy::Unaligned)]));
246    }
247
248    let item = match variants {
249        Some(variants) => quote!({ #variants }),
250        None if zero == Some(false) => quote!((::core::num::NonZero<#repr>);),
251        None => quote!((#repr);),
252    };
253
254    // Hide in a module to prevent access to private parts.
255    let module_name = Ident::new(&format!("__bounded_integer_private_{name}"), name.span());
256
257    let res = quote!(
258        #[allow(non_snake_case)]
259        #outer_attr
260        mod #module_name {
261            #attrs
262            #super_vis #item_kind #name #item
263
264            #crate_path::unsafe_api! {
265                for #name,
266                unsafe repr: #repr,
267                min: #min,
268                max: #max,
269                #zero_token
270                #one_token
271            }
272        }
273        #import_attrs #vis use #module_name::#name;
274    );
275
276    res.into()
277}
278
279fn to_array<I: IntoIterator<Item: Debug>, const N: usize>(iter: I) -> [I::Item; N] {
280    let mut iter = iter.into_iter();
281    let array = array::from_fn(|_| iter.next().expect("iterator too short"));
282    if let Some(item) = iter.next() {
283        panic!("iterator too long: found {item:?}");
284    }
285    array
286}
287
288#[derive(Debug, Clone, Copy, PartialEq, Eq)]
289struct Int {
290    nonnegative: bool,
291    magnitude: u128,
292}
293
294impl Int {
295    fn new(nonnegative: bool, magnitude: u128) -> Self {
296        Self {
297            nonnegative,
298            magnitude,
299        }
300    }
301    fn succ(self) -> Option<Self> {
302        Some(match self.nonnegative {
303            true => Self::new(true, self.magnitude.checked_add(1)?),
304            false if self.magnitude == 1 => Self::new(true, 0),
305            false => Self::new(false, self.magnitude - 1),
306        })
307    }
308    fn enum_variant_name(self, span: Span) -> Ident {
309        if self.magnitude == 0 {
310            Ident::new("Z", span)
311        } else if self.nonnegative {
312            Ident::new(&format!("P{}", self.magnitude), span)
313        } else {
314            Ident::new(&format!("N{}", self.magnitude), span)
315        }
316    }
317    fn literal(self, span: Span) -> TokenStream {
318        let mut magnitude = Literal::u128_unsuffixed(self.magnitude);
319        magnitude.set_span(span);
320        match self.nonnegative {
321            true => quote!(#magnitude),
322            false => quote!(-#magnitude),
323        }
324    }
325}
326
327impl PartialOrd for Int {
328    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
329        Some(self.cmp(other))
330    }
331}
332
333impl Ord for Int {
334    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
335        match (self.nonnegative, other.nonnegative) {
336            (true, true) => self.magnitude.cmp(&other.magnitude),
337            (true, false) => std::cmp::Ordering::Greater,
338            (false, true) => std::cmp::Ordering::Less,
339            (false, false) => other.magnitude.cmp(&self.magnitude),
340        }
341    }
342}
343
344fn parse_literal(e: TokenStream) -> Option<(Int, Option<Ident>)> {
345    let mut tokens = e.into_iter().peekable();
346    let minus = tokens
347        .next_if(|t| matches!(t, TokenTree::Punct(p) if p.as_char() == '-'))
348        .is_some();
349    let Some(TokenTree::Literal(lit)) = tokens.next() else {
350        return None;
351    };
352    if tokens.next().is_some() {
353        return None;
354    }
355
356    // Algorithm reference:
357    // https://docs.rs/syn/2.0.104/src/syn/lit.rs.html#1679-1767
358
359    let mut lit_chars = &*lit.to_string();
360
361    let (base, base_len) = match lit_chars.get(..2) {
362        Some("0x") => (16, 2),
363        Some("0o") => (8, 2),
364        Some("0b") => (2, 2),
365        _ => (10, 0),
366    };
367    lit_chars = &lit_chars[base_len..];
368
369    let mut magnitude = 0_u128;
370    let mut has_digit = None;
371
372    let suffix = loop {
373        lit_chars = lit_chars.trim_start_matches('_');
374        let Some(c) = lit_chars.chars().next() else {
375            has_digit?;
376            break None;
377        };
378        if let 'i' | 'u' = c {
379            let ("8" | "16" | "32" | "64" | "128" | "size") = &lit_chars[1..] else {
380                return None;
381            };
382            break Some(Ident::new(lit_chars, lit.span()));
383        }
384        let digit = c.to_digit(base)?;
385        lit_chars = &lit_chars[1..];
386        magnitude = magnitude
387            .checked_mul(base.into())?
388            .checked_add(digit.into())?;
389        has_digit = Some(());
390    };
391
392    let lit = Int::new(!minus || magnitude == 0, magnitude);
393    Some((lit, suffix))
394}
395
396fn range(min: Int, max: Int) -> Option<impl Iterator<Item = Int>> {
397    let range_minus_one = match (max.nonnegative, min.nonnegative) {
398        (true, true) => max.magnitude.saturating_sub(min.magnitude),
399        (true, false) => max.magnitude.saturating_add(min.magnitude),
400        (false, true) => 0,
401        (false, false) => min.magnitude.saturating_sub(max.magnitude),
402    };
403    if 100_000 <= range_minus_one {
404        return None;
405    }
406    #[expect(clippy::reversed_empty_ranges)]
407    let (negative_part, nonnegative_part) = match (min.nonnegative, max.nonnegative) {
408        (true, true) => (1..=0, min.magnitude..=max.magnitude),
409        (false, true) => (1..=min.magnitude, 0..=max.magnitude),
410        (true, false) => (1..=0, 1..=0),
411        (false, false) => (max.magnitude..=min.magnitude, 1..=0),
412    };
413    let negative_part = negative_part.map(|i| Int::new(false, i));
414    let nonnegative_part = nonnegative_part.map(|i| Int::new(true, i));
415    Some(negative_part.rev().chain(nonnegative_part))
416}
417
418fn infer_repr(min: Int, max: Int) -> Option<String> {
419    for bits in [8, 16, 32, 64, 128] {
420        let fits_unsigned =
421            |lit: Int| lit.nonnegative && lit.magnitude <= (u128::MAX >> (128 - bits));
422        let fits_signed = |lit: Int| {
423            (lit.nonnegative && lit.magnitude < (1 << (bits - 1)))
424                || (!lit.nonnegative && lit.magnitude <= (1 << (bits - 1)))
425        };
426        if fits_unsigned(min) && fits_unsigned(max) {
427            return Some(format!("u{bits}"));
428        } else if fits_signed(min) && fits_signed(max) {
429            return Some(format!("i{bits}"));
430        }
431    }
432    None
433}
434
435fn ungroup_none(tokens: TokenStream) -> TokenStream {
436    let mut tokens = tokens.into_iter().peekable();
437    if let Some(TokenTree::Group(g)) =
438        tokens.next_if(|t| matches!(t, TokenTree::Group(g) if g.delimiter() == Delimiter::None))
439    {
440        return g.stream();
441    }
442    // Sigh… make it opportunistic to get it to work on rust-analyzer
443    // https://github.com/rust-lang/rust-analyzer/issues/18211
444    tokens.collect()
445}
446
447macro_rules! error {
448    ($span:expr, $($fmt:tt)*) => {{
449        let span = SpanHelper($span).span_helper();
450        let msg = format!($($fmt)*);
451        proc_macro::TokenStream::from(quote_spanned!(span=> compile_error!(#msg);))
452    }};
453}
454use error;
455
456struct SpanHelper<T>(T);
457impl SpanHelper<TokenStream> {
458    fn span_helper(self) -> Span {
459        stream_span(self.0.into_token_stream())
460    }
461}
462trait SpanHelperTrait {
463    fn span_helper(self) -> Span;
464}
465impl SpanHelperTrait for SpanHelper<Span> {
466    fn span_helper(self) -> Span {
467        self.0
468    }
469}
470
471fn stream_span(stream: TokenStream) -> Span {
472    stream
473        .into_iter()
474        .next()
475        .map_or_else(Span::call_site, |token| token.span())
476}