1#![doc = include_str!("../README.md")]
3#![warn(clippy::unwrap_used)]
4
5use proc_macro as pc;
6use proc_macro2::{Ident, TokenStream};
7use quote::{format_ident, quote, ToTokens};
8use std::{fmt, stringify};
9use syn::spanned::Spanned;
10
11mod attr;
12use attr::*;
13mod bitenum;
14mod traits;
15
16#[proc_macro_attribute]
52pub fn bitfield(args: pc::TokenStream, input: pc::TokenStream) -> pc::TokenStream {
53 match bitfield_inner(args.into(), input.into()) {
54 Ok(result) => result.into(),
55 Err(e) => e.into_compile_error().into(),
56 }
57}
58
59#[proc_macro_attribute]
92pub fn bitenum(args: pc::TokenStream, input: pc::TokenStream) -> pc::TokenStream {
93 match bitenum::bitenum_inner(args.into(), input.into()) {
94 Ok(result) => result.into(),
95 Err(e) => e.into_compile_error().into(),
96 }
97}
98
99fn bitfield_inner(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
100 let input = syn::parse2::<syn::ItemStruct>(input)?;
101 let Params {
102 ty,
103 repr,
104 into,
105 from,
106 bits,
107 binread,
108 binwrite,
109 new,
110 clone,
111 debug,
112 defmt,
113 default,
114 hash,
115 order,
116 conversion,
117 } = syn::parse2(args)?;
118
119 let span = input.fields.span();
120 let name = input.ident;
121 let vis = input.vis;
122
123 if input.generics.type_params().next().is_some() || input.generics.lifetimes().next().is_some()
128 {
129 return Err(s_err(
130 span,
131 "type parameters and lifetimes are not supported",
132 ));
133 }
134 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
135
136 let attrs: TokenStream = input.attrs.iter().map(ToTokens::to_token_stream).collect();
137 let derive = match clone {
138 Enable::No => None,
139 Enable::Yes => Some(quote! { #[derive(Copy, Clone)] }),
140 Enable::Cfg(cfg) => Some(quote! { #[cfg_attr(#cfg, derive(Copy, Clone))] }),
141 };
142
143 let syn::Fields::Named(fields) = input.fields else {
144 return Err(s_err(span, "only named fields are supported"));
145 };
146
147 let mut offset = 0;
148 let mut members = Vec::with_capacity(fields.named.len());
149 for field in fields.named {
150 let f = Member::new(
151 ty.clone(),
152 bits,
153 into.clone(),
154 from.clone(),
155 field,
156 offset,
157 order,
158 )?;
159 offset += f.bits;
160 members.push(f);
161 }
162
163 if offset < bits {
164 return Err(s_err(
165 span,
166 format!(
167 "The bitfield size ({bits} bits) has to be equal to the sum of its fields ({offset} bits). \
168 You might have to add padding (a {} bits large field prefixed with \"_\").",
169 bits - offset
170 ),
171 ));
172 }
173 if offset > bits {
174 return Err(s_err(
175 span,
176 format!(
177 "The size of the fields ({offset} bits) is larger than the type ({bits} bits)."
178 ),
179 ));
180 }
181
182 let mut impl_debug = TokenStream::new();
183 if let Some(cfg) = debug.cfg() {
184 impl_debug.extend(traits::debug(&name, &input.generics, &members, cfg));
185 }
186 if let Some(cfg) = defmt.cfg() {
187 impl_debug.extend(traits::defmt(&name, &input.generics, &members, cfg));
188 }
189 if let Some(cfg) = hash.cfg() {
190 impl_debug.extend(traits::hash(&name, &input.generics, &members, cfg));
191 }
192 if let Some(cfg) = binread.cfg() {
193 impl_debug.extend(traits::binread(&name, &input.generics, &repr, cfg));
194 }
195 if let Some(cfg) = binwrite.cfg() {
196 impl_debug.extend(traits::binwrite(&name, &input.generics, cfg));
197 }
198
199 let defaults = members.iter().map(Member::default).collect::<Vec<_>>();
200
201 let impl_new = new.cfg().map(|cfg| {
202 let attr = cfg.map(|cfg| quote!(#[cfg(#cfg)]));
203 quote! {
204 #attr
206 #vis const fn new() -> Self {
207 let mut this = Self(#from(0));
208 #( #defaults )*
209 this
210 }
211 }
212 });
213
214 let impl_default = default.cfg().map(|cfg| {
215 let attr = cfg.map(|cfg| quote!(#[cfg(#cfg)]));
216 quote! {
217 #attr
218 impl #impl_generics Default for #name #ty_generics #where_clause {
219 fn default() -> Self {
220 let mut this = Self(#from(0));
221 #( #defaults )*
222 this
223 }
224 }
225 }
226 });
227
228 let conversion = conversion.then(|| {
229 quote! {
230 #vis const fn from_bits(bits: #repr) -> Self {
232 Self(bits)
233 }
234 #vis const fn into_bits(self) -> #repr {
236 self.0
237 }
238 }
239 });
240
241 Ok(quote! {
242 #attrs
243 #derive
244 #[repr(transparent)]
245 #vis struct #name #impl_generics (#repr) #where_clause;
246
247 #[allow(unused_comparisons)]
248 #[allow(clippy::unnecessary_cast)]
249 #[allow(clippy::assign_op_pattern)]
250 #[allow(clippy::double_parens)]
251 impl #impl_generics #name #ty_generics #where_clause {
252 #impl_new
253
254 #conversion
255
256 #( #members )*
257 }
258
259 #[allow(unused_comparisons)]
260 #[allow(clippy::unnecessary_cast)]
261 #[allow(clippy::assign_op_pattern)]
262 #[allow(clippy::double_parens)]
263 #impl_default
264
265 impl #impl_generics From<#repr> for #name #ty_generics #where_clause {
266 fn from(v: #repr) -> Self {
267 Self(v)
268 }
269 }
270 impl #impl_generics From<#name #ty_generics> for #repr #where_clause {
271 fn from(v: #name #ty_generics) -> Self {
272 v.0
273 }
274 }
275
276 #impl_debug
277 })
278}
279
280struct Member {
282 offset: usize,
283 bits: usize,
284 base_ty: syn::Type,
285 repr_into: Option<syn::Path>,
286 repr_from: Option<syn::Path>,
287 default: TokenStream,
288 inner: Option<MemberInner>,
289}
290
291struct MemberInner {
292 ident: syn::Ident,
293 ty: syn::Type,
294 attrs: Vec<syn::Attribute>,
295 vis: syn::Visibility,
296 into: TokenStream,
297 from: TokenStream,
298}
299
300impl Member {
301 fn new(
302 base_ty: syn::Type,
303 base_bits: usize,
304 repr_into: Option<syn::Path>,
305 repr_from: Option<syn::Path>,
306 field: syn::Field,
307 offset: usize,
308 order: Order,
309 ) -> syn::Result<Self> {
310 let span = field.span();
311
312 let syn::Field {
313 mut attrs,
314 vis,
315 ident,
316 ty,
317 ..
318 } = field;
319
320 let ident = ident.ok_or_else(|| s_err(span, "Not supported"))?;
321 let ignore = ident.to_string().starts_with('_');
322
323 let Field {
324 bits,
325 ty,
326 mut default,
327 into,
328 from,
329 access,
330 } = parse_field(&base_ty, &attrs, &ty, ignore)?;
331
332 let ignore = ignore || access == Access::None;
333
334 let offset = if order == Order::Lsb {
336 offset
337 } else {
338 base_bits - offset - bits
339 };
340
341 if bits > 0 && !ignore {
342 if offset + bits > base_bits {
344 return Err(s_err(
345 ty.span(),
346 "The sum of the members overflows the type size",
347 ));
348 };
349
350 let (from, into) = match access {
352 Access::ReadWrite => (from, into),
353 Access::ReadOnly => (from, quote!()),
354 Access::WriteOnly => (from, into),
355 Access::None => (quote!(), quote!()),
356 };
357
358 if default.is_empty() {
360 if !from.is_empty() {
361 default = quote!({ let this = 0; #from });
362 } else {
363 default = quote!(0);
364 }
365 }
366
367 attrs.retain(|a| !a.path().is_ident("bits"));
369
370 Ok(Self {
371 offset,
372 bits,
373 base_ty,
374 repr_into,
375 repr_from,
376 default,
377 inner: Some(MemberInner {
378 ident,
379 ty,
380 attrs,
381 vis,
382 into,
383 from,
384 }),
385 })
386 } else {
387 if default.is_empty() {
388 default = quote!(0);
389 }
390
391 Ok(Self {
392 offset,
393 bits,
394 base_ty,
395 repr_into,
396 repr_from,
397 default,
398 inner: None,
399 })
400 }
401 }
402
403 fn default(&self) -> TokenStream {
404 let default = &self.default;
405
406 if let Some(inner) = &self.inner {
407 if !inner.into.is_empty() {
408 let ident = &inner.ident;
409 let with_ident = format_ident!("with_{}", ident);
410 return quote!(this = this.#with_ident(#default););
411 }
412 }
413
414 let offset = self.offset;
416 let base_ty = &self.base_ty;
417 let repr_into = &self.repr_into;
418 let repr_from = &self.repr_from;
419 let bits = self.bits as u32;
420
421 quote! {
422 let mask = #base_ty::MAX >> (#base_ty::BITS - #bits);
423 this.0 = #repr_from(#repr_into(this.0) | (((#default as #base_ty) & mask) << #offset));
424 }
425 }
426}
427
428impl ToTokens for Member {
429 fn to_tokens(&self, tokens: &mut TokenStream) {
430 let Self {
431 offset,
432 bits,
433 base_ty,
434 repr_into,
435 repr_from,
436 default: _,
437 inner:
438 Some(MemberInner {
439 ident,
440 ty,
441 attrs,
442 vis,
443 into,
444 from,
445 }),
446 } = self
447 else {
448 return Default::default();
449 };
450
451 let ident_str = ident.to_string().to_uppercase();
452 let ident_upper = Ident::new(
453 ident_str.strip_prefix("R#").unwrap_or(&ident_str),
454 ident.span(),
455 );
456
457 let with_ident = format_ident!("with_{}", ident);
458 let with_ident_checked = format_ident!("with_{}_checked", ident);
459 let set_ident = format_ident!("set_{}", ident);
460 let set_ident_checked = format_ident!("set_{}_checked", ident);
461 let bits_ident = format_ident!("{}_BITS", ident_upper);
462 let offset_ident = format_ident!("{}_OFFSET", ident_upper);
463
464 let location = format!("\n\nBits: {offset}..{}", offset + bits);
465
466 let doc: TokenStream = attrs
467 .iter()
468 .filter(|a| !a.path().is_ident("bits"))
469 .map(ToTokens::to_token_stream)
470 .collect();
471
472 tokens.extend(quote! {
473 const #bits_ident: usize = #bits;
474 const #offset_ident: usize = #offset;
475 });
476
477 if !from.is_empty() {
478 tokens.extend(quote! {
479 #doc
480 #[doc = #location]
481 #vis const fn #ident(&self) -> #ty {
482 let mask = #base_ty::MAX >> (#base_ty::BITS - Self::#bits_ident as u32);
483 let this = (#repr_into(self.0) >> Self::#offset_ident) & mask;
484 #from
485 }
486 });
487 }
488
489 if !into.is_empty() {
490 let (class, _) = type_info(ty);
491 let bounds = if class == TypeClass::SInt {
493 let min = -((u128::MAX >> (128 - (bits - 1))) as i128) - 1;
494 let max = u128::MAX >> (128 - (bits - 1));
495 format!("[{}, {}]", min, max)
496 } else {
497 format!("[0, {}]", u128::MAX >> (128 - bits))
498 };
499 let bounds_error = format!("value out of bounds {bounds}");
500
501 tokens.extend(quote! {
502 #doc
503 #[doc = #location]
504 #vis const fn #with_ident_checked(mut self, value: #ty) -> core::result::Result<Self, ()> {
505 match self.#set_ident_checked(value) {
506 Ok(_) => Ok(self),
507 Err(_) => Err(()),
508 }
509 }
510 #doc
511 #[doc = #location]
512 #[cfg_attr(debug_assertions, track_caller)]
513 #vis const fn #with_ident(mut self, value: #ty) -> Self {
514 self.#set_ident(value);
515 self
516 }
517
518 #doc
519 #[doc = #location]
520 #vis const fn #set_ident(&mut self, value: #ty) {
521 if let Err(_) = self.#set_ident_checked(value) {
522 panic!(#bounds_error)
523 }
524 }
525 #doc
526 #[doc = #location]
527 #vis const fn #set_ident_checked(&mut self, value: #ty) -> core::result::Result<(), ()> {
528 let this = value;
529 let value: #base_ty = #into;
530 let mask = #base_ty::MAX >> (#base_ty::BITS - Self::#bits_ident as u32);
531 if value > mask {
532 return Err(());
533 }
534 let bits = #repr_into(self.0) & !(mask << Self::#offset_ident) | (value & mask) << Self::#offset_ident;
535 self.0 = #repr_from(bits);
536 Ok(())
537 }
538 });
539 }
540 }
541}
542
543#[derive(Debug, PartialEq, Eq, Clone, Copy)]
545enum TypeClass {
546 Bool,
548 UInt,
550 SInt,
552 Other,
554}
555
556struct Field {
558 bits: usize,
559 ty: syn::Type,
560
561 default: TokenStream,
562 into: TokenStream,
563 from: TokenStream,
564
565 access: Access,
566}
567
568fn parse_field(
570 base_ty: &syn::Type,
571 attrs: &[syn::Attribute],
572 ty: &syn::Type,
573 ignore: bool,
574) -> syn::Result<Field> {
575 fn malformed(mut e: syn::Error, attr: &syn::Attribute) -> syn::Error {
576 e.combine(s_err(attr.span(), "malformed #[bits] attribute"));
577 e
578 }
579
580 let access = if ignore {
581 Access::None
582 } else {
583 Access::ReadWrite
584 };
585
586 let (class, ty_bits) = type_info(ty);
588 let mut ret = match class {
589 TypeClass::Bool => Field {
590 bits: ty_bits,
591 ty: ty.clone(),
592 default: quote!(false),
593 into: quote!(this as _),
594 from: quote!(this != 0),
595 access,
596 },
597 TypeClass::SInt => Field {
598 bits: ty_bits,
599 ty: ty.clone(),
600 default: quote!(0),
601 into: quote!(),
602 from: quote!(),
603 access,
604 },
605 TypeClass::UInt => Field {
606 bits: ty_bits,
607 ty: ty.clone(),
608 default: quote!(0),
609 into: quote!(this as _),
610 from: quote!(this as _),
611 access,
612 },
613 TypeClass::Other => Field {
614 bits: ty_bits,
615 ty: ty.clone(),
616 default: quote!(),
617 into: quote!(<#ty>::into_bits(this) as _),
618 from: quote!(<#ty>::from_bits(this as _)),
619 access,
620 },
621 };
622
623 for attr in attrs {
625 let syn::Attribute {
626 style: syn::AttrStyle::Outer,
627 meta: syn::Meta::List(syn::MetaList { path, tokens, .. }),
628 ..
629 } = attr
630 else {
631 continue;
632 };
633 if !path.is_ident("bits") {
634 continue;
635 }
636
637 let span = tokens.span();
638 let BitsAttr {
639 bits,
640 default,
641 into,
642 from,
643 access,
644 } = syn::parse2(tokens.clone()).map_err(|e| malformed(e, attr))?;
645
646 if let Some(bits) = bits {
648 if bits == 0 {
649 return Err(s_err(span, "bits cannot bit 0"));
650 }
651 if ty_bits != 0 && bits > ty_bits {
652 return Err(s_err(span, "overflowing field type"));
653 }
654 ret.bits = bits;
655 }
656
657 if let Some(access) = access {
659 if ignore {
660 return Err(s_err(
661 tokens.span(),
662 "'access' is not supported for padding",
663 ));
664 }
665 ret.access = access;
666 }
667
668 if let Some(into) = into {
670 if ret.access == Access::None {
671 return Err(s_err(into.span(), "'into' is not supported on padding"));
672 }
673 ret.into = quote!(#into(this) as _);
674 }
675 if let Some(from) = from {
676 if ret.access == Access::None {
677 return Err(s_err(from.span(), "'from' is not supported on padding"));
678 }
679 ret.from = quote!(#from(this as _));
680 }
681 if let Some(default) = default {
682 ret.default = default.into_token_stream();
683 }
684 }
685
686 if ret.bits == 0 {
687 return Err(s_err(
688 ty.span(),
689 "Custom types and isize/usize require an explicit bit size",
690 ));
691 }
692
693 if !ignore && ret.access != Access::None && class == TypeClass::SInt {
695 let bits = ret.bits as u32;
696 if ret.into.is_empty() {
697 ret.into = quote! {{
699 let m = #ty::MIN >> (#ty::BITS - #bits);
700 if !(m <= this && this <= -(m + 1)) {
701 return Err(())
702 }
703 let mask = #base_ty::MAX >> (#base_ty::BITS - #bits);
704 (this as #base_ty & mask)
705 }};
706 }
707 if ret.from.is_empty() {
708 ret.from = quote! {{
710 let shift = #ty::BITS - #bits;
711 ((this as #ty) << shift) >> shift
712 }};
713 }
714 }
715
716 Ok(ret)
717}
718
719fn type_info(ty: &syn::Type) -> (TypeClass, usize) {
721 let syn::Type::Path(syn::TypePath { path, .. }) = ty else {
722 return (TypeClass::Other, 0);
723 };
724 let Some(ident) = path.get_ident() else {
725 return (TypeClass::Other, 0);
726 };
727 if ident == "bool" {
728 return (TypeClass::Bool, 1);
729 }
730 if ident == "isize" || ident == "usize" {
731 return (TypeClass::UInt, 0); }
733 macro_rules! integer {
734 ($ident:ident => $($uint:ident),* ; $($sint:ident),*) => {
735 match ident {
736 $(_ if ident == stringify!($uint) => (TypeClass::UInt, $uint::BITS as _),)*
737 $(_ if ident == stringify!($sint) => (TypeClass::SInt, $sint::BITS as _),)*
738 _ => (TypeClass::Other, 0)
739 }
740 };
741 }
742 integer!(ident => u8, u16, u32, u64, u128 ; i8, i16, i32, i64, i128)
743}
744
745fn s_err(span: proc_macro2::Span, msg: impl fmt::Display) -> syn::Error {
746 syn::Error::new(span, msg)
747}
748
749#[cfg(test)]
750mod test {
751 #![allow(clippy::unwrap_used)]
752 use quote::quote;
753
754 use crate::{Access, BitsAttr, Enable, Order, Params};
755
756 #[test]
757 fn parse_args() {
758 let args = quote!(u64);
759 let params = syn::parse2::<Params>(args).unwrap();
760 assert_eq!(params.bits, u64::BITS as usize);
761 assert!(matches!(params.debug, Enable::Yes));
762 assert!(matches!(params.defmt, Enable::No));
763
764 let args = quote!(u32, debug = false);
765 let params = syn::parse2::<Params>(args).unwrap();
766 assert_eq!(params.bits, u32::BITS as usize);
767 assert!(matches!(params.debug, Enable::No));
768 assert!(matches!(params.defmt, Enable::No));
769
770 let args = quote!(u32, defmt = true);
771 let params = syn::parse2::<Params>(args).unwrap();
772 assert_eq!(params.bits, u32::BITS as usize);
773 assert!(matches!(params.debug, Enable::Yes));
774 assert!(matches!(params.defmt, Enable::Yes));
775
776 let args = quote!(u32, defmt = cfg(test), debug = cfg(feature = "foo"));
777 let params = syn::parse2::<Params>(args).unwrap();
778 assert_eq!(params.bits, u32::BITS as usize);
779 assert!(matches!(params.debug, Enable::Cfg(_)));
780 assert!(matches!(params.defmt, Enable::Cfg(_)));
781
782 let args = quote!(u32, order = Msb);
783 let params = syn::parse2::<Params>(args).unwrap();
784 assert!(params.bits == u32::BITS as usize && params.order == Order::Msb);
785 }
786
787 #[test]
788 fn parse_bits() {
789 let args = quote!(8);
790 let attr = syn::parse2::<BitsAttr>(args).unwrap();
791 assert_eq!(attr.bits, Some(8));
792 assert!(attr.default.is_none());
793 assert!(attr.into.is_none());
794 assert!(attr.from.is_none());
795 assert!(attr.access.is_none());
796
797 let args = quote!(8, default = 8, access = RW);
798 let attr = syn::parse2::<BitsAttr>(args).unwrap();
799 assert_eq!(attr.bits, Some(8));
800 assert!(attr.default.is_some());
801 assert!(attr.into.is_none());
802 assert!(attr.from.is_none());
803 assert_eq!(attr.access, Some(Access::ReadWrite));
804
805 let args = quote!(access = RO);
806 let attr = syn::parse2::<BitsAttr>(args).unwrap();
807 assert_eq!(attr.bits, None);
808 assert!(attr.default.is_none());
809 assert!(attr.into.is_none());
810 assert!(attr.from.is_none());
811 assert_eq!(attr.access, Some(Access::ReadOnly));
812
813 let args = quote!(default = 8, access = WO);
814 let attr = syn::parse2::<BitsAttr>(args).unwrap();
815 assert_eq!(attr.bits, None);
816 assert!(attr.default.is_some());
817 assert!(attr.into.is_none());
818 assert!(attr.from.is_none());
819 assert_eq!(attr.access, Some(Access::WriteOnly));
820
821 let args = quote!(
822 3,
823 into = into_something,
824 default = 1,
825 from = from_something,
826 access = None
827 );
828 let attr = syn::parse2::<BitsAttr>(args).unwrap();
829 assert_eq!(attr.bits, Some(3));
830 assert!(attr.default.is_some());
831 assert!(attr.into.is_some());
832 assert!(attr.from.is_some());
833 assert_eq!(attr.access, Some(Access::None));
834 }
835
836 #[test]
837 fn parse_access_mode() {
838 let args = quote!(RW);
839 let mode = syn::parse2::<Access>(args).unwrap();
840 assert_eq!(mode, Access::ReadWrite);
841
842 let args = quote!(RO);
843 let mode = syn::parse2::<Access>(args).unwrap();
844 assert_eq!(mode, Access::ReadOnly);
845
846 let args = quote!(WO);
847 let mode = syn::parse2::<Access>(args).unwrap();
848 assert_eq!(mode, Access::WriteOnly);
849
850 let args = quote!(None);
851 let mode = syn::parse2::<Access>(args).unwrap();
852 assert_eq!(mode, Access::None);
853
854 let args = quote!(garbage);
855 let mode = syn::parse2::<Access>(args);
856 assert!(mode.is_err());
857 }
858}