1mod enums;
2mod named_structs;
3mod newtypes;
4mod utils;
5
6use std::{collections::HashSet, str::FromStr};
7
8use enum_all_values_const::AllValues;
9use enums::bitpiece_enum;
10use heck::{ToSnakeCase, ToUpperCamelCase};
11use itertools::Itertools;
12use named_structs::bitpiece_named_struct;
13use strum::{EnumString, VariantNames};
14use syn::{
15 parse::{Parse, ParseStream},
16 parse_macro_input,
17 punctuated::Punctuated,
18 token::Comma,
19 DeriveInput, LitInt,
20};
21use utils::{are_generics_empty, not_supported_err};
22
23#[proc_macro_attribute]
25pub fn bitpiece(
26 args: proc_macro::TokenStream,
27 input: proc_macro::TokenStream,
28) -> proc_macro::TokenStream {
29 impl_bitpiece(args, input)
30}
31
32#[derive(EnumString, VariantNames, AllValues, Hash, Clone, Copy, Debug, PartialEq, Eq)]
33enum OptIn {
34 Get,
35 Set,
36 With,
37 GetNoshift,
38 GetMut,
39 ConstEq,
40 FieldsStruct,
41 MutStruct,
42 MutStructFieldGet,
43 MutStructFieldSet,
44 MutStructFieldGetNoshift,
45 MutStructFieldMut,
46}
47
48#[derive(EnumString, VariantNames, Hash, Clone, Copy, Debug, PartialEq, Eq)]
49enum OptInPreset {
50 Basic,
51 All,
52 MutStructAll,
53}
54impl OptInPreset {
55 fn opt_ins(&self) -> &'static [OptIn] {
56 match self {
57 OptInPreset::Basic => &[OptIn::Get, OptIn::Set, OptIn::With],
58 OptInPreset::All => OptIn::ALL_VALUES.as_slice(),
59 OptInPreset::MutStructAll => &[
60 OptIn::MutStruct,
61 OptIn::MutStructFieldGet,
62 OptIn::MutStructFieldSet,
63 OptIn::MutStructFieldGetNoshift,
64 OptIn::MutStructFieldMut,
65 ],
66 }
67 }
68}
69
70struct ExplicitBitLengthArg {
71 bit_length: usize,
72 lit: LitInt,
73}
74
75struct OptInArg {
76 opt_in: OptIn,
77 ident: syn::Ident,
78}
79struct OptInPresetArg {
80 opt_in_preset: OptInPreset,
81 ident: syn::Ident,
82}
83
84enum MacroArg {
85 ExplicitBitLength(ExplicitBitLengthArg),
86 OptIn(OptInArg),
87 OptInPreset(OptInPresetArg),
88}
89impl Parse for MacroArg {
90 fn parse(input: ParseStream) -> syn::Result<Self> {
91 let opt_in_names: String = OptIn::VARIANTS
92 .iter()
93 .map(|v| format!("`{}`", v.to_snake_case()))
94 .join(", ");
95 let preset_names: String = OptInPreset::VARIANTS
96 .iter()
97 .map(|v| format!("`{}`", v.to_snake_case()))
98 .join(", ");
99 let unknown_macro_arg_err = format!(
100 "unknown macro argument, expected an integer bit-length (e.g. `32`), an opt-in flag ({opt_in_names}), or an opt-in preset ({preset_names})"
101 );
102
103 if input.peek(LitInt) {
105 let lit: LitInt = input.parse()?;
106 return Ok(MacroArg::ExplicitBitLength(ExplicitBitLengthArg {
107 bit_length: lit.base10_parse()?,
108 lit,
109 }));
110 }
111
112 if input.peek(syn::Ident) {
114 let ident: syn::Ident = input.parse()?;
115
116 let ident_pascal_case = ident.to_string().to_upper_camel_case();
117
118 if let Ok(opt_in) = OptIn::from_str(&ident_pascal_case) {
119 return Ok(MacroArg::OptIn(OptInArg { opt_in, ident }));
120 } else if let Ok(opt_in_preset) = OptInPreset::from_str(&ident_pascal_case) {
121 return Ok(MacroArg::OptInPreset(OptInPresetArg {
122 opt_in_preset,
123 ident,
124 }));
125 } else {
126 return Err(syn::Error::new_spanned(&ident, unknown_macro_arg_err));
127 }
128 }
129
130 Err(input.error(unknown_macro_arg_err))
131 }
132}
133
134struct RawMacroArgs(Punctuated<MacroArg, Comma>);
135impl Parse for RawMacroArgs {
136 fn parse(input: ParseStream) -> syn::Result<Self> {
137 Punctuated::<MacroArg, Comma>::parse_terminated(input).map(Self)
138 }
139}
140
141struct OptInArgsCollector(Vec<OptInArg>);
142impl OptInArgsCollector {
143 fn new() -> Self {
144 Self(Vec::new())
145 }
146 fn add_opt_in(&mut self, arg: OptInArg) -> Result<(), syn::Error> {
147 if let Some(existing_arg) = self
148 .0
149 .iter()
150 .find(|existing_arg| existing_arg.opt_in == arg.opt_in)
151 {
152 let mut err = syn::Error::new_spanned(arg.ident, "duplicate opt in arg");
153 err.combine(syn::Error::new_spanned(
154 existing_arg.ident.clone(),
155 "conflicts with this previous opt in arg",
156 ));
157 return Err(err);
158 }
159 Ok(self.0.push(arg))
160 }
161}
162
163#[derive(Default)]
164struct MacroArgs {
165 explicit_bit_length: Option<usize>,
166 opt_ins: HashSet<OptIn>,
167}
168impl MacroArgs {
169 pub fn filter_opt_in_code(
170 &self,
171 opt_in: OptIn,
172 code: proc_macro2::TokenStream,
173 ) -> proc_macro2::TokenStream {
174 if self.opt_ins.contains(&opt_in) {
175 code
176 } else {
177 quote::quote! {}
178 }
179 }
180}
181impl Parse for MacroArgs {
182 fn parse(input: ParseStream) -> syn::Result<Self> {
183 let raw_args: RawMacroArgs = input.parse()?;
184
185 let mut explicit_bit_length_arg: Option<ExplicitBitLengthArg> = None;
186 let mut opt_in_args = OptInArgsCollector::new();
187 for arg in raw_args.0 {
188 match arg {
189 MacroArg::ExplicitBitLength(arg) => {
190 if let Some(existing_arg) = explicit_bit_length_arg {
191 let mut err = syn::Error::new_spanned(
192 arg.lit,
193 "found more than one explicit bit length argument but only one is allowed",
194 );
195 err.combine(syn::Error::new_spanned(
196 existing_arg.lit,
197 "conflicts with this previous explicit bit length argument",
198 ));
199 return Err(err);
200 }
201 explicit_bit_length_arg = Some(arg);
202 }
203 MacroArg::OptIn(arg) => {
204 opt_in_args.add_opt_in(arg)?;
205 }
206 MacroArg::OptInPreset(opt_in_preset_arg) => {
207 for opt_in in opt_in_preset_arg.opt_in_preset.opt_ins() {
208 opt_in_args.add_opt_in(OptInArg {
209 opt_in: *opt_in,
210 ident: opt_in_preset_arg.ident.clone(),
211 })?;
212 }
213 }
214 }
215 }
216 Ok(MacroArgs {
217 explicit_bit_length: explicit_bit_length_arg.map(|arg| arg.bit_length),
218 opt_ins: if opt_in_args.0.is_empty() {
219 OptInPreset::Basic.opt_ins().iter().copied().collect()
221 } else {
222 opt_in_args.0.iter().map(|arg| arg.opt_in).collect()
223 },
224 })
225 }
226}
227
228fn impl_bitpiece(
229 args_tokens: proc_macro::TokenStream,
230 input_tokens: proc_macro::TokenStream,
231) -> proc_macro::TokenStream {
232 let macro_args = parse_macro_input!(args_tokens as MacroArgs);
233 let input = parse_macro_input!(input_tokens as DeriveInput);
234
235 if !are_generics_empty(&input.generics) {
236 return not_supported_err("generics");
237 }
238
239 match &input.data {
240 syn::Data::Struct(data_struct) => match &data_struct.fields {
241 syn::Fields::Named(fields) => bitpiece_named_struct(&input, fields, macro_args),
242 syn::Fields::Unnamed(_) => not_supported_err("unnamed structs"),
243 syn::Fields::Unit => not_supported_err("empty structs"),
244 },
245 syn::Data::Enum(data_enum) => bitpiece_enum(&input, data_enum, macro_args),
246 syn::Data::Union(_) => not_supported_err("unions"),
247 }
248}