1#![doc = include_str!("../README.md")]
3#![warn(clippy::unwrap_used)]
4
5use proc_macro as pc;
6use proc_macro2::{Ident, TokenStream};
7use quote::{format_ident, quote, ToTokens};
8use std::{fmt, stringify};
9use syn::{spanned::Spanned, AttrStyle};
10
11mod attr;
12use attr::*;
13mod traits;
14
15#[proc_macro_attribute]
51pub fn bitfield(args: pc::TokenStream, input: pc::TokenStream) -> pc::TokenStream {
52 match bitfield_inner(args.into(), input.into()) {
53 Ok(result) => result.into(),
54 Err(e) => e.into_compile_error().into(),
55 }
56}
57
58#[proc_macro_attribute]
91pub fn bitenum(args: pc::TokenStream, input: pc::TokenStream) -> pc::TokenStream {
92 match bitenum_inner(args.into(), input.into()) {
93 Ok(result) => result.into(),
94 Err(e) => e.into_compile_error().into(),
95 }
96}
97
98fn bitfield_inner(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
99 let input = syn::parse2::<syn::ItemStruct>(input)?;
100 let Params {
101 ty,
102 repr,
103 into,
104 from,
105 bits,
106 binread,
107 binwrite,
108 new,
109 clone,
110 debug,
111 defmt,
112 default,
113 hash,
114 order,
115 conversion,
116 } = syn::parse2(args)?;
117
118 let span = input.fields.span();
119 let name = input.ident;
120 let vis = input.vis;
121 let attrs: TokenStream = input.attrs.iter().map(ToTokens::to_token_stream).collect();
122 let derive = match clone {
123 Enable::No => None,
124 Enable::Yes => Some(quote! { #[derive(Copy, Clone)] }),
125 Enable::Cfg(cfg) => Some(quote! { #[cfg_attr(#cfg, derive(Copy, Clone))] }),
126 };
127
128 let syn::Fields::Named(fields) = input.fields else {
129 return Err(s_err(span, "only named fields are supported"));
130 };
131
132 let mut offset = 0;
133 let mut members = Vec::with_capacity(fields.named.len());
134 for field in fields.named {
135 let f = Member::new(
136 ty.clone(),
137 bits,
138 into.clone(),
139 from.clone(),
140 field,
141 offset,
142 order,
143 )?;
144 offset += f.bits;
145 members.push(f);
146 }
147
148 if offset < bits {
149 return Err(s_err(
150 span,
151 format!(
152 "The bitfield size ({bits} bits) has to be equal to the sum of its fields ({offset} bits). \
153 You might have to add padding (a {} bits large field prefixed with \"_\").",
154 bits - offset
155 ),
156 ));
157 }
158 if offset > bits {
159 return Err(s_err(
160 span,
161 format!(
162 "The size of the fields ({offset} bits) is larger than the type ({bits} bits)."
163 ),
164 ));
165 }
166
167 let mut impl_debug = TokenStream::new();
168 if let Some(cfg) = debug.cfg() {
169 impl_debug.extend(traits::debug(&name, &members, cfg));
170 }
171 if let Some(cfg) = defmt.cfg() {
172 impl_debug.extend(traits::defmt(&name, &members, cfg));
173 }
174 if let Some(cfg) = hash.cfg() {
175 impl_debug.extend(traits::hash(&name, &members, cfg));
176 }
177 if let Some(cfg) = binread.cfg() {
178 impl_debug.extend(traits::binread(&name, &repr, cfg));
179 }
180 if let Some(cfg) = binwrite.cfg() {
181 impl_debug.extend(traits::binwrite(&name, cfg));
182 }
183
184 let defaults = members.iter().map(Member::default).collect::<Vec<_>>();
185
186 let impl_new = new.cfg().map(|cfg| {
187 let attr = cfg.map(|cfg| quote!(#[cfg(#cfg)]));
188 quote! {
189 #attr
191 #vis const fn new() -> Self {
192 let mut this = Self(#from(0));
193 #( #defaults )*
194 this
195 }
196 }
197 });
198
199 let impl_default = default.cfg().map(|cfg| {
200 let attr = cfg.map(|cfg| quote!(#[cfg(#cfg)]));
201 quote! {
202 #attr
203 impl Default for #name {
204 fn default() -> Self {
205 let mut this = Self(#from(0));
206 #( #defaults )*
207 this
208 }
209 }
210 }
211 });
212
213 let conversion = conversion.then(|| {
214 quote! {
215 #vis const fn from_bits(bits: #repr) -> Self {
217 Self(bits)
218 }
219 #vis const fn into_bits(self) -> #repr {
221 self.0
222 }
223 }
224 });
225
226 Ok(quote! {
227 #attrs
228 #derive
229 #[repr(transparent)]
230 #vis struct #name(#repr);
231
232 #[allow(unused_comparisons)]
233 #[allow(clippy::unnecessary_cast)]
234 #[allow(clippy::assign_op_pattern)]
235 #[allow(clippy::double_parens)]
236 impl #name {
237 #impl_new
238
239 #conversion
240
241 #( #members )*
242 }
243
244 #[allow(unused_comparisons)]
245 #[allow(clippy::unnecessary_cast)]
246 #[allow(clippy::assign_op_pattern)]
247 #[allow(clippy::double_parens)]
248 #impl_default
249
250 impl From<#repr> for #name {
251 fn from(v: #repr) -> Self {
252 Self(v)
253 }
254 }
255 impl From<#name> for #repr {
256 fn from(v: #name) -> #repr {
257 v.0
258 }
259 }
260
261 #impl_debug
262 })
263}
264
265struct Member {
267 offset: usize,
268 bits: usize,
269 base_ty: syn::Type,
270 repr_into: Option<syn::Path>,
271 repr_from: Option<syn::Path>,
272 default: TokenStream,
273 inner: Option<MemberInner>,
274}
275
276struct MemberInner {
277 ident: syn::Ident,
278 ty: syn::Type,
279 attrs: Vec<syn::Attribute>,
280 vis: syn::Visibility,
281 into: TokenStream,
282 from: TokenStream,
283}
284
285impl Member {
286 fn new(
287 base_ty: syn::Type,
288 base_bits: usize,
289 repr_into: Option<syn::Path>,
290 repr_from: Option<syn::Path>,
291 field: syn::Field,
292 offset: usize,
293 order: Order,
294 ) -> syn::Result<Self> {
295 let span = field.span();
296
297 let syn::Field {
298 mut attrs,
299 vis,
300 ident,
301 ty,
302 ..
303 } = field;
304
305 let ident = ident.ok_or_else(|| s_err(span, "Not supported"))?;
306 let ignore = ident.to_string().starts_with('_');
307
308 let Field {
309 bits,
310 ty,
311 mut default,
312 into,
313 from,
314 access,
315 } = parse_field(&base_ty, &attrs, &ty, ignore)?;
316
317 let ignore = ignore || access == Access::None;
318
319 let offset = if order == Order::Lsb {
321 offset
322 } else {
323 base_bits - offset - bits
324 };
325
326 if bits > 0 && !ignore {
327 if offset + bits > base_bits {
329 return Err(s_err(
330 ty.span(),
331 "The sum of the members overflows the type size",
332 ));
333 };
334
335 let (from, into) = match access {
337 Access::ReadWrite => (from, into),
338 Access::ReadOnly => (from, quote!()),
339 Access::WriteOnly => (from, into),
340 Access::None => (quote!(), quote!()),
341 };
342
343 if default.is_empty() {
345 if !from.is_empty() {
346 default = quote!({ let this = 0; #from });
347 } else {
348 default = quote!(0);
349 }
350 }
351
352 attrs.retain(|a| !a.path().is_ident("bits"));
354
355 Ok(Self {
356 offset,
357 bits,
358 base_ty,
359 repr_into,
360 repr_from,
361 default,
362 inner: Some(MemberInner {
363 ident,
364 ty,
365 attrs,
366 vis,
367 into,
368 from,
369 }),
370 })
371 } else {
372 if default.is_empty() {
373 default = quote!(0);
374 }
375
376 Ok(Self {
377 offset,
378 bits,
379 base_ty,
380 repr_into,
381 repr_from,
382 default,
383 inner: None,
384 })
385 }
386 }
387
388 fn default(&self) -> TokenStream {
389 let default = &self.default;
390
391 if let Some(inner) = &self.inner {
392 if !inner.into.is_empty() {
393 let ident = &inner.ident;
394 let with_ident = format_ident!("with_{}", ident);
395 return quote!(this = this.#with_ident(#default););
396 }
397 }
398
399 let offset = self.offset;
401 let base_ty = &self.base_ty;
402 let repr_into = &self.repr_into;
403 let repr_from = &self.repr_from;
404 let bits = self.bits as u32;
405
406 quote! {
407 let mask = #base_ty::MAX >> (#base_ty::BITS - #bits);
408 this.0 = #repr_from(#repr_into(this.0) | (((#default as #base_ty) & mask) << #offset));
409 }
410 }
411}
412
413impl ToTokens for Member {
414 fn to_tokens(&self, tokens: &mut TokenStream) {
415 let Self {
416 offset,
417 bits,
418 base_ty,
419 repr_into,
420 repr_from,
421 default: _,
422 inner:
423 Some(MemberInner {
424 ident,
425 ty,
426 attrs,
427 vis,
428 into,
429 from,
430 }),
431 } = self
432 else {
433 return Default::default();
434 };
435
436 let ident_str = ident.to_string().to_uppercase();
437 let ident_upper = Ident::new(
438 ident_str.strip_prefix("R#").unwrap_or(&ident_str),
439 ident.span(),
440 );
441
442 let with_ident = format_ident!("with_{}", ident);
443 let with_ident_checked = format_ident!("with_{}_checked", ident);
444 let set_ident = format_ident!("set_{}", ident);
445 let set_ident_checked = format_ident!("set_{}_checked", ident);
446 let bits_ident = format_ident!("{}_BITS", ident_upper);
447 let offset_ident = format_ident!("{}_OFFSET", ident_upper);
448
449 let location = format!("\n\nBits: {offset}..{}", offset + bits);
450
451 let doc: TokenStream = attrs
452 .iter()
453 .filter(|a| !a.path().is_ident("bits"))
454 .map(ToTokens::to_token_stream)
455 .collect();
456
457 tokens.extend(quote! {
458 const #bits_ident: usize = #bits;
459 const #offset_ident: usize = #offset;
460 });
461
462 if !from.is_empty() {
463 tokens.extend(quote! {
464 #doc
465 #[doc = #location]
466 #vis const fn #ident(&self) -> #ty {
467 let mask = #base_ty::MAX >> (#base_ty::BITS - Self::#bits_ident as u32);
468 let this = (#repr_into(self.0) >> Self::#offset_ident) & mask;
469 #from
470 }
471 });
472 }
473
474 if !into.is_empty() {
475 let (class, _) = type_info(ty);
476 let bounds = if class == TypeClass::SInt {
478 let min = -((u128::MAX >> (128 - (bits - 1))) as i128) - 1;
479 let max = u128::MAX >> (128 - (bits - 1));
480 format!("[{}, {}]", min, max)
481 } else {
482 format!("[0, {}]", u128::MAX >> (128 - bits))
483 };
484 let bounds_error = format!("value out of bounds {bounds}");
485
486 tokens.extend(quote! {
487 #doc
488 #[doc = #location]
489 #vis const fn #with_ident_checked(mut self, value: #ty) -> core::result::Result<Self, ()> {
490 match self.#set_ident_checked(value) {
491 Ok(_) => Ok(self),
492 Err(_) => Err(()),
493 }
494 }
495 #doc
496 #[doc = #location]
497 #[cfg_attr(debug_assertions, track_caller)]
498 #vis const fn #with_ident(mut self, value: #ty) -> Self {
499 self.#set_ident(value);
500 self
501 }
502
503 #doc
504 #[doc = #location]
505 #vis const fn #set_ident(&mut self, value: #ty) {
506 if let Err(_) = self.#set_ident_checked(value) {
507 panic!(#bounds_error)
508 }
509 }
510 #doc
511 #[doc = #location]
512 #vis const fn #set_ident_checked(&mut self, value: #ty) -> core::result::Result<(), ()> {
513 let this = value;
514 let value: #base_ty = #into;
515 let mask = #base_ty::MAX >> (#base_ty::BITS - Self::#bits_ident as u32);
516 if value > mask {
517 return Err(());
518 }
519 let bits = #repr_into(self.0) & !(mask << Self::#offset_ident) | (value & mask) << Self::#offset_ident;
520 self.0 = #repr_from(bits);
521 Ok(())
522 }
523 });
524 }
525 }
526}
527
528#[derive(Debug, PartialEq, Eq, Clone, Copy)]
530enum TypeClass {
531 Bool,
533 UInt,
535 SInt,
537 Other,
539}
540
541struct Field {
543 bits: usize,
544 ty: syn::Type,
545
546 default: TokenStream,
547 into: TokenStream,
548 from: TokenStream,
549
550 access: Access,
551}
552
553fn parse_field(
555 base_ty: &syn::Type,
556 attrs: &[syn::Attribute],
557 ty: &syn::Type,
558 ignore: bool,
559) -> syn::Result<Field> {
560 fn malformed(mut e: syn::Error, attr: &syn::Attribute) -> syn::Error {
561 e.combine(s_err(attr.span(), "malformed #[bits] attribute"));
562 e
563 }
564
565 let access = if ignore {
566 Access::None
567 } else {
568 Access::ReadWrite
569 };
570
571 let (class, ty_bits) = type_info(ty);
573 let mut ret = match class {
574 TypeClass::Bool => Field {
575 bits: ty_bits,
576 ty: ty.clone(),
577 default: quote!(false),
578 into: quote!(this as _),
579 from: quote!(this != 0),
580 access,
581 },
582 TypeClass::SInt => Field {
583 bits: ty_bits,
584 ty: ty.clone(),
585 default: quote!(0),
586 into: quote!(),
587 from: quote!(),
588 access,
589 },
590 TypeClass::UInt => Field {
591 bits: ty_bits,
592 ty: ty.clone(),
593 default: quote!(0),
594 into: quote!(this as _),
595 from: quote!(this as _),
596 access,
597 },
598 TypeClass::Other => Field {
599 bits: ty_bits,
600 ty: ty.clone(),
601 default: quote!(),
602 into: quote!(<#ty>::into_bits(this) as _),
603 from: quote!(<#ty>::from_bits(this as _)),
604 access,
605 },
606 };
607
608 for attr in attrs {
610 let syn::Attribute {
611 style: syn::AttrStyle::Outer,
612 meta: syn::Meta::List(syn::MetaList { path, tokens, .. }),
613 ..
614 } = attr
615 else {
616 continue;
617 };
618 if !path.is_ident("bits") {
619 continue;
620 }
621
622 let span = tokens.span();
623 let BitsAttr {
624 bits,
625 default,
626 into,
627 from,
628 access,
629 } = syn::parse2(tokens.clone()).map_err(|e| malformed(e, attr))?;
630
631 if let Some(bits) = bits {
633 if bits == 0 {
634 return Err(s_err(span, "bits cannot bit 0"));
635 }
636 if ty_bits != 0 && bits > ty_bits {
637 return Err(s_err(span, "overflowing field type"));
638 }
639 ret.bits = bits;
640 }
641
642 if let Some(access) = access {
644 if ignore {
645 return Err(s_err(
646 tokens.span(),
647 "'access' is not supported for padding",
648 ));
649 }
650 ret.access = access;
651 }
652
653 if let Some(into) = into {
655 if ret.access == Access::None {
656 return Err(s_err(into.span(), "'into' is not supported on padding"));
657 }
658 ret.into = quote!(#into(this) as _);
659 }
660 if let Some(from) = from {
661 if ret.access == Access::None {
662 return Err(s_err(from.span(), "'from' is not supported on padding"));
663 }
664 ret.from = quote!(#from(this as _));
665 }
666 if let Some(default) = default {
667 ret.default = default.into_token_stream();
668 }
669 }
670
671 if ret.bits == 0 {
672 return Err(s_err(
673 ty.span(),
674 "Custom types and isize/usize require an explicit bit size",
675 ));
676 }
677
678 if !ignore && ret.access != Access::None && class == TypeClass::SInt {
680 let bits = ret.bits as u32;
681 if ret.into.is_empty() {
682 ret.into = quote! {{
684 let m = #ty::MIN >> (#ty::BITS - #bits);
685 if !(m <= this && this <= -(m + 1)) {
686 return Err(())
687 }
688 let mask = #base_ty::MAX >> (#base_ty::BITS - #bits);
689 (this as #base_ty & mask)
690 }};
691 }
692 if ret.from.is_empty() {
693 ret.from = quote! {{
695 let shift = #ty::BITS - #bits;
696 ((this as #ty) << shift) >> shift
697 }};
698 }
699 }
700
701 Ok(ret)
702}
703
704fn type_info(ty: &syn::Type) -> (TypeClass, usize) {
706 let syn::Type::Path(syn::TypePath { path, .. }) = ty else {
707 return (TypeClass::Other, 0);
708 };
709 let Some(ident) = path.get_ident() else {
710 return (TypeClass::Other, 0);
711 };
712 if ident == "bool" {
713 return (TypeClass::Bool, 1);
714 }
715 if ident == "isize" || ident == "usize" {
716 return (TypeClass::UInt, 0); }
718 macro_rules! integer {
719 ($ident:ident => $($uint:ident),* ; $($sint:ident),*) => {
720 match ident {
721 $(_ if ident == stringify!($uint) => (TypeClass::UInt, $uint::BITS as _),)*
722 $(_ if ident == stringify!($sint) => (TypeClass::SInt, $sint::BITS as _),)*
723 _ => (TypeClass::Other, 0)
724 }
725 };
726 }
727 integer!(ident => u8, u16, u32, u64, u128 ; i8, i16, i32, i64, i128)
728}
729
730fn bitenum_inner(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
731 let params = syn::parse2::<EnumParams>(args)?;
732 let span = input.span();
733 let mut input = syn::parse2::<syn::ItemEnum>(input)?;
734
735 let Some(repr) = input.attrs.iter().find_map(|attr| {
736 if attr.style == AttrStyle::Outer && attr.path().is_ident("repr") {
737 if let syn::Meta::List(syn::MetaList { tokens, .. }) = &attr.meta {
738 let ty = syn::parse2::<syn::Type>(tokens.clone()).ok()?;
739 return Some(ty);
740 }
741 }
742 None
743 }) else {
744 return Err(s_err(span, "missing #[repr(...)] attribute"));
745 };
746
747 let mut fallback = None;
748 let mut variant_names = Vec::new();
749 for variant in &mut input.variants {
750 let len = variant.attrs.len();
751 variant.attrs.retain(|a| !a.path().is_ident("fallback"));
752 if len != variant.attrs.len() {
753 if fallback.is_some() {
754 return Err(s_err(
755 variant.span(),
756 "only one #[fallback] attribute is allowed",
757 ));
758 }
759 fallback = Some(variant.ident.clone());
760 }
761 variant_names.push(variant.ident.clone());
762 }
763 if fallback.is_none() {
764 return Err(s_err(span, "missing #[fallback] attribute on one variant"));
765 }
766 let ident = &input.ident;
767 let vis = &input.vis;
768 let variant_len = variant_names.len();
769
770 let mut impls = TokenStream::new();
771 if let Some(cfg) = params.from.cfg() {
772 impls.extend(quote! {
773 #cfg
775 #vis const fn from_bits(bits: #repr) -> Self {
776 match bits {
777 #( x if x == Self::#variant_names as #repr => Self::#variant_names, )*
778 _ => Self::#fallback,
779 }
780 }
781 });
782 }
783 if let Some(cfg) = params.into.cfg() {
784 impls.extend(quote! {
785 #cfg
787 #vis const fn into_bits(self) -> #repr {
788 self as #repr
789 }
790 });
791 }
792 if let Some(cfg) = params.all.cfg() {
793 impls.extend(quote! {
794 #cfg
796 #vis const fn all() -> [Self; #variant_len] {
797 [ #( Self::#variant_names, )* ]
798 }
799 });
800 }
801 Ok(quote! {
802 #input
803 impl #ident { #impls }
804 })
805}
806
807fn s_err(span: proc_macro2::Span, msg: impl fmt::Display) -> syn::Error {
808 syn::Error::new(span, msg)
809}
810
811#[cfg(test)]
812mod test {
813 #![allow(clippy::unwrap_used)]
814 use quote::quote;
815
816 use crate::{Access, BitsAttr, Enable, Order, Params};
817
818 #[test]
819 fn parse_args() {
820 let args = quote!(u64);
821 let params = syn::parse2::<Params>(args).unwrap();
822 assert_eq!(params.bits, u64::BITS as usize);
823 assert!(matches!(params.debug, Enable::Yes));
824 assert!(matches!(params.defmt, Enable::No));
825
826 let args = quote!(u32, debug = false);
827 let params = syn::parse2::<Params>(args).unwrap();
828 assert_eq!(params.bits, u32::BITS as usize);
829 assert!(matches!(params.debug, Enable::No));
830 assert!(matches!(params.defmt, Enable::No));
831
832 let args = quote!(u32, defmt = true);
833 let params = syn::parse2::<Params>(args).unwrap();
834 assert_eq!(params.bits, u32::BITS as usize);
835 assert!(matches!(params.debug, Enable::Yes));
836 assert!(matches!(params.defmt, Enable::Yes));
837
838 let args = quote!(u32, defmt = cfg(test), debug = cfg(feature = "foo"));
839 let params = syn::parse2::<Params>(args).unwrap();
840 assert_eq!(params.bits, u32::BITS as usize);
841 assert!(matches!(params.debug, Enable::Cfg(_)));
842 assert!(matches!(params.defmt, Enable::Cfg(_)));
843
844 let args = quote!(u32, order = Msb);
845 let params = syn::parse2::<Params>(args).unwrap();
846 assert!(params.bits == u32::BITS as usize && params.order == Order::Msb);
847 }
848
849 #[test]
850 fn parse_bits() {
851 let args = quote!(8);
852 let attr = syn::parse2::<BitsAttr>(args).unwrap();
853 assert_eq!(attr.bits, Some(8));
854 assert!(attr.default.is_none());
855 assert!(attr.into.is_none());
856 assert!(attr.from.is_none());
857 assert!(attr.access.is_none());
858
859 let args = quote!(8, default = 8, access = RW);
860 let attr = syn::parse2::<BitsAttr>(args).unwrap();
861 assert_eq!(attr.bits, Some(8));
862 assert!(attr.default.is_some());
863 assert!(attr.into.is_none());
864 assert!(attr.from.is_none());
865 assert_eq!(attr.access, Some(Access::ReadWrite));
866
867 let args = quote!(access = RO);
868 let attr = syn::parse2::<BitsAttr>(args).unwrap();
869 assert_eq!(attr.bits, None);
870 assert!(attr.default.is_none());
871 assert!(attr.into.is_none());
872 assert!(attr.from.is_none());
873 assert_eq!(attr.access, Some(Access::ReadOnly));
874
875 let args = quote!(default = 8, access = WO);
876 let attr = syn::parse2::<BitsAttr>(args).unwrap();
877 assert_eq!(attr.bits, None);
878 assert!(attr.default.is_some());
879 assert!(attr.into.is_none());
880 assert!(attr.from.is_none());
881 assert_eq!(attr.access, Some(Access::WriteOnly));
882
883 let args = quote!(
884 3,
885 into = into_something,
886 default = 1,
887 from = from_something,
888 access = None
889 );
890 let attr = syn::parse2::<BitsAttr>(args).unwrap();
891 assert_eq!(attr.bits, Some(3));
892 assert!(attr.default.is_some());
893 assert!(attr.into.is_some());
894 assert!(attr.from.is_some());
895 assert_eq!(attr.access, Some(Access::None));
896 }
897
898 #[test]
899 fn parse_access_mode() {
900 let args = quote!(RW);
901 let mode = syn::parse2::<Access>(args).unwrap();
902 assert_eq!(mode, Access::ReadWrite);
903
904 let args = quote!(RO);
905 let mode = syn::parse2::<Access>(args).unwrap();
906 assert_eq!(mode, Access::ReadOnly);
907
908 let args = quote!(WO);
909 let mode = syn::parse2::<Access>(args).unwrap();
910 assert_eq!(mode, Access::WriteOnly);
911
912 let args = quote!(None);
913 let mode = syn::parse2::<Access>(args).unwrap();
914 assert_eq!(mode, Access::None);
915
916 let args = quote!(garbage);
917 let mode = syn::parse2::<Access>(args);
918 assert!(mode.is_err());
919 }
920}