1#![warn(clippy::pedantic, rust_2018_idioms, unused_qualifications)]
5#![allow(clippy::single_match_else, clippy::match_bool)]
6
7use std::array;
8use std::fmt::Debug;
9
10use proc_macro2::{Delimiter, Ident, Literal, Span, TokenStream, TokenTree};
11use quote::{ToTokens, quote, quote_spanned};
12
13#[proc_macro]
14#[doc(hidden)]
15#[expect(clippy::too_many_lines)]
16pub fn bounded_integer(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
17 let input = TokenStream::from(input).into_iter().map(|t| {
18 let TokenTree::Group(group) = t else {
19 panic!("non-group in input")
20 };
21 assert_eq!(group.delimiter(), Delimiter::Bracket);
22 group.stream()
23 });
24 let [
25 zerocopy,
26 outer_attr,
27 mut attrs,
28 vis,
29 super_vis,
30 is_named,
31 item_kind,
32 name,
33 min_or_variants,
34 max_or_none,
35 crate_path,
36 ] = to_array(input);
37
38 let zerocopy = match to_array(zerocopy) {
39 [TokenTree::Punct(p)] if p.as_char() == '-' => false,
40 [TokenTree::Punct(p)] if p.as_char() == '+' => true,
41 [t] => panic!("zerocopy ({t})"),
42 };
43
44 let [TokenTree::Ident(item_kind)] = to_array(item_kind) else {
45 panic!("item kind")
46 };
47 let is_enum = match &*item_kind.to_string() {
48 "struct" => false,
49 "enum" => true,
50 s => panic!("unknown item kind {s}"),
51 };
52 let [TokenTree::Ident(name)] = to_array(name) else {
53 panic!("name")
54 };
55
56 let mut new_attrs = TokenStream::new();
57 let mut import_attrs = TokenStream::new();
58 let mut maybe_repr = None;
59 for attr in attrs {
60 let TokenTree::Group(group) = &attr else {
61 panic!("attr ({attr})")
62 };
63 let tokens = group.stream().into_iter().collect::<Vec<_>>();
64 if let Some(TokenTree::Ident(i)) = tokens.first() {
65 let name = i.to_string();
66
67 if name == "repr"
68 && let [_, TokenTree::Group(g)] = &*tokens
69 && g.delimiter() == Delimiter::Parenthesis
70 {
71 if maybe_repr.is_some() {
72 return error!(i.span(), "duplicate `repr` attribute");
73 }
74 maybe_repr = Some(g.stream());
75 continue;
76 } else if ["allow", "expect", "warn", "deny", "forbid"].contains(&&*name)
77 && let [_, TokenTree::Group(g)] = &*tokens
78 && g.delimiter() == Delimiter::Parenthesis
79 && let [Some(TokenTree::Ident(lint)), None] = {
80 let mut iter = g.stream().into_iter();
81 [iter.next(), iter.next()]
82 }
83 && (lint == "unused" || lint == "unused_imports")
84 {
85 import_attrs.extend(quote!(# #attr));
86 continue;
87 }
88 }
89 new_attrs.extend(quote!(# #attr));
90 }
91 attrs = new_attrs;
92
93 let (variants, min, max, min_val, max_val);
94 match to_array(is_named) {
95 [TokenTree::Punct(p)] if p.as_char() == '-' => {
97 [min, max] = [min_or_variants, max_or_none].map(ungroup_none);
98 [min_val, max_val] = [&min, &max].map(|lit| {
99 parse_literal(lit.clone()).map(|(lit, repr)| {
100 if let Some(repr) = repr
102 && maybe_repr.is_none()
103 {
104 maybe_repr = Some(quote!(#repr));
105 }
106 lit
107 })
108 });
109
110 variants = match is_enum {
111 false => None,
112 true => {
113 let Some(min_val) = min_val else {
114 return error!(min, "`enum` requires bound to be statically known");
115 };
116 let Some(max_val) = max_val else {
117 return error!(max, "`enum` requires bound to be statically known");
118 };
119 let Some(range) = range(min_val, max_val) else {
120 return error!(min, "refusing to generate this many `enum` variants");
121 };
122 let mut variants = TokenStream::new();
123 let min_span = stream_span(min.clone());
124 for int in range {
125 let enum_variant_name = int.enum_variant_name(min_span);
126 if int == min_val {
127 variants.extend(quote!(#[allow(dead_code)] #enum_variant_name = #min,));
128 } else {
129 variants.extend(quote!(#[allow(dead_code)] #enum_variant_name,));
130 }
131 }
132 Some(variants)
133 }
134 };
135 }
136 [TokenTree::Punct(p)] if p.as_char() == '+' => {
138 assert!(is_enum);
139 assert!(max_or_none.into_iter().next().is_none());
140
141 let mut min_current = None::<((Int, TokenStream), Int, Span)>;
143 let mut variant_list = TokenStream::new();
144 for variant in min_or_variants {
145 let TokenTree::Group(variant) = variant else {
146 panic!("variant")
147 };
148 let [
149 TokenTree::Group(attrs),
150 TokenTree::Ident(variant_name),
151 TokenTree::Group(variant_val),
152 ] = to_array(variant.stream())
153 else {
154 panic!("variant inner")
155 };
156 let attrs = attrs.stream();
157 let variant_val = variant_val.stream();
158 min_current = Some(if variant_val.is_empty() {
159 variant_list.extend(quote!(#attrs #variant_name,));
160 match min_current {
161 Some((min, current, current_span)) => match current.succ() {
162 Some(current) => (min, current, current_span),
163 None => {
164 return error!(
165 variant_name.span(),
166 "too many variants (overflows a u128)"
167 );
168 }
169 },
170 None => (
171 (Int::new(true, 0), quote_spanned!(variant_name.span()=> 0)),
172 Int::new(true, 0),
173 variant_name.span(),
174 ),
175 }
176 } else {
177 variant_list.extend(quote!(#attrs #variant_name = #variant_val,));
178 let variant_val = ungroup_none(variant_val);
179 let Some((int, _)) = parse_literal(variant_val.clone()) else {
180 return error!(variant_val, "could not parse variant value");
181 };
182 match min_current {
183 Some((min, current, _)) if current.succ() == Some(int) => {
184 (min, int, stream_span(variant_val))
185 }
186 Some(_) => return error!(variant_val, "enum not contiguous"),
187 None => ((int, variant_val.clone()), int, stream_span(variant_val)),
188 }
189 });
190 }
191 variants = Some(variant_list);
192 [(min_val, min), (max_val, max)] = match min_current {
193 Some(((min_val, min), current, current_span)) => [
194 (Some(min_val), min),
195 (Some(current), current.literal(current_span)),
196 ],
197 None => [
198 (Some(Int::new(true, 1)), quote!(1)),
199 (Some(Int::new(true, 0)), quote!(0)),
200 ],
201 };
202 }
203 [t] => panic!("named ({t})"),
204 }
205
206 let zero = min_val
207 .zip(max_val)
208 .map(|(min, max)| (min..=max).contains(&Int::new(true, 0)));
209 let one = min_val
210 .zip(max_val)
211 .map(|(min, max)| (min..=max).contains(&Int::new(true, 1)));
212 if zero == Some(true) && zerocopy {
213 attrs.extend(quote!(#[derive(#crate_path::__private::zerocopy::FromZeros)]));
214 }
215 let zero_token = match zero {
216 Some(true) => quote!(zero,),
217 Some(false) | None => quote!(),
218 };
219 let one_token = match one {
220 Some(true) => quote!(one,),
221 Some(false) | None => quote!(),
222 };
223
224 let repr = match (maybe_repr, min_val, max_val) {
225 (Some(repr), _, _) => repr,
226 (None, Some(min_val), Some(max_val)) => match infer_repr(min_val, max_val) {
227 Some(repr) => {
228 let repr = Ident::new(&repr, stream_span(min.clone()));
229 quote!(#repr)
230 }
231 None => return error!(min, "range too large for any integer type"),
232 },
233 (None, _, _) => {
234 let msg = "no #[repr] attribute found, and could not infer";
235 return error!(min, "{msg}");
236 }
237 };
238
239 match is_enum {
240 false => attrs.extend(quote!(#[repr(transparent)])),
241 true => attrs.extend(quote!(#[repr(#repr)])),
242 }
243
244 if matches!(repr.to_string().trim(), "u8" | "i8") && zerocopy {
245 attrs.extend(quote!(#[derive(#crate_path::__private::zerocopy::Unaligned)]));
246 }
247
248 let item = match variants {
249 Some(variants) => quote!({ #variants }),
250 None if zero == Some(false) => quote!((::core::num::NonZero<#repr>);),
251 None => quote!((#repr);),
252 };
253
254 let module_name = Ident::new(&format!("__bounded_integer_private_{name}"), name.span());
256
257 let res = quote!(
258 #[allow(non_snake_case)]
259 #outer_attr
260 mod #module_name {
261 #attrs
262 #super_vis #item_kind #name #item
263
264 #crate_path::unsafe_api! {
265 for #name,
266 unsafe repr: #repr,
267 min: #min,
268 max: #max,
269 #zero_token
270 #one_token
271 }
272 }
273 #import_attrs #vis use #module_name::#name;
274 );
275
276 res.into()
277}
278
279fn to_array<I: IntoIterator<Item: Debug>, const N: usize>(iter: I) -> [I::Item; N] {
280 let mut iter = iter.into_iter();
281 let array = array::from_fn(|_| iter.next().expect("iterator too short"));
282 if let Some(item) = iter.next() {
283 panic!("iterator too long: found {item:?}");
284 }
285 array
286}
287
288#[derive(Debug, Clone, Copy, PartialEq, Eq)]
289struct Int {
290 nonnegative: bool,
291 magnitude: u128,
292}
293
294impl Int {
295 fn new(nonnegative: bool, magnitude: u128) -> Self {
296 Self {
297 nonnegative,
298 magnitude,
299 }
300 }
301 fn succ(self) -> Option<Self> {
302 Some(match self.nonnegative {
303 true => Self::new(true, self.magnitude.checked_add(1)?),
304 false if self.magnitude == 1 => Self::new(true, 0),
305 false => Self::new(false, self.magnitude - 1),
306 })
307 }
308 fn enum_variant_name(self, span: Span) -> Ident {
309 if self.magnitude == 0 {
310 Ident::new("Z", span)
311 } else if self.nonnegative {
312 Ident::new(&format!("P{}", self.magnitude), span)
313 } else {
314 Ident::new(&format!("N{}", self.magnitude), span)
315 }
316 }
317 fn literal(self, span: Span) -> TokenStream {
318 let mut magnitude = Literal::u128_unsuffixed(self.magnitude);
319 magnitude.set_span(span);
320 match self.nonnegative {
321 true => quote!(#magnitude),
322 false => quote!(-#magnitude),
323 }
324 }
325}
326
327impl PartialOrd for Int {
328 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
329 Some(self.cmp(other))
330 }
331}
332
333impl Ord for Int {
334 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
335 match (self.nonnegative, other.nonnegative) {
336 (true, true) => self.magnitude.cmp(&other.magnitude),
337 (true, false) => std::cmp::Ordering::Greater,
338 (false, true) => std::cmp::Ordering::Less,
339 (false, false) => other.magnitude.cmp(&self.magnitude),
340 }
341 }
342}
343
344fn parse_literal(e: TokenStream) -> Option<(Int, Option<Ident>)> {
345 let mut tokens = e.into_iter().peekable();
346 let minus = tokens
347 .next_if(|t| matches!(t, TokenTree::Punct(p) if p.as_char() == '-'))
348 .is_some();
349 let Some(TokenTree::Literal(lit)) = tokens.next() else {
350 return None;
351 };
352 if tokens.next().is_some() {
353 return None;
354 }
355
356 let mut lit_chars = &*lit.to_string();
360
361 let (base, base_len) = match lit_chars.get(..2) {
362 Some("0x") => (16, 2),
363 Some("0o") => (8, 2),
364 Some("0b") => (2, 2),
365 _ => (10, 0),
366 };
367 lit_chars = &lit_chars[base_len..];
368
369 let mut magnitude = 0_u128;
370 let mut has_digit = None;
371
372 let suffix = loop {
373 lit_chars = lit_chars.trim_start_matches('_');
374 let Some(c) = lit_chars.chars().next() else {
375 has_digit?;
376 break None;
377 };
378 if let 'i' | 'u' = c {
379 let ("8" | "16" | "32" | "64" | "128" | "size") = &lit_chars[1..] else {
380 return None;
381 };
382 break Some(Ident::new(lit_chars, lit.span()));
383 }
384 let digit = c.to_digit(base)?;
385 lit_chars = &lit_chars[1..];
386 magnitude = magnitude
387 .checked_mul(base.into())?
388 .checked_add(digit.into())?;
389 has_digit = Some(());
390 };
391
392 let lit = Int::new(!minus || magnitude == 0, magnitude);
393 Some((lit, suffix))
394}
395
396fn range(min: Int, max: Int) -> Option<impl Iterator<Item = Int>> {
397 let range_minus_one = match (max.nonnegative, min.nonnegative) {
398 (true, true) => max.magnitude.saturating_sub(min.magnitude),
399 (true, false) => max.magnitude.saturating_add(min.magnitude),
400 (false, true) => 0,
401 (false, false) => min.magnitude.saturating_sub(max.magnitude),
402 };
403 if 100_000 <= range_minus_one {
404 return None;
405 }
406 #[expect(clippy::reversed_empty_ranges)]
407 let (negative_part, nonnegative_part) = match (min.nonnegative, max.nonnegative) {
408 (true, true) => (1..=0, min.magnitude..=max.magnitude),
409 (false, true) => (1..=min.magnitude, 0..=max.magnitude),
410 (true, false) => (1..=0, 1..=0),
411 (false, false) => (max.magnitude..=min.magnitude, 1..=0),
412 };
413 let negative_part = negative_part.map(|i| Int::new(false, i));
414 let nonnegative_part = nonnegative_part.map(|i| Int::new(true, i));
415 Some(negative_part.rev().chain(nonnegative_part))
416}
417
418fn infer_repr(min: Int, max: Int) -> Option<String> {
419 for bits in [8, 16, 32, 64, 128] {
420 let fits_unsigned =
421 |lit: Int| lit.nonnegative && lit.magnitude <= (u128::MAX >> (128 - bits));
422 let fits_signed = |lit: Int| {
423 (lit.nonnegative && lit.magnitude < (1 << (bits - 1)))
424 || (!lit.nonnegative && lit.magnitude <= (1 << (bits - 1)))
425 };
426 if fits_unsigned(min) && fits_unsigned(max) {
427 return Some(format!("u{bits}"));
428 } else if fits_signed(min) && fits_signed(max) {
429 return Some(format!("i{bits}"));
430 }
431 }
432 None
433}
434
435fn ungroup_none(tokens: TokenStream) -> TokenStream {
436 let mut tokens = tokens.into_iter().peekable();
437 if let Some(TokenTree::Group(g)) =
438 tokens.next_if(|t| matches!(t, TokenTree::Group(g) if g.delimiter() == Delimiter::None))
439 {
440 return g.stream();
441 }
442 tokens.collect()
445}
446
447macro_rules! error {
448 ($span:expr, $($fmt:tt)*) => {{
449 let span = SpanHelper($span).span_helper();
450 let msg = format!($($fmt)*);
451 proc_macro::TokenStream::from(quote_spanned!(span=> compile_error!(#msg);))
452 }};
453}
454use error;
455
456struct SpanHelper<T>(T);
457impl SpanHelper<TokenStream> {
458 fn span_helper(self) -> Span {
459 stream_span(self.0.into_token_stream())
460 }
461}
462trait SpanHelperTrait {
463 fn span_helper(self) -> Span;
464}
465impl SpanHelperTrait for SpanHelper<Span> {
466 fn span_helper(self) -> Span {
467 self.0
468 }
469}
470
471fn stream_span(stream: TokenStream) -> Span {
472 stream
473 .into_iter()
474 .next()
475 .map_or_else(Span::call_site, |token| token.span())
476}