bitint_macros/
lib.rs

1use litrs::IntegerLit;
2use proc_macro2::{Group, Literal, TokenStream, TokenTree};
3use quote::{format_ident, quote_spanned, ToTokens};
4use syn::parse::{Parse, ParseBuffer, Parser};
5use syn::{parenthesized, parse_quote, token, Error, LitInt, Path, Result, Token};
6
7#[proc_macro]
8pub fn bitint(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
9    bitint_impl(tokens.into()).into()
10}
11
12#[proc_macro_attribute]
13pub fn bitint_literals(
14    attr: proc_macro::TokenStream,
15    item: proc_macro::TokenStream,
16) -> proc_macro::TokenStream {
17    bitint_literals_impl(attr.into(), item.into()).into()
18}
19
20struct BitintInput {
21    _paren_token: token::Paren,
22    crate_path: Path,
23    _comma_token: Token![,],
24    lit: LitInt,
25}
26
27impl Parse for BitintInput {
28    fn parse(input: &ParseBuffer) -> Result<Self> {
29        let content;
30        Ok(Self {
31            _paren_token: parenthesized!(content in input),
32            crate_path: content.parse()?,
33            _comma_token: content.parse()?,
34            lit: content.parse()?,
35        })
36    }
37}
38
39fn bitint_impl(tokens: TokenStream) -> TokenStream {
40    let input: BitintInput = match syn::parse2(tokens) {
41        Ok(input) => input,
42        Err(e) => return e.into_compile_error(),
43    };
44    match rewrite_literal(&input.crate_path, input.lit.token()) {
45        RewriteResult::Rewritten(tokens) => tokens,
46        RewriteResult::UnrecognizedSuffix(literal) => Error::new(
47            literal.span(),
48            "literal must have a suffix: 'U' followed by an integer in 1..=128",
49        )
50        .into_compile_error(),
51        RewriteResult::ValueError(e) => e.into_compile_error(),
52    }
53}
54
55enum RewriteResult {
56    Rewritten(TokenStream),
57    UnrecognizedSuffix(Literal),
58    ValueError(Error),
59}
60
61fn rewrite_literal(crate_path: &Path, literal: Literal) -> RewriteResult {
62    // Only rewrite integer literals with a recognized suffix.
63    let Ok(integer_lit) = IntegerLit::try_from(literal.clone()) else {
64        return RewriteResult::UnrecognizedSuffix(literal);
65    };
66    let Some(width) = parse_suffix(integer_lit.suffix()) else {
67        return RewriteResult::UnrecognizedSuffix(literal);
68    };
69
70    // Parse the value and enforce bounds.
71    let span = literal.span();
72    let Some(value) = integer_lit.value::<u128>() else {
73        return RewriteResult::ValueError(
74            Error::new(span, "could not parse integer literal")
75        );
76    };
77    if width < 128 {
78        let max: u128 = (1 << width) - 1;
79        if value > max {
80            return RewriteResult::ValueError(Error::new(
81                span,
82                format!("integer literal value {value} out of range for U{width}"),
83            ));
84        }
85    }
86
87    // Build the rewritten literal.
88    let type_name = format_ident!("U{width}", span = span);
89    let mut new_literal = Literal::u128_unsuffixed(value);
90    new_literal.set_span(span);
91    RewriteResult::Rewritten(
92        quote_spanned! {span=> #crate_path::#type_name::new_masked(#new_literal) },
93    )
94}
95
96fn parse_suffix(suffix: &str) -> Option<u8> {
97    if !suffix.starts_with('U') {
98        return None;
99    }
100    let width: u8 = suffix[1..].parse().ok()?;
101    if !(1..=128).contains(&width) {
102        return None;
103    }
104    Some(width)
105}
106
107fn map_token_stream_literals(
108    stream: TokenStream,
109    f: &mut impl FnMut(Literal) -> TokenStream,
110) -> TokenStream {
111    stream
112        .into_iter()
113        .flat_map(|tt| map_token_tree_literals(tt, f))
114        .collect()
115}
116
117fn map_token_tree_literals(
118    tt: TokenTree,
119    f: &mut impl FnMut(Literal) -> TokenStream,
120) -> TokenStream {
121    match tt {
122        TokenTree::Group(group) => {
123            let mut new_group = Group::new(
124                group.delimiter(),
125                map_token_stream_literals(group.stream(), f),
126            );
127            new_group.set_span(group.span());
128            TokenTree::Group(new_group).into()
129        }
130        TokenTree::Ident(_) => tt.into(),
131        TokenTree::Punct(_) => tt.into(),
132        TokenTree::Literal(lit) => f(lit),
133    }
134}
135
136#[derive(Default)]
137struct ConfigBuilder {
138    crate_path: Option<Path>,
139}
140
141impl ConfigBuilder {
142    fn parser(&mut self) -> impl Parser<Output = ()> + '_ {
143        syn::meta::parser(|meta| {
144            if meta.path.is_ident("crate_path") {
145                self.crate_path = Some(meta.value()?.parse()?);
146                Ok(())
147            } else {
148                Err(meta.error("unsupported property"))
149            }
150        })
151    }
152
153    fn build(self) -> Config {
154        Config {
155            crate_path: self.crate_path.unwrap_or_else(|| parse_quote! { ::bitint }),
156        }
157    }
158}
159
160struct Config {
161    crate_path: Path,
162}
163
164impl Config {
165    fn new(attr: TokenStream) -> (Self, Errors) {
166        let mut errors = Errors::new();
167        let mut builder = ConfigBuilder::default();
168        if !attr.is_empty() {
169            errors.record(builder.parser().parse2(attr));
170        }
171        (builder.build(), errors)
172    }
173}
174
175#[derive(Default)]
176struct Errors {
177    error: Option<Error>,
178}
179
180impl Errors {
181    fn new() -> Self {
182        Default::default()
183    }
184
185    fn push(&mut self, e: Error) {
186        match &mut self.error {
187            None => self.error = Some(e),
188            Some(error) => error.combine(e),
189        }
190    }
191
192    fn record(&mut self, result: Result<()>) {
193        if let Err(e) = result {
194            self.push(e);
195        }
196    }
197}
198
199impl ToTokens for Errors {
200    fn to_tokens(&self, tokens: &mut TokenStream) {
201        if let Some(error) = &self.error {
202            tokens.extend(error.to_compile_error());
203        }
204    }
205}
206
207fn bitint_literals_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
208    let (cfg, cfg_errors) = Config::new(attr);
209    let mut result = cfg_errors.into_token_stream();
210
211    result.extend(map_token_stream_literals(
212        item,
213        &mut |literal| match rewrite_literal(&cfg.crate_path, literal) {
214            RewriteResult::Rewritten(tokens) => tokens,
215            RewriteResult::UnrecognizedSuffix(literal) => TokenTree::Literal(literal).into(),
216            RewriteResult::ValueError(e) => e.into_compile_error(),
217        },
218    ));
219
220    result
221}
222
223#[cfg(test)]
224mod tests {
225    use quote::{quote, ToTokens};
226    use std::fmt::{self, Debug, Formatter};
227    use syn::parse::{Parse, ParseStream};
228    use syn::{Expr, Item, Result};
229
230    use super::{bitint_impl, bitint_literals_impl};
231
232    #[test]
233    fn bitint_simple() {
234        assert_eq!(
235            syn::parse2::<Expr>(bitint_impl(quote! { (some::path::to, 7_U3) })).unwrap(),
236            syn::parse2::<Expr>(quote! { some::path::to::U3::new_masked(7) }).unwrap(),
237        );
238    }
239
240    #[derive(PartialEq, Eq)]
241    struct ParseItems(Vec<Item>);
242
243    impl Parse for ParseItems {
244        fn parse(input: ParseStream) -> Result<Self> {
245            let mut items = Vec::new();
246            while !input.is_empty() {
247                items.push(input.parse()?);
248            }
249            Ok(Self(items))
250        }
251    }
252
253    impl Debug for ParseItems {
254        fn fmt(&self, f: &mut Formatter) -> fmt::Result {
255            let mut delim = "[";
256            for item in &self.0 {
257                write!(f, "{delim}")?;
258                delim = ", ";
259                write!(f, "{:?}", item.to_token_stream().to_string())?;
260            }
261            write!(f, "]")
262        }
263    }
264
265    #[test]
266    fn bitint_literals_simple() {
267        assert_eq!(
268            syn::parse2::<ParseItems>(bitint_literals_impl(
269                quote! {},
270                quote! { fn foo() { 1234567_U24 } },
271            ))
272            .unwrap(),
273            syn::parse2::<ParseItems>(quote! {
274                fn foo() { ::bitint::U24::new_masked(1234567) }
275            })
276            .unwrap(),
277        );
278    }
279
280    #[test]
281    fn bitint_literals_with_crate_path() {
282        assert_eq!(
283            syn::parse2::<ParseItems>(bitint_literals_impl(
284                quote! { crate_path = path::to::bitint_crate },
285                quote! { fn foo() { 1234567_U24 } },
286            ))
287            .unwrap(),
288            syn::parse2::<ParseItems>(quote! {
289                fn foo() { path::to::bitint_crate::U24::new_masked(1234567) }
290            })
291            .unwrap(),
292        );
293    }
294}