bitpiece_macros/
lib.rs

1mod enums;
2mod named_structs;
3mod newtypes;
4mod utils;
5
6use std::{collections::HashSet, str::FromStr};
7
8use enum_all_values_const::AllValues;
9use enums::bitpiece_enum;
10use heck::{ToSnakeCase, ToUpperCamelCase};
11use itertools::Itertools;
12use named_structs::bitpiece_named_struct;
13use strum::{EnumString, VariantNames};
14use syn::{
15    parse::{Parse, ParseStream},
16    parse_macro_input,
17    punctuated::Punctuated,
18    token::Comma,
19    DeriveInput, LitInt,
20};
21use utils::{are_generics_empty, not_supported_err};
22
23/// an attribute for defining bitfields.
24#[proc_macro_attribute]
25pub fn bitpiece(
26    args: proc_macro::TokenStream,
27    input: proc_macro::TokenStream,
28) -> proc_macro::TokenStream {
29    impl_bitpiece(args, input)
30}
31
32#[derive(EnumString, VariantNames, AllValues, Hash, Clone, Copy, Debug, PartialEq, Eq)]
33enum OptIn {
34    Get,
35    Set,
36    With,
37    GetNoshift,
38    GetMut,
39    ConstEq,
40    FieldsStruct,
41    MutStruct,
42    MutStructFieldGet,
43    MutStructFieldSet,
44    MutStructFieldGetNoshift,
45    MutStructFieldMut,
46}
47
48#[derive(EnumString, VariantNames, Hash, Clone, Copy, Debug, PartialEq, Eq)]
49enum OptInPreset {
50    Basic,
51    All,
52    MutStructAll,
53}
54impl OptInPreset {
55    fn opt_ins(&self) -> &'static [OptIn] {
56        match self {
57            OptInPreset::Basic => &[OptIn::Get, OptIn::Set, OptIn::With],
58            OptInPreset::All => OptIn::ALL_VALUES.as_slice(),
59            OptInPreset::MutStructAll => &[
60                OptIn::MutStruct,
61                OptIn::MutStructFieldGet,
62                OptIn::MutStructFieldSet,
63                OptIn::MutStructFieldGetNoshift,
64                OptIn::MutStructFieldMut,
65            ],
66        }
67    }
68}
69
70struct ExplicitBitLengthArg {
71    bit_length: usize,
72    lit: LitInt,
73}
74
75struct OptInArg {
76    opt_in: OptIn,
77    ident: syn::Ident,
78}
79struct OptInPresetArg {
80    opt_in_preset: OptInPreset,
81    ident: syn::Ident,
82}
83
84enum MacroArg {
85    ExplicitBitLength(ExplicitBitLengthArg),
86    OptIn(OptInArg),
87    OptInPreset(OptInPresetArg),
88}
89impl Parse for MacroArg {
90    fn parse(input: ParseStream) -> syn::Result<Self> {
91        let opt_in_names: String = OptIn::VARIANTS
92            .iter()
93            .map(|v| format!("`{}`", v.to_snake_case()))
94            .join(", ");
95        let preset_names: String = OptInPreset::VARIANTS
96            .iter()
97            .map(|v| format!("`{}`", v.to_snake_case()))
98            .join(", ");
99        let unknown_macro_arg_err = format!(
100            "unknown macro argument, expected an integer bit-length (e.g. `32`), an opt-in flag ({opt_in_names}), or an opt-in preset ({preset_names})"
101        );
102
103        // explicit bit length
104        if input.peek(LitInt) {
105            let lit: LitInt = input.parse()?;
106            return Ok(MacroArg::ExplicitBitLength(ExplicitBitLengthArg {
107                bit_length: lit.base10_parse()?,
108                lit,
109            }));
110        }
111
112        // opt ins as identifiers
113        if input.peek(syn::Ident) {
114            let ident: syn::Ident = input.parse()?;
115
116            let ident_pascal_case = ident.to_string().to_upper_camel_case();
117
118            if let Ok(opt_in) = OptIn::from_str(&ident_pascal_case) {
119                return Ok(MacroArg::OptIn(OptInArg { opt_in, ident }));
120            } else if let Ok(opt_in_preset) = OptInPreset::from_str(&ident_pascal_case) {
121                return Ok(MacroArg::OptInPreset(OptInPresetArg {
122                    opt_in_preset,
123                    ident,
124                }));
125            } else {
126                return Err(syn::Error::new_spanned(&ident, unknown_macro_arg_err));
127            }
128        }
129
130        Err(input.error(unknown_macro_arg_err))
131    }
132}
133
134struct RawMacroArgs(Punctuated<MacroArg, Comma>);
135impl Parse for RawMacroArgs {
136    fn parse(input: ParseStream) -> syn::Result<Self> {
137        Punctuated::<MacroArg, Comma>::parse_terminated(input).map(Self)
138    }
139}
140
141struct OptInArgsCollector(Vec<OptInArg>);
142impl OptInArgsCollector {
143    fn new() -> Self {
144        Self(Vec::new())
145    }
146    fn add_opt_in(&mut self, arg: OptInArg) -> Result<(), syn::Error> {
147        if let Some(existing_arg) = self
148            .0
149            .iter()
150            .find(|existing_arg| existing_arg.opt_in == arg.opt_in)
151        {
152            let mut err = syn::Error::new_spanned(arg.ident, "duplicate opt in arg");
153            err.combine(syn::Error::new_spanned(
154                existing_arg.ident.clone(),
155                "conflicts with this previous opt in arg",
156            ));
157            return Err(err);
158        }
159        Ok(self.0.push(arg))
160    }
161}
162
163#[derive(Default)]
164struct MacroArgs {
165    explicit_bit_length: Option<usize>,
166    opt_ins: HashSet<OptIn>,
167}
168impl MacroArgs {
169    pub fn filter_opt_in_code(
170        &self,
171        opt_in: OptIn,
172        code: proc_macro2::TokenStream,
173    ) -> proc_macro2::TokenStream {
174        if self.opt_ins.contains(&opt_in) {
175            code
176        } else {
177            quote::quote! {}
178        }
179    }
180}
181impl Parse for MacroArgs {
182    fn parse(input: ParseStream) -> syn::Result<Self> {
183        let raw_args: RawMacroArgs = input.parse()?;
184
185        let mut explicit_bit_length_arg: Option<ExplicitBitLengthArg> = None;
186        let mut opt_in_args = OptInArgsCollector::new();
187        for arg in raw_args.0 {
188            match arg {
189                MacroArg::ExplicitBitLength(arg) => {
190                    if let Some(existing_arg) = explicit_bit_length_arg {
191                        let mut err = syn::Error::new_spanned(
192                            arg.lit,
193                            "found more than one explicit bit length argument but only one is allowed",
194                        );
195                        err.combine(syn::Error::new_spanned(
196                            existing_arg.lit,
197                            "conflicts with this previous explicit bit length argument",
198                        ));
199                        return Err(err);
200                    }
201                    explicit_bit_length_arg = Some(arg);
202                }
203                MacroArg::OptIn(arg) => {
204                    opt_in_args.add_opt_in(arg)?;
205                }
206                MacroArg::OptInPreset(opt_in_preset_arg) => {
207                    for opt_in in opt_in_preset_arg.opt_in_preset.opt_ins() {
208                        opt_in_args.add_opt_in(OptInArg {
209                            opt_in: *opt_in,
210                            ident: opt_in_preset_arg.ident.clone(),
211                        })?;
212                    }
213                }
214            }
215        }
216        Ok(MacroArgs {
217            explicit_bit_length: explicit_bit_length_arg.map(|arg| arg.bit_length),
218            opt_ins: if opt_in_args.0.is_empty() {
219                // if no opt ins are specified, use the basic preset
220                OptInPreset::Basic.opt_ins().iter().copied().collect()
221            } else {
222                opt_in_args.0.iter().map(|arg| arg.opt_in).collect()
223            },
224        })
225    }
226}
227
228fn impl_bitpiece(
229    args_tokens: proc_macro::TokenStream,
230    input_tokens: proc_macro::TokenStream,
231) -> proc_macro::TokenStream {
232    let macro_args = parse_macro_input!(args_tokens as MacroArgs);
233    let input = parse_macro_input!(input_tokens as DeriveInput);
234
235    if !are_generics_empty(&input.generics) {
236        return not_supported_err("generics");
237    }
238
239    match &input.data {
240        syn::Data::Struct(data_struct) => match &data_struct.fields {
241            syn::Fields::Named(fields) => bitpiece_named_struct(&input, fields, macro_args),
242            syn::Fields::Unnamed(_) => not_supported_err("unnamed structs"),
243            syn::Fields::Unit => not_supported_err("empty structs"),
244        },
245        syn::Data::Enum(data_enum) => bitpiece_enum(&input, data_enum, macro_args),
246        syn::Data::Union(_) => not_supported_err("unions"),
247    }
248}