Skip to main content

bitfld_macro/
lib.rs

1// Copyright (c) 2025 Joshua Seaton
2//
3// Use of this source code is governed by a MIT-style
4// license that can be found in the LICENSE file or at
5// https://opensource.org/licenses/MIT#
6
7use std::collections::HashSet;
8use std::str::FromStr;
9
10use proc_macro::TokenStream;
11use proc_macro2::{Literal, Span, TokenStream as TokenStream2};
12use quote::{ToTokens, format_ident, quote};
13use syn::parse::discouraged::Speculative;
14use syn::parse::{Error, Parse, ParseStream, Result};
15use syn::spanned::Spanned;
16use syn::{
17    Attribute, Expr, ExprLit, Fields, GenericArgument, Ident, ItemStruct, Lit,
18    MetaNameValue, Pat, Path, PathArguments, Stmt, Type, braced,
19    parse_macro_input, parse_quote,
20};
21
22#[proc_macro_attribute]
23pub fn bitfield_repr(attr: TokenStream, item: TokenStream) -> TokenStream {
24    let attr: TokenStream2 = attr.into();
25    let item: TokenStream2 = item.into();
26    quote! {
27        #[repr(#attr)]
28        #[derive(
29            Debug,
30            Eq,
31            PartialEq,
32            ::zerocopy::Immutable,
33            ::zerocopy::IntoBytes,
34            ::zerocopy::TryFromBytes,
35        )]
36        #item
37    }
38    .into()
39}
40
41#[proc_macro]
42pub fn layout(item: TokenStream) -> TokenStream {
43    parse_macro_input!(item as Bitfields)
44        .to_token_stream()
45        .into()
46}
47
48//
49// Parsing of the bitfld type.
50//
51
52enum BaseType {
53    U8,
54    U16,
55    U32,
56    U64,
57    U128,
58    Usize,
59}
60
61impl BaseType {
62    const fn high_bit(&self) -> Option<usize> {
63        match *self {
64            Self::U8 => Some(7),
65            Self::U16 => Some(15),
66            Self::U32 => Some(31),
67            Self::U64 => Some(63),
68            Self::U128 => Some(127),
69            Self::Usize => None,
70        }
71    }
72}
73
74struct BaseTypeDef {
75    def: Type,
76    ty: BaseType,
77}
78
79impl TryFrom<Type> for BaseTypeDef {
80    type Error = Error;
81
82    fn try_from(type_def: Type) -> Result<Self> {
83        const INVALID_BASE_TYPE: &str =
84            "base type must be an unsigned integral type";
85        let Type::Path(ref path_ty) = type_def else {
86            return Err(Error::new_spanned(type_def, INVALID_BASE_TYPE));
87        };
88        let path = &path_ty.path;
89        let ty = if path.is_ident("u8") {
90            BaseType::U8
91        } else if path.is_ident("u16") {
92            BaseType::U16
93        } else if path.is_ident("u32") {
94            BaseType::U32
95        } else if path.is_ident("u64") {
96            BaseType::U64
97        } else if path.is_ident("u128") {
98            BaseType::U128
99        } else if path.is_ident("usize") {
100            BaseType::Usize
101        } else {
102            return Err(Error::new_spanned(path, INVALID_BASE_TYPE));
103        };
104        Ok(Self { def: type_def, ty })
105    }
106}
107
108struct TypeDef {
109    def: ItemStruct,
110    base: BaseTypeDef,
111}
112
113impl Parse for TypeDef {
114    fn parse(input: ParseStream) -> Result<Self> {
115        let mut strct: ItemStruct = input.parse()?;
116
117        // Check for any redundant derives; all other derives are forwarded. If
118        // no repr is specified, then we default to repr(transparent).
119        let mut repr = false;
120        for attr in &strct.attrs {
121            if attr.path().is_ident("derive") {
122                attr.parse_nested_meta(|meta| {
123                    for t in &[
124                        "Copy",
125                        "Clone",
126                        "Debug",
127                        "Default",
128                        "Eq",
129                        "PartialEq",
130                    ] {
131                        if meta.path.is_ident(t) {
132                            return Err(Error::new_spanned(
133                                meta.path,
134                                format!("layout! already derives {t}"),
135                            ));
136                        }
137                    }
138                    Ok(())
139                })?;
140                continue;
141            }
142
143            if attr.path().is_ident("repr") {
144                repr = true;
145            }
146        }
147        if !repr {
148            strct.attrs.push(parse_quote!(#[repr(transparent)]));
149        }
150
151        //
152        let base_type = if let Fields::Unnamed(fields) = &strct.fields {
153            if fields.unnamed.is_empty() {
154                return Err(Error::new_spanned(
155                    &fields.unnamed,
156                    "no base type provided",
157                ));
158            }
159            if fields.unnamed.len() > 1 {
160                return Err(Error::new_spanned(
161                    &fields.unnamed,
162                    "too many tuple fields; only the base type should be provided",
163                ));
164            }
165            BaseTypeDef::try_from(fields.unnamed.first().unwrap().ty.clone())?
166        } else {
167            return Err(Error::new_spanned(
168                &strct.fields,
169                "bitfld type must be defined as a tuple struct",
170            ));
171        };
172
173        if let Some(param) = strct.generics.type_params().next() {
174            return Err(Error::new_spanned(
175                &param.ident,
176                "only const generic parameters are supported",
177            ));
178        }
179        if let Some(param) = strct.generics.lifetimes().next() {
180            return Err(Error::new_spanned(
181                param,
182                "only const generic parameters are supported",
183            ));
184        }
185
186        Ok(Self {
187            def: strct,
188            base: base_type,
189        })
190    }
191}
192
193//
194// Parsing and binding for an individual bitfield.
195//
196
197struct Bitfield {
198    span: Span,
199    name: Option<Ident>,
200    high_bit: usize,
201    low_bit: usize,
202    repr: Option<Type>,
203    cfg_pointer_width: Option<String>,
204    doc_attrs: Vec<Attribute>,
205    unshifted: bool,
206
207    //
208    default: Option<Box<Expr>>,
209
210    //
211    shifted_mask: TokenStream2,
212}
213
214impl Bitfield {
215    fn new(
216        span: Span,
217        name: Option<Ident>,
218        high_bit: usize,
219        low_bit: usize,
220        repr: Option<Type>,
221        cfg_pointer_width: Option<String>,
222        doc_attrs: Vec<Attribute>,
223        unshifted: bool,
224        default: Option<Box<Expr>>,
225    ) -> Self {
226        let shifted_mask = {
227            let num_ones = high_bit - low_bit + 1;
228            let mut mask_str = "0x".to_string();
229
230            // If the first hex digit is not 'f', write that out now. After that,
231            // the remaining width will be a multiple of four, and the remaining
232            // digits will be either f or 0.
233            if !num_ones.is_multiple_of(4) {
234                mask_str.push(match num_ones % 4 {
235                    1 => '1',
236                    2 => '3',
237                    3 => '7',
238                    _ => unreachable!(),
239                });
240            }
241
242            let mut remaining = num_ones / 4;
243            while remaining > 0 {
244                if mask_str.len() > 2 && remaining.is_multiple_of(4) {
245                    mask_str.push('_');
246                }
247                mask_str.push('f');
248                remaining -= 1;
249            }
250            TokenStream2::from_str(&mask_str).unwrap()
251        };
252
253        Self {
254            span,
255            name,
256            high_bit,
257            low_bit,
258            repr,
259            cfg_pointer_width,
260            doc_attrs,
261            unshifted,
262            default,
263            shifted_mask,
264        }
265    }
266
267    const fn is_reserved(&self) -> bool {
268        self.name.is_none()
269    }
270
271    fn display_name(&self) -> String {
272        match &self.name {
273            Some(name) => format!("`{name}`"),
274            None => "reserved".to_string(),
275        }
276    }
277
278    fn display_kind(&self) -> &'static str {
279        if self.bit_width() == 1 {
280            "bit"
281        } else {
282            "field"
283        }
284    }
285
286    fn display_range(&self) -> String {
287        if self.bit_width() == 1 {
288            format!("{}", self.low_bit)
289        } else {
290            format!("[{}:{}]", self.high_bit, self.low_bit)
291        }
292    }
293
294    const fn bit_width(&self) -> usize {
295        self.high_bit - self.low_bit + 1
296    }
297
298    fn minimum_width_integral_type(&self) -> TokenStream2 {
299        match self.bit_width() {
300            2..=8 => quote! {u8},
301            9..=16 => quote! {u16},
302            17..=32 => quote! {u32},
303            33..=64 => quote! {u64},
304            65..=128 => quote! {u128},
305            width => panic!("unexpected integral bit width: {width}"),
306        }
307    }
308
309    fn getter_and_setter(&self, ty: &TypeDef) -> TokenStream2 {
310        debug_assert!(!self.is_reserved());
311
312        let cfg_attr = cfg_attr(self);
313        let doc_attrs = &self.doc_attrs;
314        let name = self.name.as_ref().unwrap();
315        let setter_name = format_ident!("set_{}", name);
316        let type_name = &ty.def.ident;
317
318        let base_type = &ty.base.def;
319        let low_bit = Literal::usize_unsuffixed(self.low_bit);
320        let bit_width = self.bit_width();
321
322        if self.unshifted {
323            debug_assert!(self.repr.is_none());
324
325            let shifted_mask = &self.shifted_mask;
326            let mask = quote! { (#shifted_mask << #low_bit) };
327            let range = self.display_range();
328
329            let get_doc =
330                format!("The unshifted value of `{type_name}{range}`.",);
331            let set_doc =
332                format!("Sets the unshifted value of `{type_name}{range}`.",);
333
334            if bit_width == 1 {
335                let bit_mask = quote! { (1 << #low_bit) };
336                return quote! {
337                    #cfg_attr
338                    #(#doc_attrs)*
339                    #[doc = #get_doc]
340                    #[inline]
341                    pub const fn #name(&self) -> #base_type {
342                        self.0 & #bit_mask
343                    }
344
345                    #cfg_attr
346                    #[doc = #set_doc]
347                    #[inline]
348                    pub const fn #setter_name(&mut self, value: #base_type) -> &mut Self {
349                        debug_assert!((value & !#bit_mask) == 0);
350                        self.0 = (self.0 & !#bit_mask) | (value & #bit_mask);
351                        self
352                    }
353                };
354            }
355
356            return quote! {
357                #cfg_attr
358                #(#doc_attrs)*
359                #[doc = #get_doc]
360                #[inline]
361                pub const fn #name(&self) -> #base_type {
362                    self.0 & #mask
363                }
364
365                #cfg_attr
366                #[doc = #set_doc]
367                #[inline]
368                pub const fn #setter_name(&mut self, value: #base_type) -> &mut Self {
369                    debug_assert!((value & !#mask) == 0);
370                    self.0 = (self.0 & !#mask) | (value & #mask);
371                    self
372                }
373            };
374        }
375
376        if bit_width == 1 {
377            let get_doc =
378                format!("The value of `{type_name}[{}]`.", self.low_bit,);
379            let set_doc =
380                format!("Sets the value of `{type_name}[{}]`.", self.low_bit,);
381            return quote! {
382                #cfg_attr
383                #(#doc_attrs)*
384                #[doc = #get_doc]
385                #[inline]
386                pub const fn #name(&self) -> bool {
387                    (self.0 & (1 << #low_bit)) != 0
388                }
389
390                #cfg_attr
391                #[doc = #set_doc]
392                #[inline]
393                pub const fn #setter_name(&mut self, value: bool) -> &mut Self {
394                    if value {
395                        self.0 |= (1 << #low_bit);
396                    } else {
397                        self.0 &= !(1 << #low_bit);
398                    }
399                    self
400                }
401            };
402        }
403
404        let get_doc = format!(
405            "The value of `{type_name}[{}:{}]`.",
406            self.high_bit, self.low_bit,
407        );
408        let set_doc = format!(
409            "Sets the value of `{type_name}[{}:{}]`.",
410            self.high_bit, self.low_bit,
411        );
412
413        let min_width = self.minimum_width_integral_type();
414        let shifted_mask = &self.shifted_mask;
415        let get_value =
416            quote! { ((self.0 >> #low_bit) & #shifted_mask) as #min_width };
417        let getter = if let Some(repr) = &self.repr {
418            quote! {
419                #(#doc_attrs)*
420                #[doc = #get_doc]
421                #[inline]
422                pub fn #name(&self)
423                    -> ::core::result::Result<#repr, ::bitfld::InvalidBits<#min_width>>
424                where
425                    #repr: ::zerocopy::TryFromBytes,
426                {
427                    use ::zerocopy::IntoBytes;
428                    use ::zerocopy::TryFromBytes;
429                    let value = #get_value;
430                    #repr::try_read_from_bytes(value.as_bytes())
431                        .map_err(|_| ::bitfld::InvalidBits(value))
432                }
433            }
434        } else {
435            quote! {
436                #(#doc_attrs)*
437                #[doc = #get_doc]
438                #[inline]
439                pub const fn #name(&self) -> #min_width {
440                    #get_value
441                }
442            }
443        };
444
445        let set_value = {
446            let value_check = if bit_width >= 8 && bit_width.is_power_of_two() {
447                quote! {}
448            } else {
449                quote! { debug_assert!((value & !#shifted_mask) == 0); }
450            };
451            quote! {
452                #value_check
453                self.0 &= !(#shifted_mask << #low_bit);
454                self.0 |= ((value & #shifted_mask) as #base_type) << #low_bit;
455            }
456        };
457
458        let setter = if let Some(repr) = &self.repr {
459            quote! {
460                #[doc = #set_doc]
461                #[inline]
462                pub fn #setter_name(&mut self, value: #repr) -> &mut Self
463                where
464                    #repr: ::zerocopy::IntoBytes + ::zerocopy::Immutable
465                 {
466                    use ::zerocopy::IntoBytes;
467                    use ::zerocopy::FromBytes;
468                    const { assert!(::core::mem::size_of::<#repr>() == ::core::mem::size_of::<#min_width>()) }
469                    let value = #min_width::read_from_bytes(value.as_bytes()).unwrap();
470                    #set_value
471                    self
472                }
473            }
474        } else {
475            quote! {
476                #[doc = #set_doc]
477                #[inline]
478                pub const fn #setter_name(&mut self, value: #min_width) -> &mut Self {
479                    #set_value
480                    self
481                }
482            }
483        };
484
485        quote! {
486            #cfg_attr
487            #getter
488            #cfg_attr
489            #setter
490        }
491    }
492}
493
494fn cfg_attr(field: &Bitfield) -> Option<TokenStream2> {
495    field.cfg_pointer_width.as_ref().map(|w| {
496        quote! { #[cfg(target_pointer_width = #w)] }
497    })
498}
499
500impl Parse for Bitfield {
501    fn parse(input: ParseStream) -> Result<Self> {
502        const INVALID_BITFIELD_DECL_FORM: &str = "bitfield declaration should take one of the following forms:\n\
503            * `let $name: Bit<$bit> (= $default)?;`\n\
504            * `let $name: Bits<$high, $low (, $repr)?> (= $default)?;`\n\
505            * `let _: Bit<$bit> (= $value)?;`\n\
506            * `let _: Bits<$high, $low> (= $value)?;`";
507        let err = |spanned: &dyn ToTokens| {
508            Error::new_spanned(spanned, INVALID_BITFIELD_DECL_FORM)
509        };
510
511        let stmt = input.parse::<Stmt>()?;
512        let Stmt::Local(ref local) = stmt else {
513            return Err(err(&stmt));
514        };
515
516        let mut doc_attrs = Vec::new();
517        let mut unshifted = false;
518        for attr in &local.attrs {
519            if attr.path().is_ident("doc") {
520                doc_attrs.push(attr.clone());
521            } else if attr.path().is_ident("unshifted") {
522                if unshifted {
523                    return Err(Error::new_spanned(
524                        attr,
525                        "duplicate `#[unshifted]` attribute",
526                    ));
527                }
528                unshifted = true;
529            } else {
530                return Err(Error::new_spanned(
531                    attr,
532                    "attributes are not permitted on individual fields",
533                ));
534            }
535        }
536
537        let Pat::Type(ref pat_type) = local.pat else {
538            return Err(err(&local));
539        };
540
541        let name: Option<Ident> = match *pat_type.pat {
542            Pat::Ident(ref pat_ident) => {
543                if let Some(by_ref) = &pat_ident.by_ref {
544                    return Err(err(by_ref));
545                }
546                if let Some(mutability) = &pat_ident.mutability {
547                    return Err(err(mutability));
548                }
549                if let Some(subpat) = &pat_ident.subpat {
550                    return Err(err(&subpat.0));
551                }
552                Some(pat_ident.ident.clone())
553            }
554            Pat::Wild(_) => None,
555            _ => return Err(err(&*pat_type.pat)),
556        };
557
558        let path: &Path = if let Type::Path(ref type_path) = *pat_type.ty {
559            if type_path.qself.is_some() {
560                return Err(err(&*pat_type.ty));
561            }
562            &type_path.path
563        } else {
564            return Err(err(&*pat_type.ty));
565        };
566
567        let get_bits_and_repr = |bits: &mut [usize]| -> Result<Option<Type>> {
568            let args = &path.segments.first().unwrap().arguments;
569            let args = if let PathArguments::AngleBracketed(bracketed) = args {
570                &bracketed.args
571            } else {
572                return Err(err(&args));
573            };
574            if args.len() < bits.len() || args.len() > bits.len() + 1 {
575                return Err(err(&args));
576            }
577            for (i, bit) in bits.iter_mut().enumerate() {
578                let arg = args.get(i).unwrap();
579                match arg {
580                    GenericArgument::Const(Expr::Lit(ExprLit {
581                        lit: Lit::Int(b),
582                        ..
583                    })) => {
584                        *bit = b.base10_parse()?;
585                    }
586                    _ => return Err(err(&arg)),
587                }
588            }
589            if args.len() == bits.len() + 1 {
590                let arg = args.last().unwrap();
591                if let GenericArgument::Type(repr) = arg {
592                    Ok(Some(repr.clone()))
593                } else {
594                    Err(err(&arg))
595                }
596            } else {
597                Ok(None)
598            }
599        };
600
601        let type_ident = &path.segments.first().unwrap().ident;
602        let (high, low, repr) = if type_ident == "Bits" {
603            let mut bits = [0usize; 2];
604            let repr = get_bits_and_repr(&mut bits)?;
605            if bits[0] < bits[1] {
606                Err(Error::new_spanned(
607                    &path.segments,
608                    "first high bit, then low",
609                ))
610            } else {
611                Ok((bits[0], bits[1], repr))
612            }
613        } else if type_ident == "Bit" {
614            let mut bit = [0usize; 1];
615            let repr = get_bits_and_repr(&mut bit)?;
616            Ok((bit[0], bit[0], repr))
617        } else {
618            Err(err(path))
619        }?;
620
621        let default_or_value = if let Some(ref init) = local.init {
622            if init.diverge.is_some() {
623                return Err(err(local));
624            }
625            Some(init.expr.clone())
626        } else {
627            None
628        };
629
630        if !doc_attrs.is_empty() && name.is_none() {
631            return Err(Error::new_spanned(
632                &doc_attrs[0],
633                "doc comments are not permitted on reserved fields",
634            ));
635        }
636
637        if repr.is_some() {
638            if name.is_none() {
639                return Err(Error::new_spanned(
640                    repr,
641                    "custom representations are not permitted for reserved fields",
642                ));
643            }
644            if high == low {
645                return Err(Error::new_spanned(
646                    repr,
647                    "custom representations are not permitted for bits",
648                ));
649            }
650        }
651
652        if unshifted && name.is_none() {
653            return Err(Error::new_spanned(
654                &local.pat,
655                "`#[unshifted]` is not permitted on reserved fields",
656            ));
657        }
658        if unshifted && repr.is_some() {
659            return Err(Error::new_spanned(
660                repr,
661                "`#[unshifted]` is not permitted on fields with custom representations",
662            ));
663        }
664
665        Ok(Bitfield::new(
666            stmt.span(),
667            name,
668            high,
669            low,
670            repr,
671            None,
672            doc_attrs,
673            unshifted,
674            default_or_value,
675        ))
676    }
677}
678
679/// Parses a `#[cfg(target_pointer_width = "...")]` attribute, returning the
680/// width value on success.
681fn parse_target_pointer_width_cfg(attr: &Attribute) -> Result<String> {
682    const ERROR_MSG: &str = "expected #[cfg(target_pointer_width = \"...\")]";
683    let meta = attr
684        .meta
685        .require_list()
686        .map_err(|_| Error::new_spanned(attr, ERROR_MSG))?;
687    let cfg: MetaNameValue = meta
688        .parse_args()
689        .map_err(|_| Error::new_spanned(attr, ERROR_MSG))?;
690    if !cfg.path.is_ident("target_pointer_width") {
691        return Err(Error::new_spanned(attr, ERROR_MSG));
692    }
693    match cfg.value {
694        Expr::Lit(ExprLit {
695            lit: Lit::Str(s), ..
696        }) => Ok(s.value()),
697        _ => Err(Error::new_spanned(attr, ERROR_MSG)),
698    }
699}
700
701struct Bitfields {
702    ty: TypeDef,
703    named: Vec<Bitfield>,
704    reserved: Vec<Bitfield>,
705    errors: Vec<Error>,
706}
707
708impl Bitfields {
709    fn constants(&self) -> TokenStream2 {
710        let base = &self.ty.base.def;
711        let is_usize = matches!(self.ty.base.ty, BaseType::Usize);
712
713        let mut field_constants = Vec::new();
714        let mut field_metadata = Vec::new();
715        let mut num_field_stmts = Vec::new();
716        let mut checks = Vec::new();
717        let mut default_stmts = Vec::new();
718
719        for field in &self.named {
720            let cfg_attr = cfg_attr(field);
721            let name_lower = field.name.as_ref().unwrap().to_string();
722            let name_upper = name_lower.to_uppercase();
723            let low_bit = Literal::usize_unsuffixed(field.low_bit);
724            let shifted_mask = &field.shifted_mask;
725
726            let mask_name = format_ident!("{name_upper}_MASK");
727            let mask_doc = format!("Unshifted bitmask of `{name_lower}`.");
728            let shift_name = format_ident!("{name_upper}_SHIFT");
729            let shift_doc =
730                format!("Bit shift (i.e., the low bit) of `{name_lower}`.");
731
732            field_constants.push(quote! {
733                #cfg_attr
734                #[doc = #mask_doc]
735                pub const #mask_name: #base = (#shifted_mask << #low_bit);
736                #cfg_attr
737                #[doc = #shift_doc]
738                pub const #shift_name: usize = #low_bit;
739            });
740
741            if let Some(default) = &field.default {
742                let default_name = format_ident!("{name_upper}_DEFAULT");
743                let doc = format!(
744                    "Pre-shifted default value of the `{name_lower}` field.",
745                );
746                field_constants.push(quote! {
747                    #cfg_attr
748                    #[doc = #doc]
749                    pub const #default_name: #base = ((#default) as #base) << #low_bit;
750                });
751                checks.push(quote! {
752                    #cfg_attr
753                    const { assert!(((#default) as #base) << #low_bit & !(#shifted_mask << #low_bit) == 0) }
754                });
755                default_stmts.push(quote! {
756                    #cfg_attr
757                    { v |= Self::#default_name; }
758                });
759            }
760
761            if is_usize {
762                let high_bit = Literal::usize_unsuffixed(field.high_bit);
763                checks.push(quote! {
764                    #cfg_attr
765                    const { assert!(#high_bit < usize::BITS as usize) }
766                });
767            }
768
769            let high_bit = Literal::usize_unsuffixed(field.high_bit);
770            let default = if let Some(default) = &field.default {
771                quote! { #default }
772            } else {
773                quote! { 0 }
774            };
775            field_metadata.push(quote! {
776                #cfg_attr
777                ::bitfld::FieldMetadata::<#base>{
778                    name: #name_lower,
779                    high_bit: #high_bit,
780                    low_bit: #low_bit,
781                    default: #default as #base,
782                },
783            });
784            num_field_stmts.push(quote! {
785                #cfg_attr
786                { n += 1; }
787            });
788        }
789
790        let mut rsvd1_stmts = Vec::new();
791        let mut rsvd0_stmts = Vec::new();
792        for rsvd in &self.reserved {
793            let cfg_attr = cfg_attr(rsvd);
794            let rsvd_value = rsvd.default.as_ref().unwrap();
795            let low_bit = Literal::usize_unsuffixed(rsvd.low_bit);
796            let shifted_mask = &rsvd.shifted_mask;
797            let name = format_ident!("RSVD_{}_{}", rsvd.high_bit, rsvd.low_bit);
798
799            field_constants.push(quote! {
800                #cfg_attr
801                const #name: #base = (#rsvd_value as #base) << #low_bit;
802            });
803            checks.push(quote! {
804                #cfg_attr
805                const { assert!((#rsvd_value as #base) << #low_bit & !(#shifted_mask << #low_bit) == 0) }
806            });
807            rsvd1_stmts.push(quote! {
808                #cfg_attr
809                { v |= Self::#name; }
810            });
811            rsvd0_stmts.push(quote! {
812                #cfg_attr
813                { v |= !Self::#name & (#shifted_mask << #low_bit); }
814            });
815
816            if is_usize {
817                let high_bit = Literal::usize_unsuffixed(rsvd.high_bit);
818                checks.push(quote! {
819                    #cfg_attr
820                    const { assert!(#high_bit < usize::BITS as usize) }
821                });
822            }
823        }
824
825        let num_fields_expr = quote! {
826            {
827                let mut n = 0usize;
828                #(#num_field_stmts)*
829                n
830            }
831        };
832        field_constants.push(quote! {
833            #[doc(hidden)]
834            const NUM_FIELDS: usize = #num_fields_expr;
835            /// Metadata of all named fields in the layout.
836            pub const FIELDS: [::bitfld::FieldMetadata::<#base>; #num_fields_expr] = [
837                #(#field_metadata)*
838            ];
839        });
840
841        let check_fn = if checks.is_empty() {
842            quote! {}
843        } else {
844            let checks = checks.into_iter();
845            quote! {
846                #[forbid(overflowing_literals)]
847                const fn check_defaults() -> () {
848                    #(#checks)*
849                }
850            }
851        };
852
853        quote! {
854            /// Mask of all reserved-as-1 bits.
855            pub const RSVD1_MASK: #base = {
856                let mut v: #base = 0;
857                #(#rsvd1_stmts)*
858                v
859            };
860            /// Mask of all reserved-as-0 bits.
861            pub const RSVD0_MASK: #base = {
862                let mut v: #base = 0;
863                #(#rsvd0_stmts)*
864                v
865            };
866            /// The default value of the layout, combining all field
867            /// defaults and reserved-as values.
868            pub const DEFAULT: #base = {
869                let mut v: #base = Self::RSVD1_MASK;
870                #(#default_stmts)*
871                v
872            };
873
874            #(#field_constants)*
875
876            #check_fn
877        }
878    }
879
880    fn iter_impl(&self) -> TokenStream2 {
881        let ty = &self.ty.def.ident;
882        let base = &self.ty.base.def;
883        let iter_type = format_ident!("{}Iter", ty);
884        let vis = &self.ty.def.vis;
885
886        let generics = &self.ty.def.generics;
887        let (impl_generics, ty_generics, where_clause) =
888            generics.split_for_impl();
889
890        let mut ref_generics = generics.clone();
891        ref_generics.params.insert(0, parse_quote!('a));
892        let (ref_impl_generics, _, _) = ref_generics.split_for_impl();
893
894        quote! {
895            #[doc(hidden)]
896            #vis struct #iter_type #impl_generics (#base, usize, usize) #where_clause;
897
898            impl #impl_generics ::core::iter::Iterator for #iter_type #ty_generics #where_clause {
899                type Item = (&'static ::bitfld::FieldMetadata<#base>, #base);
900
901                fn next(&mut self) -> Option<Self::Item> {
902                    if self.1 >= self.2 {
903                        return None;
904                    }
905                    let metadata = &<#ty #ty_generics>::FIELDS[self.1];
906                    let shifted_mask = (1 << (metadata.high_bit - metadata.low_bit + 1)) - 1;
907                    let value = (self.0 >> metadata.low_bit) & shifted_mask;
908                    self.1 += 1;
909                    Some((metadata, value))
910                }
911            }
912
913            impl #impl_generics ::core::iter::DoubleEndedIterator for #iter_type #ty_generics #where_clause {
914                fn next_back(&mut self) -> Option<Self::Item> {
915                    if self.1 >= self.2 {
916                        return None;
917                    }
918                    self.2 -= 1;
919                    let metadata = &<#ty #ty_generics>::FIELDS[self.2];
920                    let shifted_mask = (1 << (metadata.high_bit - metadata.low_bit + 1)) - 1;
921                    let value = (self.0 >> metadata.low_bit) & shifted_mask;
922                    Some((metadata, value))
923                }
924            }
925
926            impl #impl_generics #ty #ty_generics #where_clause {
927                /// Returns an iterator over
928                /// ([metadata][`bitfld::FieldMetadata`], value) pairs for each
929                /// field.
930                pub fn iter(&self) -> #iter_type #ty_generics {
931                    #iter_type(self.0, 0, Self::NUM_FIELDS)
932                }
933            }
934
935            impl #impl_generics ::core::iter::IntoIterator for #ty #ty_generics #where_clause {
936                type Item = (&'static ::bitfld::FieldMetadata<#base>, #base);
937                type IntoIter = #iter_type #ty_generics;
938
939                fn into_iter(self) -> Self::IntoIter { #iter_type(self.0, 0, Self::NUM_FIELDS) }
940            }
941
942            impl #ref_impl_generics ::core::iter::IntoIterator for &'a #ty #ty_generics #where_clause {
943                type Item = (&'static ::bitfld::FieldMetadata<#base>, #base);
944                type IntoIter = #iter_type #ty_generics;
945
946                fn into_iter(self) -> Self::IntoIter { #iter_type(self.0, 0, <#ty #ty_generics>::NUM_FIELDS) }
947            }
948        }
949    }
950
951    fn getters_and_setters(&self) -> impl Iterator<Item = TokenStream2> + '_ {
952        self.named
953            .iter()
954            .map(|field| field.getter_and_setter(&self.ty))
955    }
956
957    fn fmt_fn(&self, integral_specifier: &str) -> TokenStream2 {
958        let ty_str = &self.ty.def.ident.to_string();
959
960        let mut custom_repr_fields = self
961            .named
962            .iter()
963            .filter(|field| field.repr.is_some())
964            .peekable();
965
966        let where_clause = if custom_repr_fields.peek().is_some() {
967            let bounds = custom_repr_fields.map(|field| {
968                let repr = field.repr.as_ref().unwrap();
969                quote! {#repr: ::core::fmt::Debug,}
970            });
971            quote! {
972                where
973                    #(#bounds)*
974            }
975        } else {
976            quote! {}
977        };
978
979        let fmt_fields = self.named.iter().map(|field| {
980            let cfg_attr = cfg_attr(field);
981            let name = &field.name;
982            let name_str = name.as_ref().unwrap().to_string();
983            let default_specifier = if field.bit_width() == 1 {
984                ""
985            } else {
986                integral_specifier
987            };
988            if field.repr.is_some() {
989                let ok_format_string =
990                    format!("{{indent}}{name_str}: {{:#?}},{{sep}}");
991                let ok_format_string = Literal::string(&ok_format_string);
992                let err_format_string = format!(
993                    "{{indent}}{name_str}: InvalidBits({{{default_specifier}}}),{{sep}}"
994                );
995                let err_format_string = Literal::string(&err_format_string);
996                quote! {
997                    #cfg_attr
998                    {
999                        match self.#name() {
1000                            Ok(value) => write!(f, #ok_format_string, value),
1001                            Err(invalid) => write!(f, #err_format_string, invalid.0),
1002                        }?;
1003                    }
1004                }
1005            } else {
1006                let format_string = format!(
1007                    "{{indent}}{name_str}: {{{default_specifier}}},{{sep}}"
1008                );
1009                let format_string = Literal::string(&format_string);
1010                quote! {
1011                    #cfg_attr
1012                    { write!(f, #format_string, self.#name())?; }
1013                }
1014            }
1015        });
1016
1017        quote! {
1018            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result
1019            #where_clause
1020            {
1021                let (sep, indent) = if f.alternate() {
1022                    ('\n', "    ")
1023                } else {
1024                    (' ', "")
1025                };
1026                write!(f, "{} {{{sep}", #ty_str)?;
1027                #(#fmt_fields)*
1028                write!(f, "}}")
1029            }
1030        }
1031    }
1032
1033    fn fmt_impls(&self) -> TokenStream2 {
1034        let ty = &self.ty.def.ident;
1035        let (impl_generics, ty_generics, where_clause) =
1036            self.ty.def.generics.split_for_impl();
1037        let lower_hex_fmt = self.fmt_fn(":#x");
1038        let upper_hex_fmt = self.fmt_fn(":#X");
1039        let binary_fmt = self.fmt_fn(":#b");
1040        let octal_fmt = self.fmt_fn(":#o");
1041        quote! {
1042            impl #impl_generics ::core::fmt::Debug for #ty #ty_generics #where_clause {
1043                fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
1044                    ::core::fmt::LowerHex::fmt(self, f)
1045                }
1046            }
1047
1048            impl #impl_generics ::core::fmt::Binary for #ty #ty_generics #where_clause {
1049                #binary_fmt
1050            }
1051
1052            impl #impl_generics ::core::fmt::LowerHex for #ty #ty_generics #where_clause {
1053                #lower_hex_fmt
1054            }
1055
1056            impl #impl_generics ::core::fmt::UpperHex for #ty #ty_generics #where_clause {
1057                #upper_hex_fmt
1058            }
1059
1060            impl #impl_generics ::core::fmt::Octal for #ty #ty_generics #where_clause {
1061                #octal_fmt
1062            }
1063        }
1064    }
1065}
1066
1067impl Parse for Bitfields {
1068    fn parse(input: ParseStream) -> Result<Self> {
1069        let input = {
1070            let content;
1071            braced!(content in input);
1072            content
1073        };
1074
1075        let ty = input.parse::<TypeDef>()?;
1076
1077        let input = {
1078            let content;
1079            braced!(content in input);
1080            content
1081        };
1082
1083        let mut fields = Vec::new();
1084        let mut errors = Vec::new();
1085        let is_usize = matches!(ty.base.ty, BaseType::Usize);
1086
1087        // Phase 1: parse cfg blocks. Each is `#[cfg(...)] { fields }`.
1088        // We use a fork to distinguish a cfg block from a bare field with
1089        // a stray attribute (which the field parser will reject).
1090        let mut seen_widths = HashSet::new();
1091        while input.peek(syn::Token![#]) {
1092            let fork = input.fork();
1093            let attrs = fork.call(Attribute::parse_outer)?;
1094            if !fork.peek(syn::token::Brace) {
1095                break;
1096            }
1097            input.advance_to(&fork);
1098
1099            let attr = &attrs[0];
1100            if attrs.len() > 1 {
1101                return Err(Error::new_spanned(
1102                    &attrs[1],
1103                    "expected `{` after cfg attribute",
1104                ));
1105            }
1106            let width = parse_target_pointer_width_cfg(attr)?;
1107
1108            if !is_usize {
1109                errors.push(Error::new_spanned(
1110                    attr,
1111                    "#[cfg] blocks are only permitted in `usize`-based layouts",
1112                ));
1113            }
1114            if seen_widths.contains(width.as_str()) {
1115                errors.push(Error::new_spanned(
1116                    attr,
1117                    format!(
1118                        "duplicate cfg block for target_pointer_width = \"{width}\""
1119                    ),
1120                ));
1121            }
1122
1123            let block;
1124            braced!(block in input);
1125            if block.is_empty() {
1126                errors.push(Error::new_spanned(
1127                    attr,
1128                    "cfg block must contain at least one field",
1129                ));
1130            }
1131            while !block.is_empty() {
1132                let mut field = block.parse::<Bitfield>()?;
1133                field.cfg_pointer_width = Some(width.clone());
1134                fields.push(field);
1135            }
1136            seen_widths.insert(width);
1137        }
1138
1139        // Phase 2: bare fields.
1140        while !input.is_empty() {
1141            fields.push(input.parse::<Bitfield>()?);
1142        }
1143
1144        // Propagate doc comments and `unshifted` across cfg-gated field
1145        // variants: if one variant has them and the other doesn't, copy over.
1146        for i in 0..fields.len() {
1147            if fields[i].name.is_none()
1148                || fields[i].cfg_pointer_width.is_none()
1149                || (fields[i].doc_attrs.is_empty() && !fields[i].unshifted)
1150            {
1151                continue;
1152            }
1153            for j in (i + 1)..fields.len() {
1154                if fields[j].name == fields[i].name
1155                    && fields[j].cfg_pointer_width.is_some()
1156                    && fields[j].cfg_pointer_width
1157                        != fields[i].cfg_pointer_width
1158                {
1159                    if !fields[i].doc_attrs.is_empty()
1160                        && fields[j].doc_attrs.is_empty()
1161                    {
1162                        #[allow(clippy::assigning_clones)]
1163                        {
1164                            fields[j].doc_attrs = fields[i].doc_attrs.clone();
1165                        }
1166                    }
1167                    if fields[i].unshifted {
1168                        fields[j].unshifted = true;
1169                    }
1170                }
1171            }
1172        }
1173
1174        fields.sort_by_key(|field| field.low_bit);
1175
1176        for i in 0..fields.len() {
1177            let curr = &fields[i];
1178
1179            // The might be multiple overlapping fields, and some might be valid
1180            // if mutually excluded due to differing target pointer cfg
1181            // conditions.
1182            for next in fields.iter().skip(i + 1) {
1183                if curr.high_bit < next.low_bit {
1184                    break;
1185                }
1186                let can_overlap = curr.cfg_pointer_width.is_some()
1187                    && next.cfg_pointer_width.is_some()
1188                    && curr.cfg_pointer_width != next.cfg_pointer_width;
1189                if !can_overlap {
1190                    // TODO(https://github.com/rust-lang/rust/issues/54725): It
1191                    // would be nice to Span::join() the two spans, but that's still
1192                    // experimental.
1193                    errors.push(Error::new(
1194                        next.span,
1195                        format!(
1196                            "{} ({} {}) overlaps with {} ({} {})",
1197                            next.display_name(),
1198                            next.display_kind(),
1199                            next.display_range(),
1200                            curr.display_name(),
1201                            curr.display_kind(),
1202                            curr.display_range(),
1203                        ),
1204                    ));
1205                }
1206            }
1207        }
1208
1209        if let Some(highest_possible) = ty.base.ty.high_bit()
1210            && let Some(highest) = fields.last()
1211            && highest.high_bit > highest_possible
1212        {
1213            errors.push(Error::new(
1214                highest.span,
1215                format!(
1216                    "high bit {} exceeds the highest possible value \
1217                     of {highest_possible}",
1218                    highest.high_bit
1219                ),
1220            ));
1221        }
1222
1223        let mut bitfld = Self {
1224            ty,
1225            named: vec![],
1226            reserved: vec![],
1227            errors,
1228        };
1229
1230        while let Some(field) = fields.pop() {
1231            if field.is_reserved() {
1232                if field.default.is_some() {
1233                    bitfld.reserved.push(field);
1234                }
1235            } else {
1236                bitfld.named.push(field);
1237            }
1238        }
1239
1240        Ok(bitfld)
1241    }
1242}
1243
1244impl ToTokens for Bitfields {
1245    fn to_tokens(&self, tokens: &mut TokenStream2) {
1246        let type_def = &self.ty.def;
1247        let type_name = &type_def.ident;
1248        let base = &self.ty.base.def;
1249
1250        let (impl_generics, ty_generics, where_clause) =
1251            type_def.generics.split_for_impl();
1252
1253        if !self.errors.is_empty() {
1254            let errors = self.errors.iter().map(Error::to_compile_error);
1255            quote! {
1256                #[derive(Copy, Clone, Eq, PartialEq)]
1257                #type_def
1258                #(#errors)*
1259            }
1260            .to_tokens(tokens);
1261            return;
1262        }
1263
1264        let constants = self.constants();
1265        let getters_and_setters = self.getters_and_setters();
1266        let iter_impl = self.iter_impl();
1267        let fmt_impls = self.fmt_impls();
1268        quote! {
1269            #[derive(Copy, Clone, Eq, PartialEq)]
1270            #type_def
1271
1272            impl #impl_generics #type_name #ty_generics #where_clause {
1273                #constants
1274
1275                /// Creates a new instance with reserved-as-1 bits set and
1276                /// all other bits zeroed (i.e., with a value of
1277                /// [`Self::RSVD1_MASK`]).
1278                pub const fn new() -> Self {
1279                    Self(Self::RSVD1_MASK)
1280                }
1281
1282                #(#getters_and_setters)*
1283            }
1284
1285            impl #impl_generics ::core::default::Default for #type_name #ty_generics #where_clause {
1286                /// Returns an instance with the default bits set (i.e,. with a
1287                /// value of [`Self::DEFAULT`].
1288                fn default() -> Self {
1289                    Self(Self::DEFAULT)
1290                }
1291            }
1292
1293            impl #impl_generics ::core::convert::From<#base> for #type_name #ty_generics #where_clause {
1294                // `RSVD{0,1}_MASK` may be zero, in which case the following
1295                // mask conditions might be trivially true.
1296                #[allow(clippy::bad_bit_mask)]
1297                fn from(value: #base) -> Self {
1298                    debug_assert!(
1299                        value & Self::RSVD1_MASK == Self::RSVD1_MASK,
1300                        "from(): Invalid base value ({value:#x}) has reserved-as-1 bits ({:#x}) unset",
1301                        Self::RSVD1_MASK,
1302                    );
1303                    debug_assert!(
1304                        !value & Self::RSVD0_MASK == Self::RSVD0_MASK,
1305                        "from(): Invalid base value ({value:#x}) has reserved-as-0 bits ({:#x}) set",
1306                        Self::RSVD0_MASK,
1307                    );
1308                    Self(value)
1309                }
1310            }
1311
1312            impl #impl_generics ::core::ops::Deref for #type_name #ty_generics #where_clause {
1313                type Target = #base;
1314
1315                fn deref(&self) -> &Self::Target {
1316                    &self.0
1317                }
1318            }
1319
1320            impl #impl_generics ::core::ops::DerefMut for #type_name #ty_generics #where_clause {
1321                fn deref_mut(&mut self) -> &mut Self::Target {
1322                    &mut self.0
1323                }
1324            }
1325
1326            #iter_impl
1327
1328            #fmt_impls
1329        }
1330        .to_tokens(tokens);
1331    }
1332}