1use litrs::IntegerLit;
2use proc_macro2::{Group, Literal, TokenStream, TokenTree};
3use quote::{format_ident, quote_spanned, ToTokens};
4use syn::parse::{Parse, ParseBuffer, Parser};
5use syn::{parenthesized, parse_quote, token, Error, LitInt, Path, Result, Token};
6
7#[proc_macro]
8pub fn bitint(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
9 bitint_impl(tokens.into()).into()
10}
11
12#[proc_macro_attribute]
13pub fn bitint_literals(
14 attr: proc_macro::TokenStream,
15 item: proc_macro::TokenStream,
16) -> proc_macro::TokenStream {
17 bitint_literals_impl(attr.into(), item.into()).into()
18}
19
20struct BitintInput {
21 _paren_token: token::Paren,
22 crate_path: Path,
23 _comma_token: Token![,],
24 lit: LitInt,
25}
26
27impl Parse for BitintInput {
28 fn parse(input: &ParseBuffer) -> Result<Self> {
29 let content;
30 Ok(Self {
31 _paren_token: parenthesized!(content in input),
32 crate_path: content.parse()?,
33 _comma_token: content.parse()?,
34 lit: content.parse()?,
35 })
36 }
37}
38
39fn bitint_impl(tokens: TokenStream) -> TokenStream {
40 let input: BitintInput = match syn::parse2(tokens) {
41 Ok(input) => input,
42 Err(e) => return e.into_compile_error(),
43 };
44 match rewrite_literal(&input.crate_path, input.lit.token()) {
45 RewriteResult::Rewritten(tokens) => tokens,
46 RewriteResult::UnrecognizedSuffix(literal) => Error::new(
47 literal.span(),
48 "literal must have a suffix: 'U' followed by an integer in 1..=128",
49 )
50 .into_compile_error(),
51 RewriteResult::ValueError(e) => e.into_compile_error(),
52 }
53}
54
55enum RewriteResult {
56 Rewritten(TokenStream),
57 UnrecognizedSuffix(Literal),
58 ValueError(Error),
59}
60
61fn rewrite_literal(crate_path: &Path, literal: Literal) -> RewriteResult {
62 let Ok(integer_lit) = IntegerLit::try_from(literal.clone()) else {
64 return RewriteResult::UnrecognizedSuffix(literal);
65 };
66 let Some(width) = parse_suffix(integer_lit.suffix()) else {
67 return RewriteResult::UnrecognizedSuffix(literal);
68 };
69
70 let span = literal.span();
72 let Some(value) = integer_lit.value::<u128>() else {
73 return RewriteResult::ValueError(
74 Error::new(span, "could not parse integer literal")
75 );
76 };
77 if width < 128 {
78 let max: u128 = (1 << width) - 1;
79 if value > max {
80 return RewriteResult::ValueError(Error::new(
81 span,
82 format!("integer literal value {value} out of range for U{width}"),
83 ));
84 }
85 }
86
87 let type_name = format_ident!("U{width}", span = span);
89 let mut new_literal = Literal::u128_unsuffixed(value);
90 new_literal.set_span(span);
91 RewriteResult::Rewritten(
92 quote_spanned! {span=> #crate_path::#type_name::new_masked(#new_literal) },
93 )
94}
95
96fn parse_suffix(suffix: &str) -> Option<u8> {
97 if !suffix.starts_with('U') {
98 return None;
99 }
100 let width: u8 = suffix[1..].parse().ok()?;
101 if !(1..=128).contains(&width) {
102 return None;
103 }
104 Some(width)
105}
106
107fn map_token_stream_literals(
108 stream: TokenStream,
109 f: &mut impl FnMut(Literal) -> TokenStream,
110) -> TokenStream {
111 stream
112 .into_iter()
113 .flat_map(|tt| map_token_tree_literals(tt, f))
114 .collect()
115}
116
117fn map_token_tree_literals(
118 tt: TokenTree,
119 f: &mut impl FnMut(Literal) -> TokenStream,
120) -> TokenStream {
121 match tt {
122 TokenTree::Group(group) => {
123 let mut new_group = Group::new(
124 group.delimiter(),
125 map_token_stream_literals(group.stream(), f),
126 );
127 new_group.set_span(group.span());
128 TokenTree::Group(new_group).into()
129 }
130 TokenTree::Ident(_) => tt.into(),
131 TokenTree::Punct(_) => tt.into(),
132 TokenTree::Literal(lit) => f(lit),
133 }
134}
135
136#[derive(Default)]
137struct ConfigBuilder {
138 crate_path: Option<Path>,
139}
140
141impl ConfigBuilder {
142 fn parser(&mut self) -> impl Parser<Output = ()> + '_ {
143 syn::meta::parser(|meta| {
144 if meta.path.is_ident("crate_path") {
145 self.crate_path = Some(meta.value()?.parse()?);
146 Ok(())
147 } else {
148 Err(meta.error("unsupported property"))
149 }
150 })
151 }
152
153 fn build(self) -> Config {
154 Config {
155 crate_path: self.crate_path.unwrap_or_else(|| parse_quote! { ::bitint }),
156 }
157 }
158}
159
160struct Config {
161 crate_path: Path,
162}
163
164impl Config {
165 fn new(attr: TokenStream) -> (Self, Errors) {
166 let mut errors = Errors::new();
167 let mut builder = ConfigBuilder::default();
168 if !attr.is_empty() {
169 errors.record(builder.parser().parse2(attr));
170 }
171 (builder.build(), errors)
172 }
173}
174
175#[derive(Default)]
176struct Errors {
177 error: Option<Error>,
178}
179
180impl Errors {
181 fn new() -> Self {
182 Default::default()
183 }
184
185 fn push(&mut self, e: Error) {
186 match &mut self.error {
187 None => self.error = Some(e),
188 Some(error) => error.combine(e),
189 }
190 }
191
192 fn record(&mut self, result: Result<()>) {
193 if let Err(e) = result {
194 self.push(e);
195 }
196 }
197}
198
199impl ToTokens for Errors {
200 fn to_tokens(&self, tokens: &mut TokenStream) {
201 if let Some(error) = &self.error {
202 tokens.extend(error.to_compile_error());
203 }
204 }
205}
206
207fn bitint_literals_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
208 let (cfg, cfg_errors) = Config::new(attr);
209 let mut result = cfg_errors.into_token_stream();
210
211 result.extend(map_token_stream_literals(
212 item,
213 &mut |literal| match rewrite_literal(&cfg.crate_path, literal) {
214 RewriteResult::Rewritten(tokens) => tokens,
215 RewriteResult::UnrecognizedSuffix(literal) => TokenTree::Literal(literal).into(),
216 RewriteResult::ValueError(e) => e.into_compile_error(),
217 },
218 ));
219
220 result
221}
222
223#[cfg(test)]
224mod tests {
225 use quote::{quote, ToTokens};
226 use std::fmt::{self, Debug, Formatter};
227 use syn::parse::{Parse, ParseStream};
228 use syn::{Expr, Item, Result};
229
230 use super::{bitint_impl, bitint_literals_impl};
231
232 #[test]
233 fn bitint_simple() {
234 assert_eq!(
235 syn::parse2::<Expr>(bitint_impl(quote! { (some::path::to, 7_U3) })).unwrap(),
236 syn::parse2::<Expr>(quote! { some::path::to::U3::new_masked(7) }).unwrap(),
237 );
238 }
239
240 #[derive(PartialEq, Eq)]
241 struct ParseItems(Vec<Item>);
242
243 impl Parse for ParseItems {
244 fn parse(input: ParseStream) -> Result<Self> {
245 let mut items = Vec::new();
246 while !input.is_empty() {
247 items.push(input.parse()?);
248 }
249 Ok(Self(items))
250 }
251 }
252
253 impl Debug for ParseItems {
254 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
255 let mut delim = "[";
256 for item in &self.0 {
257 write!(f, "{delim}")?;
258 delim = ", ";
259 write!(f, "{:?}", item.to_token_stream().to_string())?;
260 }
261 write!(f, "]")
262 }
263 }
264
265 #[test]
266 fn bitint_literals_simple() {
267 assert_eq!(
268 syn::parse2::<ParseItems>(bitint_literals_impl(
269 quote! {},
270 quote! { fn foo() { 1234567_U24 } },
271 ))
272 .unwrap(),
273 syn::parse2::<ParseItems>(quote! {
274 fn foo() { ::bitint::U24::new_masked(1234567) }
275 })
276 .unwrap(),
277 );
278 }
279
280 #[test]
281 fn bitint_literals_with_crate_path() {
282 assert_eq!(
283 syn::parse2::<ParseItems>(bitint_literals_impl(
284 quote! { crate_path = path::to::bitint_crate },
285 quote! { fn foo() { 1234567_U24 } },
286 ))
287 .unwrap(),
288 syn::parse2::<ParseItems>(quote! {
289 fn foo() { path::to::bitint_crate::U24::new_masked(1234567) }
290 })
291 .unwrap(),
292 );
293 }
294}