bitfield_struct/
lib.rs

1// Generate docs from readme
2#![doc = include_str!("../README.md")]
3#![warn(clippy::unwrap_used)]
4
5use proc_macro as pc;
6use proc_macro2::{Ident, TokenStream};
7use quote::{format_ident, quote, ToTokens};
8use std::{fmt, stringify};
9use syn::spanned::Spanned;
10
11mod attr;
12use attr::*;
13mod traits;
14
15fn s_err(span: proc_macro2::Span, msg: impl fmt::Display) -> syn::Error {
16    syn::Error::new(span, msg)
17}
18
19/// Creates a bitfield for this struct.
20///
21/// The arguments first, have to begin with the integer type of the bitfield:
22/// For example: `#[bitfield(u64)]`.
23///
24/// It can contain the following additional parameters, like the `debug` argument
25/// for disabling the `Debug` trait generation (`#[bitfield(u64, debug = false)]`).
26///
27/// Parameters of the `bitfield` attribute:
28/// - the bitfield integer type (required)
29/// - `repr` specifies the bitfield's representation in memory
30/// - `from` to specify a conversion function from repr to the bitfield's integer type
31/// - `into` to specify a conversion function from the bitfield's integer type to repr
32/// - `new` to disable the `new` function generation
33/// - `clone` to disable the `Clone` trait generation
34/// - `debug` to disable the `Debug` trait generation
35/// - `defmt` to enable the `defmt::Format` trait generation
36/// - `default` to disable the `Default` trait generation
37/// - `hash` to generate the `Hash` trait
38/// - `order` to specify the bit order (Lsb, Msb)
39/// - `conversion` to disable the generation of `into_bits` and `from_bits`
40///
41/// > For `new`, `clone`, `debug`, `defmt` or `default`, you can either use booleans
42/// > (`#[bitfield(u8, debug = false)]`) or cfg attributes
43/// > (`#[bitfield(u8, debug = cfg(test))]`) to enable/disable them.
44///
45/// Parameters of the `bits` attribute (for fields):
46/// - the number of bits
47/// - `access` to specify the access mode (RW, RO, WO, None)
48/// - `default` to set a default value
49/// - `into` to specify a conversion function from the field type to the bitfield type
50/// - `from` to specify a conversion function from the bitfield type to the field type
51#[proc_macro_attribute]
52pub fn bitfield(args: pc::TokenStream, input: pc::TokenStream) -> pc::TokenStream {
53    match bitfield_inner(args.into(), input.into()) {
54        Ok(result) => result.into(),
55        Err(e) => e.into_compile_error().into(),
56    }
57}
58
59fn bitfield_inner(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
60    let input = syn::parse2::<syn::ItemStruct>(input)?;
61    let Params {
62        ty,
63        repr,
64        into,
65        from,
66        bits,
67        new,
68        clone,
69        debug,
70        defmt,
71        default,
72        hash,
73        order,
74        conversion,
75    } = syn::parse2(args)?;
76
77    let span = input.fields.span();
78    let name = input.ident;
79    let vis = input.vis;
80    let attrs: TokenStream = input.attrs.iter().map(ToTokens::to_token_stream).collect();
81    let derive = match clone {
82        Enable::No => None,
83        Enable::Yes => Some(quote! { #[derive(Copy, Clone)] }),
84        Enable::Cfg(cfg) => Some(quote! { #[cfg_attr(#cfg, derive(Copy, Clone))] }),
85    };
86
87    let syn::Fields::Named(fields) = input.fields else {
88        return Err(s_err(span, "only named fields are supported"));
89    };
90
91    let mut offset = 0;
92    let mut members = Vec::with_capacity(fields.named.len());
93    for field in fields.named {
94        let f = Member::new(
95            ty.clone(),
96            bits,
97            into.clone(),
98            from.clone(),
99            field,
100            offset,
101            order,
102        )?;
103        offset += f.bits;
104        members.push(f);
105    }
106
107    if offset < bits {
108        return Err(s_err(
109            span,
110            format!(
111                "The bitfield size ({bits} bits) has to be equal to the sum of its fields ({offset} bits). \
112                You might have to add padding (a {} bits large field prefixed with \"_\").",
113                bits - offset
114            ),
115        ));
116    }
117    if offset > bits {
118        return Err(s_err(
119            span,
120            format!(
121                "The size of the fields ({offset} bits) is larger than the type ({bits} bits)."
122            ),
123        ));
124    }
125
126    let mut impl_debug = TokenStream::new();
127    if let Some(cfg) = debug.cfg() {
128        impl_debug.extend(traits::debug(&name, &members, cfg));
129    }
130    if let Some(cfg) = defmt.cfg() {
131        impl_debug.extend(traits::defmt(&name, &members, cfg));
132    }
133    if let Some(cfg) = hash.cfg() {
134        impl_debug.extend(traits::hash(&name, &members, cfg));
135    }
136
137    let defaults = members.iter().map(Member::default).collect::<Vec<_>>();
138
139    let impl_new = new.cfg().map(|cfg| {
140        let attr = cfg.map(|cfg| quote!(#[cfg(#cfg)]));
141        quote! {
142            /// Creates a new default initialized bitfield.
143            #attr
144            #vis const fn new() -> Self {
145                let mut this = Self(#from(0));
146                #( #defaults )*
147                this
148            }
149        }
150    });
151
152    let impl_default = default.cfg().map(|cfg| {
153        let attr = cfg.map(|cfg| quote!(#[cfg(#cfg)]));
154        quote! {
155            #attr
156            impl Default for #name {
157                fn default() -> Self {
158                    let mut this = Self(#from(0));
159                    #( #defaults )*
160                    this
161                }
162            }
163        }
164    });
165
166    let conversion = conversion.then(|| {
167        quote! {
168            /// Convert from bits.
169            #vis const fn from_bits(bits: #repr) -> Self {
170                Self(bits)
171            }
172            /// Convert into bits.
173            #vis const fn into_bits(self) -> #repr {
174                self.0
175            }
176        }
177    });
178
179    Ok(quote! {
180        #attrs
181        #derive
182        #[repr(transparent)]
183        #vis struct #name(#repr);
184
185        #[allow(unused_comparisons)]
186        #[allow(clippy::unnecessary_cast)]
187        #[allow(clippy::assign_op_pattern)]
188        impl #name {
189            #impl_new
190
191            #conversion
192
193            #( #members )*
194        }
195
196        #[allow(unused_comparisons)]
197        #[allow(clippy::unnecessary_cast)]
198        #[allow(clippy::assign_op_pattern)]
199        #impl_default
200
201        impl From<#repr> for #name {
202            fn from(v: #repr) -> Self {
203                Self(v)
204            }
205        }
206        impl From<#name> for #repr {
207            fn from(v: #name) -> #repr {
208                v.0
209            }
210        }
211
212        #impl_debug
213    })
214}
215
216/// Represents a member where accessor functions should be generated for.
217struct Member {
218    offset: usize,
219    bits: usize,
220    base_ty: syn::Type,
221    repr_into: Option<syn::Path>,
222    repr_from: Option<syn::Path>,
223    default: TokenStream,
224    inner: Option<MemberInner>,
225}
226
227struct MemberInner {
228    ident: syn::Ident,
229    ty: syn::Type,
230    attrs: Vec<syn::Attribute>,
231    vis: syn::Visibility,
232    into: TokenStream,
233    from: TokenStream,
234}
235
236impl Member {
237    fn new(
238        base_ty: syn::Type,
239        base_bits: usize,
240        repr_into: Option<syn::Path>,
241        repr_from: Option<syn::Path>,
242        field: syn::Field,
243        offset: usize,
244        order: Order,
245    ) -> syn::Result<Self> {
246        let span = field.span();
247
248        let syn::Field {
249            mut attrs,
250            vis,
251            ident,
252            ty,
253            ..
254        } = field;
255
256        let ident = ident.ok_or_else(|| s_err(span, "Not supported"))?;
257        let ignore = ident.to_string().starts_with('_');
258
259        let Field {
260            bits,
261            ty,
262            mut default,
263            into,
264            from,
265            access,
266        } = parse_field(&base_ty, &attrs, &ty, ignore)?;
267
268        let ignore = ignore || access == Access::None;
269
270        // compute the offset
271        let offset = if order == Order::Lsb {
272            offset
273        } else {
274            base_bits - offset - bits
275        };
276
277        if bits > 0 && !ignore {
278            // overflow check
279            if offset + bits > base_bits {
280                return Err(s_err(
281                    ty.span(),
282                    "The sum of the members overflows the type size",
283                ));
284            };
285
286            // clear conversion expr if not needed
287            let (from, into) = match access {
288                Access::ReadWrite => (from, into),
289                Access::ReadOnly => (from, quote!()),
290                Access::WriteOnly => (quote!(), into),
291                Access::None => (quote!(), quote!()),
292            };
293
294            // auto-conversion from zero
295            if default.is_empty() {
296                if !from.is_empty() {
297                    default = quote!({ let this = 0; #from });
298                } else {
299                    default = quote!(0);
300                }
301            }
302
303            // remove our attribute
304            attrs.retain(|a| !a.path().is_ident("bits"));
305
306            Ok(Self {
307                offset,
308                bits,
309                base_ty,
310                repr_into,
311                repr_from,
312                default,
313                inner: Some(MemberInner {
314                    ident,
315                    ty,
316                    attrs,
317                    vis,
318                    into,
319                    from,
320                }),
321            })
322        } else {
323            if default.is_empty() {
324                default = quote!(0);
325            }
326
327            Ok(Self {
328                offset,
329                bits,
330                base_ty,
331                repr_into,
332                repr_from,
333                default,
334                inner: None,
335            })
336        }
337    }
338
339    fn default(&self) -> TokenStream {
340        let default = &self.default;
341
342        if let Some(inner) = &self.inner {
343            if !inner.into.is_empty() {
344                let ident = &inner.ident;
345                let with_ident = format_ident!("with_{}", ident);
346                return quote!(this = this.#with_ident(#default););
347            }
348        }
349
350        // fallback when there is no setter
351        let offset = self.offset;
352        let base_ty = &self.base_ty;
353        let repr_into = &self.repr_into;
354        let repr_from = &self.repr_from;
355        let bits = self.bits as u32;
356
357        quote! {
358            let mask = #base_ty::MAX >> (#base_ty::BITS - #bits);
359            this.0 = #repr_from(#repr_into(this.0) | (((#default as #base_ty) & mask) << #offset));
360        }
361    }
362}
363
364impl ToTokens for Member {
365    fn to_tokens(&self, tokens: &mut TokenStream) {
366        let Self {
367            offset,
368            bits,
369            base_ty,
370            repr_into,
371            repr_from,
372            default: _,
373            inner:
374                Some(MemberInner {
375                    ident,
376                    ty,
377                    attrs,
378                    vis,
379                    into,
380                    from,
381                }),
382        } = self
383        else {
384            return Default::default();
385        };
386
387        let ident_str = ident.to_string().to_uppercase();
388        let ident_upper = Ident::new(
389            ident_str.strip_prefix("R#").unwrap_or(&ident_str),
390            ident.span(),
391        );
392
393        let with_ident = format_ident!("with_{}", ident);
394        let with_ident_checked = format_ident!("with_{}_checked", ident);
395        let set_ident = format_ident!("set_{}", ident);
396        let set_ident_checked = format_ident!("set_{}_checked", ident);
397        let bits_ident = format_ident!("{}_BITS", ident_upper);
398        let offset_ident = format_ident!("{}_OFFSET", ident_upper);
399
400        let location = format!("\n\nBits: {offset}..{}", offset + bits);
401
402        let doc: TokenStream = attrs
403            .iter()
404            .filter(|a| !a.path().is_ident("bits"))
405            .map(ToTokens::to_token_stream)
406            .collect();
407
408        tokens.extend(quote! {
409            const #bits_ident: usize = #bits;
410            const #offset_ident: usize = #offset;
411        });
412
413        if !from.is_empty() {
414            tokens.extend(quote! {
415                #doc
416                #[doc = #location]
417                #vis const fn #ident(&self) -> #ty {
418                    let mask = #base_ty::MAX >> (#base_ty::BITS - Self::#bits_ident as u32);
419                    let this = (#repr_into(self.0) >> Self::#offset_ident) & mask;
420                    #from
421                }
422            });
423        }
424
425        if !into.is_empty() {
426            let (class, _) = type_info(ty);
427            // generate static strings for the error messages (due to const)
428            let bounds = if class == TypeClass::SInt {
429                let min = -((u128::MAX >> (128 - (bits - 1))) as i128) - 1;
430                let max = u128::MAX >> (128 - (bits - 1));
431                format!("[{}, {}]", min, max)
432            } else {
433                format!("[0, {}]", u128::MAX >> (128 - bits))
434            };
435            let bounds_error = format!("value out of bounds {bounds}");
436
437            tokens.extend(quote! {
438                #doc
439                #[doc = #location]
440                #vis const fn #with_ident_checked(mut self, value: #ty) -> core::result::Result<Self, ()> {
441                    match self.#set_ident_checked(value) {
442                        Ok(_) => Ok(self),
443                        Err(_) => Err(()),
444                    }
445                }
446                #doc
447                #[doc = #location]
448                #[cfg_attr(debug_assertions, track_caller)]
449                #vis const fn #with_ident(mut self, value: #ty) -> Self {
450                    self.#set_ident(value);
451                    self
452                }
453
454                #doc
455                #[doc = #location]
456                #vis const fn #set_ident(&mut self, value: #ty) {
457                    if let Err(_) = self.#set_ident_checked(value) {
458                        panic!(#bounds_error)
459                    }
460                }
461                #doc
462                #[doc = #location]
463                #vis const fn #set_ident_checked(&mut self, value: #ty) -> core::result::Result<(), ()> {
464                    let this = value;
465                    let value: #base_ty = #into;
466                    let mask = #base_ty::MAX >> (#base_ty::BITS - Self::#bits_ident as u32);
467                    if value > mask {
468                        return Err(());
469                    }
470                    let bits = #repr_into(self.0) & !(mask << Self::#offset_ident) | (value & mask) << Self::#offset_ident;
471                    self.0 = #repr_from(bits);
472                    Ok(())
473                }
474            });
475        }
476    }
477}
478
479/// Distinguish between different types for code generation.
480#[derive(Debug, PartialEq, Eq, Clone, Copy)]
481enum TypeClass {
482    /// Booleans with 1 bit size
483    Bool,
484    /// Unsigned ints with fixes sizes: u8, u64, ...
485    UInt,
486    /// Signed ints with fixes sizes: i8, i64, ...
487    SInt,
488    /// Custom types
489    Other,
490}
491
492/// Field information, including the `bits` attribute
493struct Field {
494    bits: usize,
495    ty: syn::Type,
496
497    default: TokenStream,
498    into: TokenStream,
499    from: TokenStream,
500
501    access: Access,
502}
503
504/// Parses the `bits` attribute that allows specifying a custom number of bits.
505fn parse_field(
506    base_ty: &syn::Type,
507    attrs: &[syn::Attribute],
508    ty: &syn::Type,
509    ignore: bool,
510) -> syn::Result<Field> {
511    fn malformed(mut e: syn::Error, attr: &syn::Attribute) -> syn::Error {
512        e.combine(s_err(attr.span(), "malformed #[bits] attribute"));
513        e
514    }
515
516    let access = if ignore {
517        Access::None
518    } else {
519        Access::ReadWrite
520    };
521
522    // Defaults for the different types
523    let (class, ty_bits) = type_info(ty);
524    let mut ret = match class {
525        TypeClass::Bool => Field {
526            bits: ty_bits,
527            ty: ty.clone(),
528            default: quote!(false),
529            into: quote!(this as _),
530            from: quote!(this != 0),
531            access,
532        },
533        TypeClass::SInt => Field {
534            bits: ty_bits,
535            ty: ty.clone(),
536            default: quote!(0),
537            into: quote!(),
538            from: quote!(),
539            access,
540        },
541        TypeClass::UInt => Field {
542            bits: ty_bits,
543            ty: ty.clone(),
544            default: quote!(0),
545            into: quote!(this as _),
546            from: quote!(this as _),
547            access,
548        },
549        TypeClass::Other => Field {
550            bits: ty_bits,
551            ty: ty.clone(),
552            default: quote!(),
553            into: quote!(<#ty>::into_bits(this) as _),
554            from: quote!(<#ty>::from_bits(this as _)),
555            access,
556        },
557    };
558
559    // Find and parse the bits attribute
560    for attr in attrs {
561        let syn::Attribute {
562            style: syn::AttrStyle::Outer,
563            meta: syn::Meta::List(syn::MetaList { path, tokens, .. }),
564            ..
565        } = attr
566        else {
567            continue;
568        };
569        if !path.is_ident("bits") {
570            continue;
571        }
572
573        let span = tokens.span();
574        let BitsAttr {
575            bits,
576            default,
577            into,
578            from,
579            access,
580        } = syn::parse2(tokens.clone()).map_err(|e| malformed(e, attr))?;
581
582        // bit size
583        if let Some(bits) = bits {
584            if bits == 0 {
585                return Err(s_err(span, "bits cannot bit 0"));
586            }
587            if ty_bits != 0 && bits > ty_bits {
588                return Err(s_err(span, "overflowing field type"));
589            }
590            ret.bits = bits;
591        }
592
593        // read/write access
594        if let Some(access) = access {
595            if ignore {
596                return Err(s_err(
597                    tokens.span(),
598                    "'access' is not supported for padding",
599                ));
600            }
601            ret.access = access;
602        }
603
604        // conversion
605        if let Some(into) = into {
606            if ret.access == Access::None {
607                return Err(s_err(into.span(), "'into' is not supported on padding"));
608            }
609            ret.into = quote!(#into(this) as _);
610        }
611        if let Some(from) = from {
612            if ret.access == Access::None {
613                return Err(s_err(from.span(), "'from' is not supported on padding"));
614            }
615            ret.from = quote!(#from(this as _));
616        }
617        if let Some(default) = default {
618            ret.default = default.into_token_stream();
619        }
620    }
621
622    if ret.bits == 0 {
623        return Err(s_err(
624            ty.span(),
625            "Custom types and isize/usize require an explicit bit size",
626        ));
627    }
628
629    // Signed integers need some special handling...
630    if !ignore && ret.access != Access::None && class == TypeClass::SInt {
631        let bits = ret.bits as u32;
632        if ret.into.is_empty() {
633            // Bounds check and remove leading ones from negative values
634            ret.into = quote! {{
635                let m = #ty::MIN >> (#ty::BITS - #bits);
636                if !(m <= this && this <= -(m + 1)) {
637                    return Err(())
638                }
639                let mask = #base_ty::MAX >> (#base_ty::BITS - #bits);
640                (this as #base_ty & mask)
641            }};
642        }
643        if ret.from.is_empty() {
644            // Sign extend negative values
645            ret.from = quote! {{
646                let shift = #ty::BITS - #bits;
647                ((this as #ty) << shift) >> shift
648            }};
649        }
650    }
651
652    Ok(ret)
653}
654
655
656/// Returns the number of bits for a given type
657fn type_info(ty: &syn::Type) -> (TypeClass, usize) {
658    let syn::Type::Path(syn::TypePath { path, .. }) = ty else {
659        return (TypeClass::Other, 0);
660    };
661    let Some(ident) = path.get_ident() else {
662        return (TypeClass::Other, 0);
663    };
664    if ident == "bool" {
665        return (TypeClass::Bool, 1);
666    }
667    if ident == "isize" || ident == "usize" {
668        return (TypeClass::UInt, 0); // they have architecture dependend sizes
669    }
670    macro_rules! integer {
671        ($ident:ident => $($uint:ident),* ; $($sint:ident),*) => {
672            match ident {
673                $(_ if ident == stringify!($uint) => (TypeClass::UInt, $uint::BITS as _),)*
674                $(_ if ident == stringify!($sint) => (TypeClass::SInt, $sint::BITS as _),)*
675                _ => (TypeClass::Other, 0)
676            }
677        };
678    }
679    integer!(ident => u8, u16, u32, u64, u128 ; i8, i16, i32, i64, i128)
680}
681
682#[cfg(test)]
683mod test {
684    #![allow(clippy::unwrap_used)]
685    use quote::quote;
686
687    use crate::{Access, BitsAttr, Enable, Order, Params};
688
689    #[test]
690    fn parse_args() {
691        let args = quote!(u64);
692        let params = syn::parse2::<Params>(args).unwrap();
693        assert_eq!(params.bits, u64::BITS as usize);
694        assert!(matches!(params.debug, Enable::Yes));
695        assert!(matches!(params.defmt, Enable::No));
696
697        let args = quote!(u32, debug = false);
698        let params = syn::parse2::<Params>(args).unwrap();
699        assert_eq!(params.bits, u32::BITS as usize);
700        assert!(matches!(params.debug, Enable::No));
701        assert!(matches!(params.defmt, Enable::No));
702
703        let args = quote!(u32, defmt = true);
704        let params = syn::parse2::<Params>(args).unwrap();
705        assert_eq!(params.bits, u32::BITS as usize);
706        assert!(matches!(params.debug, Enable::Yes));
707        assert!(matches!(params.defmt, Enable::Yes));
708
709        let args = quote!(u32, defmt = cfg(test), debug = cfg(feature = "foo"));
710        let params = syn::parse2::<Params>(args).unwrap();
711        assert_eq!(params.bits, u32::BITS as usize);
712        assert!(matches!(params.debug, Enable::Cfg(_)));
713        assert!(matches!(params.defmt, Enable::Cfg(_)));
714
715        let args = quote!(u32, order = Msb);
716        let params = syn::parse2::<Params>(args).unwrap();
717        assert!(params.bits == u32::BITS as usize && params.order == Order::Msb);
718    }
719
720    #[test]
721    fn parse_bits() {
722        let args = quote!(8);
723        let attr = syn::parse2::<BitsAttr>(args).unwrap();
724        assert_eq!(attr.bits, Some(8));
725        assert!(attr.default.is_none());
726        assert!(attr.into.is_none());
727        assert!(attr.from.is_none());
728        assert!(attr.access.is_none());
729
730        let args = quote!(8, default = 8, access = RW);
731        let attr = syn::parse2::<BitsAttr>(args).unwrap();
732        assert_eq!(attr.bits, Some(8));
733        assert!(attr.default.is_some());
734        assert!(attr.into.is_none());
735        assert!(attr.from.is_none());
736        assert_eq!(attr.access, Some(Access::ReadWrite));
737
738        let args = quote!(access = RO);
739        let attr = syn::parse2::<BitsAttr>(args).unwrap();
740        assert_eq!(attr.bits, None);
741        assert!(attr.default.is_none());
742        assert!(attr.into.is_none());
743        assert!(attr.from.is_none());
744        assert_eq!(attr.access, Some(Access::ReadOnly));
745
746        let args = quote!(default = 8, access = WO);
747        let attr = syn::parse2::<BitsAttr>(args).unwrap();
748        assert_eq!(attr.bits, None);
749        assert!(attr.default.is_some());
750        assert!(attr.into.is_none());
751        assert!(attr.from.is_none());
752        assert_eq!(attr.access, Some(Access::WriteOnly));
753
754        let args = quote!(
755            3,
756            into = into_something,
757            default = 1,
758            from = from_something,
759            access = None
760        );
761        let attr = syn::parse2::<BitsAttr>(args).unwrap();
762        assert_eq!(attr.bits, Some(3));
763        assert!(attr.default.is_some());
764        assert!(attr.into.is_some());
765        assert!(attr.from.is_some());
766        assert_eq!(attr.access, Some(Access::None));
767    }
768
769    #[test]
770    fn parse_access_mode() {
771        let args = quote!(RW);
772        let mode = syn::parse2::<Access>(args).unwrap();
773        assert_eq!(mode, Access::ReadWrite);
774
775        let args = quote!(RO);
776        let mode = syn::parse2::<Access>(args).unwrap();
777        assert_eq!(mode, Access::ReadOnly);
778
779        let args = quote!(WO);
780        let mode = syn::parse2::<Access>(args).unwrap();
781        assert_eq!(mode, Access::WriteOnly);
782
783        let args = quote!(None);
784        let mode = syn::parse2::<Access>(args).unwrap();
785        assert_eq!(mode, Access::None);
786
787        let args = quote!(garbage);
788        let mode = syn::parse2::<Access>(args);
789        assert!(mode.is_err());
790    }
791}