bounded_integer_macro/
lib.rs

1//! A macro for generating bounded integer structs and enums.
2//!
3//! This crate is unstable and must not be used directly.
4#![warn(clippy::pedantic, rust_2018_idioms, unused_qualifications)]
5#![allow(clippy::single_match_else, clippy::match_bool)]
6#![allow(unused)]
7
8use std::borrow::Borrow;
9use std::cmp;
10use std::convert::TryInto;
11use std::fmt::{self, Display, Formatter};
12use std::ops::RangeInclusive;
13
14use proc_macro2::{Group, Ident, Literal, Span, TokenStream};
15use quote::{quote, ToTokens, TokenStreamExt as _};
16use syn::parse::{self, Parse, ParseStream};
17use syn::{braced, parse_macro_input, token::Brace, Token};
18use syn::{Attribute, Error, Expr, PathArguments, PathSegment, Visibility};
19use syn::{BinOp, ExprBinary, ExprRange, ExprUnary, RangeLimits, UnOp};
20use syn::{ExprGroup, ExprParen};
21use syn::{ExprLit, Lit, LitBool};
22
23use num_bigint::{BigInt, Sign, TryFromBigIntError};
24
25mod generate;
26
27#[proc_macro]
28#[doc(hidden)]
29pub fn bounded_integer(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
30    let mut item = parse_macro_input!(input as BoundedInteger);
31
32    // Hide in a module to prevent access to private parts.
33    let module_name = Ident::new(
34        &format!("__bounded_integer_private_{}", item.ident),
35        item.ident.span(),
36    );
37    let ident = &item.ident;
38    let original_visibility = item.vis;
39
40    let import = quote!(#original_visibility use #module_name::#ident);
41
42    item.vis = raise_one_level(original_visibility);
43    let mut result = TokenStream::new();
44    generate::generate(&item, &mut result);
45
46    quote!(
47        #[allow(non_snake_case)]
48        mod #module_name {
49            #result
50        }
51        #import;
52    )
53    .into()
54}
55
56#[allow(clippy::struct_excessive_bools)]
57struct BoundedInteger {
58    // $crate
59    crate_path: TokenStream,
60
61    // Optional features
62    alloc: bool,
63    arbitrary1: bool,
64    bytemuck1: bool,
65    serde1: bool,
66    std: bool,
67    zerocopy: bool,
68    step_trait: bool,
69
70    // The item itself
71    attrs: Vec<Attribute>,
72    repr: Repr,
73    vis: Visibility,
74    kind: Kind,
75    ident: Ident,
76    brace_token: Brace,
77    range: RangeInclusive<BigInt>,
78}
79
80impl Parse for BoundedInteger {
81    fn parse(input: ParseStream<'_>) -> parse::Result<Self> {
82        let crate_path = input.parse::<Group>()?.stream();
83
84        let alloc = input.parse::<LitBool>()?.value;
85        let arbitrary1 = input.parse::<LitBool>()?.value;
86        let bytemuck1 = input.parse::<LitBool>()?.value;
87        let serde1 = input.parse::<LitBool>()?.value;
88        let std = input.parse::<LitBool>()?.value;
89        let zerocopy = input.parse::<LitBool>()?.value;
90        let step_trait = input.parse::<LitBool>()?.value;
91
92        let mut attrs = input.call(Attribute::parse_outer)?;
93
94        let repr_pos = attrs.iter().position(|attr| attr.path().is_ident("repr"));
95        let repr = repr_pos
96            .map(|pos| attrs.remove(pos).parse_args::<Repr>())
97            .transpose()?;
98
99        let vis: Visibility = input.parse()?;
100
101        let kind: Kind = input.parse()?;
102
103        let ident: Ident = input.parse()?;
104
105        let range_tokens;
106        let brace_token = braced!(range_tokens in input);
107        let range: ExprRange = range_tokens.parse()?;
108
109        let Some((start_expr, end_expr)) = range.start.as_deref().zip(range.end.as_deref()) else {
110            return Err(Error::new_spanned(range, "Range must be closed"));
111        };
112        let start = eval_expr(start_expr)?;
113        let end = eval_expr(end_expr)?;
114        let end = if let RangeLimits::HalfOpen(_) = range.limits {
115            end - 1
116        } else {
117            end
118        };
119        if start >= end {
120            return Err(Error::new_spanned(
121                range,
122                "The start of the range must be before the end",
123            ));
124        }
125
126        let repr = match repr {
127            Some(explicit_repr) => {
128                if explicit_repr.sign == Unsigned && start.sign() == Sign::Minus {
129                    return Err(Error::new_spanned(
130                        start_expr,
131                        "An unsigned integer cannot hold a negative value",
132                    ));
133                }
134
135                if explicit_repr.minimum().is_some_and(|min| start < min) {
136                    return Err(Error::new_spanned(
137                        start_expr,
138                        format_args!(
139                            "Bound {start} is below the minimum value for the underlying type",
140                        ),
141                    ));
142                }
143                if explicit_repr.maximum().is_some_and(|max| end > max) {
144                    return Err(Error::new_spanned(
145                        end_expr,
146                        format_args!(
147                            "Bound {end} is above the maximum value for the underlying type",
148                        ),
149                    ));
150                }
151
152                explicit_repr
153            }
154            None => Repr::smallest_repr(&start, &end).ok_or_else(|| {
155                Error::new_spanned(range, "Range is too wide to fit in any integer primitive")
156            })?,
157        };
158
159        Ok(Self {
160            crate_path,
161            alloc,
162            arbitrary1,
163            bytemuck1,
164            serde1,
165            std,
166            zerocopy,
167            step_trait,
168            attrs,
169            repr,
170            vis,
171            kind,
172            ident,
173            brace_token,
174            range: start..=end,
175        })
176    }
177}
178
179enum Kind {
180    Struct(Token![struct]),
181    Enum(Token![enum]),
182}
183
184impl Parse for Kind {
185    fn parse(input: ParseStream<'_>) -> parse::Result<Self> {
186        Ok(if input.peek(Token![struct]) {
187            Self::Struct(input.parse()?)
188        } else {
189            Self::Enum(input.parse()?)
190        })
191    }
192}
193
194#[derive(Clone, Copy, PartialEq, Eq)]
195enum ReprSign {
196    Signed,
197    Unsigned,
198}
199use ReprSign::{Signed, Unsigned};
200
201struct Repr {
202    sign: ReprSign,
203    size: ReprSize,
204    name: Ident,
205}
206
207impl Repr {
208    fn new(sign: ReprSign, size: ReprSize) -> Self {
209        let prefix = match sign {
210            Signed => 'i',
211            Unsigned => 'u',
212        };
213        Self {
214            sign,
215            size,
216            name: Ident::new(&format!("{prefix}{size}"), Span::call_site()),
217        }
218    }
219
220    fn smallest_repr(min: &BigInt, max: &BigInt) -> Option<Self> {
221        // NOTE: Never infer nonzero types, even if we can.
222        Some(if min.sign() == Sign::Minus {
223            Self::new(
224                Signed,
225                ReprSize::Fixed(cmp::max(
226                    ReprSizeFixed::from_bits((min + 1_u8).bits() + 1)?,
227                    ReprSizeFixed::from_bits(max.bits() + 1)?,
228                )),
229            )
230        } else {
231            Self::new(
232                Unsigned,
233                ReprSize::Fixed(ReprSizeFixed::from_bits(max.bits())?),
234            )
235        })
236    }
237
238    fn minimum(&self) -> Option<BigInt> {
239        Some(match (self.sign, self.size) {
240            (Unsigned, ReprSize::Fixed(_)) => BigInt::from(0u8),
241            (Signed, ReprSize::Fixed(size)) => -(BigInt::from(1u8) << (size.to_bits() - 1)),
242            (_, ReprSize::Pointer) => return None,
243        })
244    }
245
246    fn maximum(&self) -> Option<BigInt> {
247        Some(match (self.sign, self.size) {
248            (Unsigned, ReprSize::Fixed(size)) => (BigInt::from(1u8) << size.to_bits()) - 1,
249            (Signed, ReprSize::Fixed(size)) => (BigInt::from(1u8) << (size.to_bits() - 1)) - 1,
250            (_, ReprSize::Pointer) => return None,
251        })
252    }
253
254    fn try_number_literal(
255        &self,
256        value: impl Borrow<BigInt>,
257    ) -> Result<Literal, TryFromBigIntError<()>> {
258        macro_rules! match_repr {
259            ($($sign:ident $size:ident $(($fixed:ident))? => $f:ident,)*) => {
260                match (self.sign, self.size) {
261                    $(($sign, ReprSize::$size $((ReprSizeFixed::$fixed))?) => {
262                        Ok(Literal::$f(value.borrow().try_into()?))
263                    })*
264                }
265            }
266        }
267
268        match_repr! {
269            Unsigned Fixed(Fixed8) => u8_suffixed,
270            Unsigned Fixed(Fixed16) => u16_suffixed,
271            Unsigned Fixed(Fixed32) => u32_suffixed,
272            Unsigned Fixed(Fixed64) => u64_suffixed,
273            Unsigned Fixed(Fixed128) => u128_suffixed,
274            Unsigned Pointer => usize_suffixed,
275            Signed Fixed(Fixed8) => i8_suffixed,
276            Signed Fixed(Fixed16) => i16_suffixed,
277            Signed Fixed(Fixed32) => i32_suffixed,
278            Signed Fixed(Fixed64) => i64_suffixed,
279            Signed Fixed(Fixed128) => i128_suffixed,
280            Signed Pointer => isize_suffixed,
281        }
282    }
283
284    fn number_literal(&self, value: impl Borrow<BigInt>) -> Literal {
285        self.try_number_literal(value).unwrap()
286    }
287
288    fn larger_reprs(&self) -> impl Iterator<Item = Self> {
289        match self.sign {
290            Signed => Either::A(self.size.larger_reprs().map(|size| Self::new(Signed, size))),
291            Unsigned => Either::B(
292                self.size
293                    .larger_reprs()
294                    .map(|size| Self::new(Unsigned, size))
295                    .chain(
296                        self.size
297                            .larger_reprs()
298                            .skip(1)
299                            .map(|size| Self::new(Signed, size)),
300                    ),
301            ),
302        }
303    }
304
305    fn is_usize(&self) -> bool {
306        matches!((self.sign, self.size), (Unsigned, ReprSize::Pointer))
307    }
308}
309
310impl Parse for Repr {
311    fn parse(input: ParseStream<'_>) -> parse::Result<Self> {
312        let name = input.parse::<Ident>()?;
313        let span = name.span();
314        let s = name.to_string();
315
316        let (size, sign) = if let Some(size) = s.strip_prefix('i') {
317            (size, Signed)
318        } else if let Some(size) = s.strip_prefix('u') {
319            (size, Unsigned)
320        } else {
321            return Err(Error::new(span, "Repr must a primitive integer type"));
322        };
323
324        let size = match size {
325            "8" => ReprSize::Fixed(ReprSizeFixed::Fixed8),
326            "16" => ReprSize::Fixed(ReprSizeFixed::Fixed16),
327            "32" => ReprSize::Fixed(ReprSizeFixed::Fixed32),
328            "64" => ReprSize::Fixed(ReprSizeFixed::Fixed64),
329            "128" => ReprSize::Fixed(ReprSizeFixed::Fixed128),
330            "size" => ReprSize::Pointer,
331            unknown => {
332                return Err(Error::new(
333                    span,
334                    format_args!(
335                        "Unknown integer size {unknown}, must be one of 8, 16, 32, 64, 128 or size",
336                    ),
337                ));
338            }
339        };
340
341        Ok(Self { sign, size, name })
342    }
343}
344
345impl ToTokens for Repr {
346    fn to_tokens(&self, tokens: &mut TokenStream) {
347        tokens.append(self.name.clone());
348    }
349}
350
351#[derive(Clone, Copy)]
352enum ReprSize {
353    Fixed(ReprSizeFixed),
354
355    /// `usize`/`isize`
356    Pointer,
357}
358
359impl ReprSize {
360    fn larger_reprs(self) -> impl Iterator<Item = Self> {
361        match self {
362            Self::Fixed(fixed) => Either::A(fixed.larger_reprs().map(Self::Fixed)),
363            Self::Pointer => Either::B(std::iter::once(Self::Pointer)),
364        }
365    }
366}
367
368impl Display for ReprSize {
369    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
370        match self {
371            Self::Fixed(fixed) => fixed.fmt(f),
372            Self::Pointer => f.write_str("size"),
373        }
374    }
375}
376
377#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
378enum ReprSizeFixed {
379    Fixed8,
380    Fixed16,
381    Fixed32,
382    Fixed64,
383    Fixed128,
384}
385
386impl ReprSizeFixed {
387    fn to_bits(self) -> u64 {
388        match self {
389            ReprSizeFixed::Fixed8 => 8,
390            ReprSizeFixed::Fixed16 => 16,
391            ReprSizeFixed::Fixed32 => 32,
392            ReprSizeFixed::Fixed64 => 64,
393            ReprSizeFixed::Fixed128 => 128,
394        }
395    }
396
397    fn from_bits(bits: u64) -> Option<Self> {
398        Some(match bits {
399            0..=8 => Self::Fixed8,
400            9..=16 => Self::Fixed16,
401            17..=32 => Self::Fixed32,
402            33..=64 => Self::Fixed64,
403            65..=128 => Self::Fixed128,
404            129..=u64::MAX => return None,
405        })
406    }
407
408    fn larger_reprs(self) -> impl Iterator<Item = Self> {
409        const REPRS: [ReprSizeFixed; 5] = [
410            ReprSizeFixed::Fixed8,
411            ReprSizeFixed::Fixed16,
412            ReprSizeFixed::Fixed32,
413            ReprSizeFixed::Fixed64,
414            ReprSizeFixed::Fixed128,
415        ];
416        let index = match self {
417            Self::Fixed8 => 0,
418            Self::Fixed16 => 1,
419            Self::Fixed32 => 2,
420            Self::Fixed64 => 3,
421            Self::Fixed128 => 4,
422        };
423        REPRS[index..].iter().copied()
424    }
425}
426
427impl Display for ReprSizeFixed {
428    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
429        f.write_str(match self {
430            Self::Fixed8 => "8",
431            Self::Fixed16 => "16",
432            Self::Fixed32 => "32",
433            Self::Fixed64 => "64",
434            Self::Fixed128 => "128",
435        })
436    }
437}
438
439fn eval_expr(expr: &Expr) -> syn::Result<BigInt> {
440    Ok(match expr {
441        Expr::Lit(ExprLit { lit, .. }) => match lit {
442            Lit::Byte(byte) => byte.value().into(),
443            Lit::Int(int) => int.base10_parse()?,
444            _ => {
445                return Err(Error::new_spanned(lit, "literal must be integer"));
446            }
447        },
448        Expr::Unary(ExprUnary { op, expr, .. }) => {
449            let expr = eval_expr(expr)?;
450            match op {
451                UnOp::Not(_) => !expr,
452                UnOp::Neg(_) => -expr,
453                _ => return Err(Error::new_spanned(op, "unary operator must be ! or -")),
454            }
455        }
456        Expr::Binary(ExprBinary {
457            left, op, right, ..
458        }) => {
459            let left = eval_expr(left)?;
460            let right = eval_expr(right)?;
461            match op {
462                BinOp::Add(_) => left + right,
463                BinOp::Sub(_) => left - right,
464                BinOp::Mul(_) => left * right,
465                BinOp::Div(_) => left
466                    .checked_div(&right)
467                    .ok_or_else(|| Error::new_spanned(op, "Attempted to divide by zero"))?,
468                BinOp::Rem(_) => left % right,
469                BinOp::BitXor(_) => left ^ right,
470                BinOp::BitAnd(_) => left & right,
471                BinOp::BitOr(_) => left | right,
472                _ => {
473                    return Err(Error::new_spanned(
474                        op,
475                        "operator not supported in this context",
476                    ));
477                }
478            }
479        }
480        Expr::Group(ExprGroup { expr, .. }) | Expr::Paren(ExprParen { expr, .. }) => {
481            eval_expr(expr)?
482        }
483        _ => return Err(Error::new_spanned(expr, "expected simple expression")),
484    })
485}
486
487/// Raise a visibility one level.
488///
489/// ```text
490/// no visibility -> pub(super)
491/// pub(self) -> pub(super)
492/// pub(in self) -> pub(in super)
493/// pub(in self::some::path) -> pub(in super::some::path)
494/// pub(super) -> pub(in super::super)
495/// pub(in super) -> pub(in super::super)
496/// pub(in super::some::path) -> pub(in super::super::some::path)
497/// ```
498fn raise_one_level(vis: Visibility) -> Visibility {
499    match vis {
500        Visibility::Inherited => syn::parse2(quote!(pub(super))).unwrap(),
501        Visibility::Restricted(mut restricted)
502            if restricted.path.segments.first().unwrap().ident == "self" =>
503        {
504            let first = &mut restricted.path.segments.first_mut().unwrap().ident;
505            *first = Ident::new("super", first.span());
506            Visibility::Restricted(restricted)
507        }
508        Visibility::Restricted(mut restricted)
509            if restricted.path.segments.first().unwrap().ident == "super" =>
510        {
511            restricted
512                .in_token
513                .get_or_insert_with(<Token![in]>::default);
514            let first = PathSegment {
515                ident: restricted.path.segments.first().unwrap().ident.clone(),
516                arguments: PathArguments::None,
517            };
518            restricted.path.segments.insert(0, first);
519            Visibility::Restricted(restricted)
520        }
521        absolute_visibility => absolute_visibility,
522    }
523}
524
525#[test]
526fn test_raise_one_level() {
527    #[track_caller]
528    fn assert_output(input: TokenStream, output: TokenStream) {
529        let tokens = raise_one_level(syn::parse2(input).unwrap()).into_token_stream();
530        assert_eq!(tokens.to_string(), output.to_string());
531        drop(output);
532    }
533
534    assert_output(TokenStream::new(), quote!(pub(super)));
535    assert_output(quote!(pub(self)), quote!(pub(super)));
536    assert_output(quote!(pub(in self)), quote!(pub(in super)));
537    assert_output(
538        quote!(pub(in self::some::path)),
539        quote!(pub(in super::some::path)),
540    );
541    assert_output(quote!(pub(super)), quote!(pub(in super::super)));
542    assert_output(quote!(pub(in super)), quote!(pub(in super::super)));
543    assert_output(
544        quote!(pub(in super::some::path)),
545        quote!(pub(in super::super::some::path)),
546    );
547
548    assert_output(quote!(pub), quote!(pub));
549    assert_output(quote!(pub(crate)), quote!(pub(crate)));
550    assert_output(quote!(pub(in crate)), quote!(pub(in crate)));
551    assert_output(
552        quote!(pub(in crate::some::path)),
553        quote!(pub(in crate::some::path)),
554    );
555}
556
557enum Either<A, B> {
558    A(A),
559    B(B),
560}
561impl<T, A: Iterator<Item = T>, B: Iterator<Item = T>> Iterator for Either<A, B> {
562    type Item = T;
563    fn next(&mut self) -> Option<Self::Item> {
564        match self {
565            Self::A(a) => a.next(),
566            Self::B(b) => b.next(),
567        }
568    }
569}