1use core::{cmp::Ordering, convert::TryInto, fmt, ops::Range};
2
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{
6 parse::{Parse, ParseStream},
7 parse_macro_input,
8 punctuated::Punctuated,
9 Token,
10};
11
12#[proc_macro]
13pub fn bitstruct(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
14 let input = parse_macro_input!(tokens as BitStructInput);
15 expand_bitstruct(input)
16 .unwrap_or_else(|err| err.to_compile_error())
17 .into()
18}
19
20fn expand_bitstruct(input: BitStructInput) -> syn::Result<TokenStream> {
21 let attrs = &input.attrs;
22 let vis = &input.vis;
23 let name = &input.name;
24 let raw_vis = &input.raw_vis;
25 let raw = &input.raw.as_type();
26 let fields = input
27 .fields
28 .iter()
29 .map(|field| expand_field_methods(&input, field))
30 .collect::<syn::Result<Vec<TokenStream>>>()?;
31 Ok(quote! {
32 #(#attrs)*
33 #vis struct #name(#raw_vis #raw);
34 impl #name {
35 #(#fields)*
36 }
37 })
38}
39
40fn expand_field_methods(input: &BitStructInput, field: &FieldDef) -> syn::Result<TokenStream> {
41 let bitstruct_field_attrs = field
43 .attrs
44 .iter()
45 .find_map(|attr| {
46 let bitstruct: syn::Path = syn::parse_quote! {bitstruct};
47 match attr.parse_meta().ok()? {
48 syn::Meta::List(meta_list) if meta_list.path == bitstruct => Some(meta_list.nested),
49 _ => None,
50 }
51 })
52 .unwrap_or_default();
53
54 let getter_method = expand_field_getter(input, field);
55 let setter_methods = {
56 let omit_setter = bitstruct_field_attrs.iter().any(|nested_meta| {
57 let omit_setter: syn::NestedMeta = syn::parse_quote! {omit_setter};
58 nested_meta == &omit_setter
59 });
60
61 if omit_setter {
62 quote! {}
63 } else {
64 expand_field_setter(input, field)
65 }
66 };
67
68 Ok(quote! {
69 #getter_method
70 #setter_methods
71 })
72}
73
74fn expand_field_getter(input: &BitStructInput, field: &FieldDef) -> TokenStream {
75 let pass_thru_attrs = field.attrs.iter().filter(|&attr| {
77 let bitstruct: syn::Path = syn::parse_quote! {bitstruct};
78 attr.path != bitstruct
79 });
80
81 let target_ty = field.target.as_type();
82 let mask = hexlit(input.raw, field.bits.get_mask());
83 let start_bit = hexlit(input.raw, field.bits.0.start.into());
84 let mask_and_shift: syn::Expr = syn::parse_quote! {
85 ((self.0 & #mask) >> #start_bit)
86 };
87 let cast = from_raw(mask_and_shift, input.raw, &field.target, &field.bits);
88
89 let field_vis = &field.vis;
90 let field_name = &field.name;
91 let maybe_const_fn = if let Target::Convert(_) = field.target {
92 quote! {fn}
93 } else {
94 quote! {const fn}
95 };
96 quote! {
97 #(#pass_thru_attrs)*
98 #field_vis #maybe_const_fn #field_name(&self) -> #target_ty {
99 #cast
100 }
101 }
102}
103
104fn from_raw(raw_expr: syn::Expr, raw: RawDef, target: &Target, bitrange: &BitRange) -> syn::Expr {
105 match target {
106 Target::Int(raw_def) => {
107 let target_ty = raw_def.as_type();
108 syn::parse_quote! {
109 #raw_expr as #target_ty
110 }
111 }
112 Target::Bool => {
113 syn::parse_quote! {
114 #raw_expr != 0
115 }
116 }
117 Target::Convert(ty) => {
118 let bitlen = bitrange.0.end - bitrange.0.start;
119 let smallest_target = Target::smallest_target(bitlen);
120 let smallest_target_expr = from_raw(raw_expr, raw, &smallest_target, bitrange);
121 let smallest_target_ty = smallest_target.as_type();
122 syn::parse_quote! {
123 <Self as ::bitstruct::FromRaw<#smallest_target_ty, #ty>>::from_raw(#smallest_target_expr)
124 }
125 }
126 }
127}
128
129fn expand_field_setter(input: &BitStructInput, field: &FieldDef) -> TokenStream {
130 let pass_thru_attrs = field
132 .attrs
133 .iter()
134 .filter(|&attr| {
135 let bitstruct: syn::Path = syn::parse_quote! {bitstruct};
136 attr.path != bitstruct
137 })
138 .collect::<Vec<_>>();
139
140 let target_ty = field.target.as_type();
141 let mask = field.bits.get_mask();
142 let neg_mask = hexlit(input.raw, !mask);
143 let mask = hexlit(input.raw, mask);
144 let start_bit = hexlit(input.raw, field.bits.0.start.into());
145
146 let field_vis = &field.vis;
147 let field_name = &field.name;
148 let with_method = quote::format_ident!("with_{}", field_name);
149 let set_method = quote::format_ident!("set_{}", field_name);
150 let cast = into_raw(
151 syn::parse_quote! {value},
152 &field.target,
153 input.raw,
154 &field.bits,
155 );
156 let maybe_const_fn = if let Target::Convert(_) = field.target {
157 quote! {fn}
158 } else {
159 quote! {const fn}
160 };
161 quote! {
162 #[must_use]
163 #(#pass_thru_attrs)*
164 #field_vis #maybe_const_fn #with_method(mut self, value: #target_ty) -> Self {
165 self.0 = (self.0 & #neg_mask) | ((#cast << #start_bit) & #mask);
166 self
167 }
168
169 #(#pass_thru_attrs)*
170 #field_vis fn #set_method(&mut self, value: #target_ty) {
171 self.0 = (self.0 & #neg_mask) | ((#cast << #start_bit) & #mask);
172 }
173 }
174}
175
176fn into_raw(
177 target_expr: syn::Expr,
178 target: &Target,
179 raw: RawDef,
180 bitrange: &BitRange,
181) -> syn::Expr {
182 match target {
183 Target::Int(_) | Target::Bool => {
184 let raw = raw.as_type();
185 syn::parse_quote! {
186 (#target_expr as #raw)
187 }
188 }
189 Target::Convert(ty) => {
190 let bitlen = bitrange.0.end - bitrange.0.start;
191 let smallest_target = Target::smallest_target(bitlen);
192 let smallest_target_ty = smallest_target.as_type();
193 let smallest_target_expr = syn::parse_quote! {
194 <Self as ::bitstruct::IntoRaw<#smallest_target_ty, #ty>>::into_raw(#target_expr)
195 };
196 into_raw(smallest_target_expr, &smallest_target, raw, bitrange)
197 }
198 }
199}
200
201trait TryParse {
203 fn try_parse<T: Parse>(&self) -> syn::Result<T>;
204 fn try_call<T>(&self, function: fn(_: ParseStream<'_>) -> syn::Result<T>) -> syn::Result<T>;
205}
206
207impl TryParse for ParseStream<'_> {
208 fn try_parse<T: Parse>(&self) -> syn::Result<T> {
209 use syn::parse::discouraged::Speculative;
210 let fork = self.fork();
211 match fork.parse::<T>() {
212 Ok(value) => {
213 self.advance_to(&fork);
214 Ok(value)
215 }
216 err => err,
217 }
218 }
219
220 fn try_call<T>(&self, function: fn(_: ParseStream<'_>) -> syn::Result<T>) -> syn::Result<T> {
221 use syn::parse::discouraged::Speculative;
222 let fork = self.fork();
223 match fork.call(function) {
224 Ok(value) => {
225 self.advance_to(&fork);
226 Ok(value)
227 }
228 err => err,
229 }
230 }
231}
232
233#[derive(Debug)]
234struct BitStructInput {
235 attrs: Vec<syn::Attribute>,
236 vis: syn::Visibility,
237 name: syn::Ident,
238 raw_vis: syn::Visibility,
239 raw: RawDef,
240 fields: Punctuated<FieldDef, Token![;]>,
241}
242
243impl Parse for BitStructInput {
244 fn parse(input: ParseStream) -> syn::Result<Self> {
245 let attrs = input.call(syn::Attribute::parse_outer)?;
246 let vis = input.parse()?;
247 input.parse::<Token![struct]>()?;
248 let name = input.parse()?;
249 let within_parens;
250 syn::parenthesized!(within_parens in input);
251 let raw_vis = within_parens.parse()?;
252 let raw: RawDef = within_parens.parse()?;
253 let within_braces;
254 syn::braced!(within_braces in input);
255 let fields: Punctuated<FieldDef, _> = Punctuated::parse_terminated(&within_braces)?;
256 for field in fields.iter() {
257 if field.bits.0.end > raw.bit_len() {
258 return Err(syn::Error::new(
259 field.name.span(),
260 format!(
261 "field `{}` specifies a bitrange beyond `{}` range",
262 field.name,
263 raw.as_str()
264 ),
265 ));
266 }
267 }
268 Ok(BitStructInput {
269 attrs,
270 vis,
271 name,
272 raw_vis,
273 raw,
274 fields,
275 })
276 }
277}
278
279#[derive(Debug, Copy, Clone, Eq, PartialEq)]
280enum RawDef {
281 U8,
282 U16,
283 U32,
284 U64,
285 U128,
286}
287
288impl RawDef {
289 fn as_str(self) -> &'static str {
290 match self {
291 RawDef::U8 => "u8",
292 RawDef::U16 => "u16",
293 RawDef::U32 => "u32",
294 RawDef::U64 => "u64",
295 RawDef::U128 => "u128",
296 }
297 }
298
299 fn as_type(self) -> syn::Type {
300 syn::parse_str(self.as_str()).unwrap()
301 }
302
303 fn bit_len(self) -> u8 {
304 match self {
305 RawDef::U8 => 8,
306 RawDef::U16 => 16,
307 RawDef::U32 => 32,
308 RawDef::U64 => 64,
309 RawDef::U128 => 128,
310 }
311 }
312}
313
314impl Parse for RawDef {
315 fn parse(input: ParseStream) -> syn::Result<Self> {
316 let ident: syn::Ident = input.parse()?;
317 if ident == "u8" {
318 Ok(RawDef::U8)
319 } else if ident == "u16" {
320 Ok(RawDef::U16)
321 } else if ident == "u32" {
322 Ok(RawDef::U32)
323 } else if ident == "u64" {
324 Ok(RawDef::U64)
325 } else if ident == "u128" {
326 Ok(RawDef::U128)
327 } else {
328 Err(input.error(format!(
329 "`{}` is not supported; needs to be one of u8,u16,u32,u64,u128",
330 ident
331 )))
332 }
333 }
334}
335
336#[derive(Debug)]
337struct FieldDef {
338 attrs: Vec<syn::Attribute>,
339 vis: syn::Visibility,
340 name: syn::Ident,
341 target: Target,
342 bits: BitRange,
343}
344
345impl Parse for FieldDef {
346 fn parse(input: ParseStream) -> syn::Result<Self> {
347 let attrs = input.call(syn::Attribute::parse_outer)?;
348 let vis = input.parse()?;
349 let name = input.parse()?;
350 input.parse::<Token![:]>()?;
351 let target: Target = input.parse()?;
352 input.parse::<Token![=]>()?;
353 let bits: BitRange = input.parse()?;
354 if target.bit_len() < bits.bit_len() {
355 return Err(input.error(format!(
356 "target `{}` can only represent {} bits; {} specified",
357 target,
358 target.bit_len(),
359 bits.bit_len(),
360 )));
361 }
362 Ok(FieldDef {
363 attrs,
364 vis,
365 name,
366 target,
367 bits,
368 })
369 }
370}
371
372#[derive(Debug, Eq, PartialEq)]
373enum Target {
374 Int(RawDef),
376 Bool,
378 Convert(syn::Type),
380}
381
382impl Target {
383 fn smallest_target(bitlen: u8) -> Target {
384 match bitlen {
385 x if x == 1 => Target::Bool,
386 x if x <= 8 => Target::Int(RawDef::U8),
387 x if x <= 16 => Target::Int(RawDef::U16),
388 x if x <= 32 => Target::Int(RawDef::U32),
389 x if x <= 64 => Target::Int(RawDef::U64),
390 x if x <= 128 => Target::Int(RawDef::U128),
391 _ => unreachable!("invalid bitlen"),
392 }
393 }
394
395 fn bit_len(&self) -> u8 {
396 match self {
397 Target::Int(raw) => raw.bit_len(),
398 Target::Bool => 1,
399 Target::Convert(_) => u8::MAX,
400 }
401 }
402
403 fn as_type(&self) -> syn::Type {
404 match self {
405 Target::Int(raw) => raw.as_type(),
406 Target::Bool => syn::parse_quote! {bool},
407 Target::Convert(ty) => ty.clone().into(),
408 }
409 }
410}
411
412impl fmt::Display for Target {
413 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
414 match self {
415 Target::Int(rawdef) => write!(f, "{}", rawdef.as_str()),
416 Target::Bool => write!(f, "bool"),
417 Target::Convert(ty) => write!(f, "{:?}", ty),
418 }
419 }
420}
421
422mod kw {
423 syn::custom_keyword!(bool);
424}
425
426impl Parse for Target {
427 fn parse(input: ParseStream) -> syn::Result<Self> {
428 input
429 .try_parse::<RawDef>()
430 .map(|raw_def| Target::Int(raw_def))
431 .or_else(|_| input.try_parse::<kw::bool>().map(|_| Target::Bool))
432 .or_else(|_| input.try_parse::<syn::Type>().map(|ty| Target::Convert(ty)))
433 }
434}
435
436#[derive(Debug, Eq, PartialEq)]
437struct BitRange(Range<u8>);
438
439impl BitRange {
440 fn bit_len(&self) -> u8 {
441 self.0.len().try_into().unwrap()
442 }
443
444 fn get_mask(&self) -> u128 {
445 let mut mask = !0u128;
446 mask <<= 128 - self.0.end;
447 mask >>= 128 - self.0.end;
448 mask >>= self.0.start;
449 mask <<= self.0.start;
450 mask
451 }
452}
453
454impl Parse for BitRange {
455 fn parse(input: ParseStream) -> syn::Result<Self> {
456 fn parse_end_range(input: ParseStream) -> syn::Result<u8> {
457 let range_limits: syn::RangeLimits = input.parse()?;
458 let end_bit: u8 = input.parse::<syn::LitInt>()?.base10_parse()?;
459 Ok(match range_limits {
460 syn::RangeLimits::HalfOpen(_) => end_bit,
461 syn::RangeLimits::Closed(_) => end_bit + 1,
462 })
463 }
464
465 let start_bit: u8 = input.parse::<syn::LitInt>()?.base10_parse()?;
466 let range = match input.try_call(parse_end_range) {
467 Ok(end_bit) => start_bit..end_bit,
468 Err(_) => start_bit..start_bit + 1,
469 };
470 match range.start.cmp(&range.end) {
471 Ordering::Less => {}
472 Ordering::Equal => return Err(input.error("empty bit range specified")),
473 Ordering::Greater => {
474 return Err(input
475 .error("least significant bit must be specified before most significant bit"))
476 }
477 };
478 Ok(BitRange(range))
479 }
480}
481
482fn hexlit(typ: RawDef, value: u128) -> syn::LitInt {
483 let num_hex_chars = typ.bit_len() as usize / 4;
484 syn::LitInt::new(
485 &format!(
486 "0x{value:0width$x}{suffix:}",
487 value = value,
488 suffix = typ.as_str(),
489 width = num_hex_chars
490 ),
491 proc_macro2::Span::call_site(),
492 )
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 #[test]
499 fn parse_bitstruct_input() {
500 let bitstruct: BitStructInput = syn::parse2(quote! {
501 #[derive(Clone,Copy)]
502 pub(crate) struct Foo(pub u16) {
503 #[inline]
504 pub f1: u8 = 0 .. 8;
505 pub f2: u8 = 8 .. 12;
506 }
507 })
508 .unwrap();
509 assert_eq!(bitstruct.name, quote::format_ident!("Foo"));
510 assert_eq!(bitstruct.fields.len(), 2);
511 assert_eq!(bitstruct.fields[0].attrs.len(), 1);
512 assert_eq!(bitstruct.fields[1].attrs.len(), 0);
513 }
514
515 #[test]
516 fn parse_field_def() {
517 let field_def: FieldDef = syn::parse2(quote! {
518 pub field1: u8 = 3 .. 5
519 })
520 .unwrap();
521 assert_eq!(field_def.name, quote::format_ident!("field1"));
522 assert_eq!(field_def.target, Target::Int(RawDef::U8));
523 assert_eq!(field_def.bits, BitRange(3..5));
524
525 let field_def: FieldDef = syn::parse2(quote! {
526 pub field1: bool = 3
527 })
528 .unwrap();
529 assert_eq!(field_def.name, quote::format_ident!("field1"));
530 assert_eq!(field_def.target, Target::Bool);
531 assert_eq!(field_def.bits, BitRange(3..4));
532 }
533
534 #[test]
535 fn parse_target() {
536 assert_eq!(
537 Target::Int(RawDef::U8),
538 syn::parse2::<Target>(quote! {u8}).unwrap(),
539 );
540 assert_eq!(
541 Target::Int(RawDef::U16),
542 syn::parse2::<Target>(quote! {u16}).unwrap(),
543 );
544 assert_eq!(
545 Target::Int(RawDef::U128),
546 syn::parse2::<Target>(quote! {u128}).unwrap(),
547 );
548 assert_eq!(Target::Bool, syn::parse2::<Target>(quote! {bool}).unwrap(),);
549 assert_eq!(
550 Target::Convert(syn::parse_quote! {MyEnum}),
551 syn::parse2::<Target>(quote! {MyEnum}).unwrap(),
552 );
553 assert_eq!(
554 Target::Convert(syn::parse_quote! {Vec<u32>}),
555 syn::parse2::<Target>(quote! {Vec<u32>}).unwrap(),
556 );
557 }
558
559 #[test]
560 fn parse_bitrange() {
561 assert_eq!(
562 BitRange(0..10),
563 syn::parse2::<BitRange>(quote! {0..10}).unwrap()
564 );
565 assert_eq!(
566 BitRange(0..12),
567 syn::parse2::<BitRange>(quote! {0..=11}).unwrap()
568 );
569 assert_eq!(
570 BitRange(14..15),
571 syn::parse2::<BitRange>(quote! {14}).unwrap()
572 );
573 }
574}