1#![warn(clippy::pedantic, rust_2018_idioms, unused_qualifications)]
5#![allow(clippy::single_match_else, clippy::match_bool)]
6#![allow(unused)]
7
8use std::borrow::Borrow;
9use std::cmp;
10use std::convert::TryInto;
11use std::fmt::{self, Display, Formatter};
12use std::ops::RangeInclusive;
13
14use proc_macro2::{Group, Ident, Literal, Span, TokenStream};
15use quote::{quote, ToTokens, TokenStreamExt as _};
16use syn::parse::{self, Parse, ParseStream};
17use syn::{braced, parse_macro_input, token::Brace, Token};
18use syn::{Attribute, Error, Expr, PathArguments, PathSegment, Visibility};
19use syn::{BinOp, ExprBinary, ExprRange, ExprUnary, RangeLimits, UnOp};
20use syn::{ExprGroup, ExprParen};
21use syn::{ExprLit, Lit, LitBool};
22
23use num_bigint::{BigInt, Sign, TryFromBigIntError};
24
25mod generate;
26
27#[proc_macro]
28#[doc(hidden)]
29pub fn bounded_integer(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
30 let mut item = parse_macro_input!(input as BoundedInteger);
31
32 let module_name = Ident::new(
34 &format!("__bounded_integer_private_{}", item.ident),
35 item.ident.span(),
36 );
37 let ident = &item.ident;
38 let original_visibility = item.vis;
39
40 let import = quote!(#original_visibility use #module_name::#ident);
41
42 item.vis = raise_one_level(original_visibility);
43 let mut result = TokenStream::new();
44 generate::generate(&item, &mut result);
45
46 quote!(
47 #[allow(non_snake_case)]
48 mod #module_name {
49 #result
50 }
51 #import;
52 )
53 .into()
54}
55
56#[allow(clippy::struct_excessive_bools)]
57struct BoundedInteger {
58 crate_path: TokenStream,
60
61 alloc: bool,
63 arbitrary1: bool,
64 bytemuck1: bool,
65 serde1: bool,
66 std: bool,
67 zerocopy: bool,
68 step_trait: bool,
69
70 attrs: Vec<Attribute>,
72 repr: Repr,
73 vis: Visibility,
74 kind: Kind,
75 ident: Ident,
76 brace_token: Brace,
77 range: RangeInclusive<BigInt>,
78}
79
80impl Parse for BoundedInteger {
81 fn parse(input: ParseStream<'_>) -> parse::Result<Self> {
82 let crate_path = input.parse::<Group>()?.stream();
83
84 let alloc = input.parse::<LitBool>()?.value;
85 let arbitrary1 = input.parse::<LitBool>()?.value;
86 let bytemuck1 = input.parse::<LitBool>()?.value;
87 let serde1 = input.parse::<LitBool>()?.value;
88 let std = input.parse::<LitBool>()?.value;
89 let zerocopy = input.parse::<LitBool>()?.value;
90 let step_trait = input.parse::<LitBool>()?.value;
91
92 let mut attrs = input.call(Attribute::parse_outer)?;
93
94 let repr_pos = attrs.iter().position(|attr| attr.path().is_ident("repr"));
95 let repr = repr_pos
96 .map(|pos| attrs.remove(pos).parse_args::<Repr>())
97 .transpose()?;
98
99 let vis: Visibility = input.parse()?;
100
101 let kind: Kind = input.parse()?;
102
103 let ident: Ident = input.parse()?;
104
105 let range_tokens;
106 let brace_token = braced!(range_tokens in input);
107 let range: ExprRange = range_tokens.parse()?;
108
109 let Some((start_expr, end_expr)) = range.start.as_deref().zip(range.end.as_deref()) else {
110 return Err(Error::new_spanned(range, "Range must be closed"));
111 };
112 let start = eval_expr(start_expr)?;
113 let end = eval_expr(end_expr)?;
114 let end = if let RangeLimits::HalfOpen(_) = range.limits {
115 end - 1
116 } else {
117 end
118 };
119 if start >= end {
120 return Err(Error::new_spanned(
121 range,
122 "The start of the range must be before the end",
123 ));
124 }
125
126 let repr = match repr {
127 Some(explicit_repr) => {
128 if explicit_repr.sign == Unsigned && start.sign() == Sign::Minus {
129 return Err(Error::new_spanned(
130 start_expr,
131 "An unsigned integer cannot hold a negative value",
132 ));
133 }
134
135 if explicit_repr.minimum().is_some_and(|min| start < min) {
136 return Err(Error::new_spanned(
137 start_expr,
138 format_args!(
139 "Bound {start} is below the minimum value for the underlying type",
140 ),
141 ));
142 }
143 if explicit_repr.maximum().is_some_and(|max| end > max) {
144 return Err(Error::new_spanned(
145 end_expr,
146 format_args!(
147 "Bound {end} is above the maximum value for the underlying type",
148 ),
149 ));
150 }
151
152 explicit_repr
153 }
154 None => Repr::smallest_repr(&start, &end).ok_or_else(|| {
155 Error::new_spanned(range, "Range is too wide to fit in any integer primitive")
156 })?,
157 };
158
159 Ok(Self {
160 crate_path,
161 alloc,
162 arbitrary1,
163 bytemuck1,
164 serde1,
165 std,
166 zerocopy,
167 step_trait,
168 attrs,
169 repr,
170 vis,
171 kind,
172 ident,
173 brace_token,
174 range: start..=end,
175 })
176 }
177}
178
179enum Kind {
180 Struct(Token![struct]),
181 Enum(Token![enum]),
182}
183
184impl Parse for Kind {
185 fn parse(input: ParseStream<'_>) -> parse::Result<Self> {
186 Ok(if input.peek(Token![struct]) {
187 Self::Struct(input.parse()?)
188 } else {
189 Self::Enum(input.parse()?)
190 })
191 }
192}
193
194#[derive(Clone, Copy, PartialEq, Eq)]
195enum ReprSign {
196 Signed,
197 Unsigned,
198}
199use ReprSign::{Signed, Unsigned};
200
201struct Repr {
202 sign: ReprSign,
203 size: ReprSize,
204 name: Ident,
205}
206
207impl Repr {
208 fn new(sign: ReprSign, size: ReprSize) -> Self {
209 let prefix = match sign {
210 Signed => 'i',
211 Unsigned => 'u',
212 };
213 Self {
214 sign,
215 size,
216 name: Ident::new(&format!("{prefix}{size}"), Span::call_site()),
217 }
218 }
219
220 fn smallest_repr(min: &BigInt, max: &BigInt) -> Option<Self> {
221 Some(if min.sign() == Sign::Minus {
223 Self::new(
224 Signed,
225 ReprSize::Fixed(cmp::max(
226 ReprSizeFixed::from_bits((min + 1_u8).bits() + 1)?,
227 ReprSizeFixed::from_bits(max.bits() + 1)?,
228 )),
229 )
230 } else {
231 Self::new(
232 Unsigned,
233 ReprSize::Fixed(ReprSizeFixed::from_bits(max.bits())?),
234 )
235 })
236 }
237
238 fn minimum(&self) -> Option<BigInt> {
239 Some(match (self.sign, self.size) {
240 (Unsigned, ReprSize::Fixed(_)) => BigInt::from(0u8),
241 (Signed, ReprSize::Fixed(size)) => -(BigInt::from(1u8) << (size.to_bits() - 1)),
242 (_, ReprSize::Pointer) => return None,
243 })
244 }
245
246 fn maximum(&self) -> Option<BigInt> {
247 Some(match (self.sign, self.size) {
248 (Unsigned, ReprSize::Fixed(size)) => (BigInt::from(1u8) << size.to_bits()) - 1,
249 (Signed, ReprSize::Fixed(size)) => (BigInt::from(1u8) << (size.to_bits() - 1)) - 1,
250 (_, ReprSize::Pointer) => return None,
251 })
252 }
253
254 fn try_number_literal(
255 &self,
256 value: impl Borrow<BigInt>,
257 ) -> Result<Literal, TryFromBigIntError<()>> {
258 macro_rules! match_repr {
259 ($($sign:ident $size:ident $(($fixed:ident))? => $f:ident,)*) => {
260 match (self.sign, self.size) {
261 $(($sign, ReprSize::$size $((ReprSizeFixed::$fixed))?) => {
262 Ok(Literal::$f(value.borrow().try_into()?))
263 })*
264 }
265 }
266 }
267
268 match_repr! {
269 Unsigned Fixed(Fixed8) => u8_suffixed,
270 Unsigned Fixed(Fixed16) => u16_suffixed,
271 Unsigned Fixed(Fixed32) => u32_suffixed,
272 Unsigned Fixed(Fixed64) => u64_suffixed,
273 Unsigned Fixed(Fixed128) => u128_suffixed,
274 Unsigned Pointer => usize_suffixed,
275 Signed Fixed(Fixed8) => i8_suffixed,
276 Signed Fixed(Fixed16) => i16_suffixed,
277 Signed Fixed(Fixed32) => i32_suffixed,
278 Signed Fixed(Fixed64) => i64_suffixed,
279 Signed Fixed(Fixed128) => i128_suffixed,
280 Signed Pointer => isize_suffixed,
281 }
282 }
283
284 fn number_literal(&self, value: impl Borrow<BigInt>) -> Literal {
285 self.try_number_literal(value).unwrap()
286 }
287
288 fn larger_reprs(&self) -> impl Iterator<Item = Self> {
289 match self.sign {
290 Signed => Either::A(self.size.larger_reprs().map(|size| Self::new(Signed, size))),
291 Unsigned => Either::B(
292 self.size
293 .larger_reprs()
294 .map(|size| Self::new(Unsigned, size))
295 .chain(
296 self.size
297 .larger_reprs()
298 .skip(1)
299 .map(|size| Self::new(Signed, size)),
300 ),
301 ),
302 }
303 }
304
305 fn is_usize(&self) -> bool {
306 matches!((self.sign, self.size), (Unsigned, ReprSize::Pointer))
307 }
308}
309
310impl Parse for Repr {
311 fn parse(input: ParseStream<'_>) -> parse::Result<Self> {
312 let name = input.parse::<Ident>()?;
313 let span = name.span();
314 let s = name.to_string();
315
316 let (size, sign) = if let Some(size) = s.strip_prefix('i') {
317 (size, Signed)
318 } else if let Some(size) = s.strip_prefix('u') {
319 (size, Unsigned)
320 } else {
321 return Err(Error::new(span, "Repr must a primitive integer type"));
322 };
323
324 let size = match size {
325 "8" => ReprSize::Fixed(ReprSizeFixed::Fixed8),
326 "16" => ReprSize::Fixed(ReprSizeFixed::Fixed16),
327 "32" => ReprSize::Fixed(ReprSizeFixed::Fixed32),
328 "64" => ReprSize::Fixed(ReprSizeFixed::Fixed64),
329 "128" => ReprSize::Fixed(ReprSizeFixed::Fixed128),
330 "size" => ReprSize::Pointer,
331 unknown => {
332 return Err(Error::new(
333 span,
334 format_args!(
335 "Unknown integer size {unknown}, must be one of 8, 16, 32, 64, 128 or size",
336 ),
337 ));
338 }
339 };
340
341 Ok(Self { sign, size, name })
342 }
343}
344
345impl ToTokens for Repr {
346 fn to_tokens(&self, tokens: &mut TokenStream) {
347 tokens.append(self.name.clone());
348 }
349}
350
351#[derive(Clone, Copy)]
352enum ReprSize {
353 Fixed(ReprSizeFixed),
354
355 Pointer,
357}
358
359impl ReprSize {
360 fn larger_reprs(self) -> impl Iterator<Item = Self> {
361 match self {
362 Self::Fixed(fixed) => Either::A(fixed.larger_reprs().map(Self::Fixed)),
363 Self::Pointer => Either::B(std::iter::once(Self::Pointer)),
364 }
365 }
366}
367
368impl Display for ReprSize {
369 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
370 match self {
371 Self::Fixed(fixed) => fixed.fmt(f),
372 Self::Pointer => f.write_str("size"),
373 }
374 }
375}
376
377#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
378enum ReprSizeFixed {
379 Fixed8,
380 Fixed16,
381 Fixed32,
382 Fixed64,
383 Fixed128,
384}
385
386impl ReprSizeFixed {
387 fn to_bits(self) -> u64 {
388 match self {
389 ReprSizeFixed::Fixed8 => 8,
390 ReprSizeFixed::Fixed16 => 16,
391 ReprSizeFixed::Fixed32 => 32,
392 ReprSizeFixed::Fixed64 => 64,
393 ReprSizeFixed::Fixed128 => 128,
394 }
395 }
396
397 fn from_bits(bits: u64) -> Option<Self> {
398 Some(match bits {
399 0..=8 => Self::Fixed8,
400 9..=16 => Self::Fixed16,
401 17..=32 => Self::Fixed32,
402 33..=64 => Self::Fixed64,
403 65..=128 => Self::Fixed128,
404 129..=u64::MAX => return None,
405 })
406 }
407
408 fn larger_reprs(self) -> impl Iterator<Item = Self> {
409 const REPRS: [ReprSizeFixed; 5] = [
410 ReprSizeFixed::Fixed8,
411 ReprSizeFixed::Fixed16,
412 ReprSizeFixed::Fixed32,
413 ReprSizeFixed::Fixed64,
414 ReprSizeFixed::Fixed128,
415 ];
416 let index = match self {
417 Self::Fixed8 => 0,
418 Self::Fixed16 => 1,
419 Self::Fixed32 => 2,
420 Self::Fixed64 => 3,
421 Self::Fixed128 => 4,
422 };
423 REPRS[index..].iter().copied()
424 }
425}
426
427impl Display for ReprSizeFixed {
428 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
429 f.write_str(match self {
430 Self::Fixed8 => "8",
431 Self::Fixed16 => "16",
432 Self::Fixed32 => "32",
433 Self::Fixed64 => "64",
434 Self::Fixed128 => "128",
435 })
436 }
437}
438
439fn eval_expr(expr: &Expr) -> syn::Result<BigInt> {
440 Ok(match expr {
441 Expr::Lit(ExprLit { lit, .. }) => match lit {
442 Lit::Byte(byte) => byte.value().into(),
443 Lit::Int(int) => int.base10_parse()?,
444 _ => {
445 return Err(Error::new_spanned(lit, "literal must be integer"));
446 }
447 },
448 Expr::Unary(ExprUnary { op, expr, .. }) => {
449 let expr = eval_expr(expr)?;
450 match op {
451 UnOp::Not(_) => !expr,
452 UnOp::Neg(_) => -expr,
453 _ => return Err(Error::new_spanned(op, "unary operator must be ! or -")),
454 }
455 }
456 Expr::Binary(ExprBinary {
457 left, op, right, ..
458 }) => {
459 let left = eval_expr(left)?;
460 let right = eval_expr(right)?;
461 match op {
462 BinOp::Add(_) => left + right,
463 BinOp::Sub(_) => left - right,
464 BinOp::Mul(_) => left * right,
465 BinOp::Div(_) => left
466 .checked_div(&right)
467 .ok_or_else(|| Error::new_spanned(op, "Attempted to divide by zero"))?,
468 BinOp::Rem(_) => left % right,
469 BinOp::BitXor(_) => left ^ right,
470 BinOp::BitAnd(_) => left & right,
471 BinOp::BitOr(_) => left | right,
472 _ => {
473 return Err(Error::new_spanned(
474 op,
475 "operator not supported in this context",
476 ));
477 }
478 }
479 }
480 Expr::Group(ExprGroup { expr, .. }) | Expr::Paren(ExprParen { expr, .. }) => {
481 eval_expr(expr)?
482 }
483 _ => return Err(Error::new_spanned(expr, "expected simple expression")),
484 })
485}
486
487fn raise_one_level(vis: Visibility) -> Visibility {
499 match vis {
500 Visibility::Inherited => syn::parse2(quote!(pub(super))).unwrap(),
501 Visibility::Restricted(mut restricted)
502 if restricted.path.segments.first().unwrap().ident == "self" =>
503 {
504 let first = &mut restricted.path.segments.first_mut().unwrap().ident;
505 *first = Ident::new("super", first.span());
506 Visibility::Restricted(restricted)
507 }
508 Visibility::Restricted(mut restricted)
509 if restricted.path.segments.first().unwrap().ident == "super" =>
510 {
511 restricted
512 .in_token
513 .get_or_insert_with(<Token![in]>::default);
514 let first = PathSegment {
515 ident: restricted.path.segments.first().unwrap().ident.clone(),
516 arguments: PathArguments::None,
517 };
518 restricted.path.segments.insert(0, first);
519 Visibility::Restricted(restricted)
520 }
521 absolute_visibility => absolute_visibility,
522 }
523}
524
525#[test]
526fn test_raise_one_level() {
527 #[track_caller]
528 fn assert_output(input: TokenStream, output: TokenStream) {
529 let tokens = raise_one_level(syn::parse2(input).unwrap()).into_token_stream();
530 assert_eq!(tokens.to_string(), output.to_string());
531 drop(output);
532 }
533
534 assert_output(TokenStream::new(), quote!(pub(super)));
535 assert_output(quote!(pub(self)), quote!(pub(super)));
536 assert_output(quote!(pub(in self)), quote!(pub(in super)));
537 assert_output(
538 quote!(pub(in self::some::path)),
539 quote!(pub(in super::some::path)),
540 );
541 assert_output(quote!(pub(super)), quote!(pub(in super::super)));
542 assert_output(quote!(pub(in super)), quote!(pub(in super::super)));
543 assert_output(
544 quote!(pub(in super::some::path)),
545 quote!(pub(in super::super::some::path)),
546 );
547
548 assert_output(quote!(pub), quote!(pub));
549 assert_output(quote!(pub(crate)), quote!(pub(crate)));
550 assert_output(quote!(pub(in crate)), quote!(pub(in crate)));
551 assert_output(
552 quote!(pub(in crate::some::path)),
553 quote!(pub(in crate::some::path)),
554 );
555}
556
557enum Either<A, B> {
558 A(A),
559 B(B),
560}
561impl<T, A: Iterator<Item = T>, B: Iterator<Item = T>> Iterator for Either<A, B> {
562 type Item = T;
563 fn next(&mut self) -> Option<Self::Item> {
564 match self {
565 Self::A(a) => a.next(),
566 Self::B(b) => b.next(),
567 }
568 }
569}