Skip to main content

nexus_bits_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{
5    Data, DeriveInput, Error, Fields, Ident, Result, Type, parse::Parser, parse_macro_input,
6};
7
8// =============================================================================
9// IntEnum derive
10// =============================================================================
11
12#[proc_macro_derive(IntEnum)]
13pub fn derive_int_enum(input: TokenStream) -> TokenStream {
14    let input = parse_macro_input!(input as DeriveInput);
15
16    match derive_int_enum_impl(input) {
17        Ok(tokens) => tokens.into(),
18        Err(err) => err.to_compile_error().into(),
19    }
20}
21
22fn derive_int_enum_impl(input: DeriveInput) -> Result<TokenStream2> {
23    let variants = match &input.data {
24        Data::Enum(data) => &data.variants,
25        _ => {
26            return Err(Error::new_spanned(
27                &input,
28                "IntEnum can only be derived for enums",
29            ));
30        }
31    };
32
33    let repr = parse_repr(&input)?;
34
35    for variant in variants {
36        if !matches!(variant.fields, Fields::Unit) {
37            return Err(Error::new_spanned(
38                variant,
39                "IntEnum variants cannot have fields",
40            ));
41        }
42    }
43
44    let name = &input.ident;
45    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
46
47    let from_arms = variants.iter().map(|v| {
48        let variant_name = &v.ident;
49        quote! {
50            x if x == #name::#variant_name as #repr => Some(#name::#variant_name),
51        }
52    });
53
54    Ok(quote! {
55        impl #impl_generics nexus_bits::IntEnum for #name #ty_generics #where_clause {
56            type Repr = #repr;
57
58            #[inline]
59            fn into_repr(self) -> #repr {
60                self as #repr
61            }
62
63            #[inline]
64            fn try_from_repr(repr: #repr) -> Option<Self> {
65                match repr {
66                    #(#from_arms)*
67                    _ => None,
68                }
69            }
70        }
71    })
72}
73
74fn parse_repr(input: &DeriveInput) -> Result<Ident> {
75    for attr in &input.attrs {
76        if attr.path().is_ident("repr") {
77            let repr: Ident = attr.parse_args()?;
78            match repr.to_string().as_str() {
79                "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64"
80                | "i128" => {
81                    return Ok(repr);
82                }
83                _ => {
84                    return Err(Error::new_spanned(
85                        repr,
86                        "IntEnum requires a primitive integer repr (u8..u128, i8..i128)",
87                    ));
88                }
89            }
90        }
91    }
92
93    Err(Error::new_spanned(
94        input,
95        "IntEnum requires a #[repr(u8/u16/u32/u64/i8/i16/i32/i64)] attribute",
96    ))
97}
98
99// =============================================================================
100// BitStorage attribute macro
101// =============================================================================
102
103#[proc_macro_attribute]
104pub fn bit_storage(attr: TokenStream, item: TokenStream) -> TokenStream {
105    let attr = proc_macro2::TokenStream::from(attr);
106    let item = parse_macro_input!(item as DeriveInput);
107
108    match bit_storage_impl(attr, item) {
109        Ok(tokens) => tokens.into(),
110        Err(err) => err.to_compile_error().into(),
111    }
112}
113
114fn bit_storage_impl(attr: TokenStream2, input: DeriveInput) -> Result<TokenStream2> {
115    let storage_attr = parse_storage_attr_tokens(attr)?;
116
117    match &input.data {
118        Data::Struct(data) => derive_storage_struct(&input, data, storage_attr),
119        Data::Enum(data) => derive_storage_enum(&input, data, storage_attr),
120        Data::Union(_) => Err(Error::new_spanned(
121            &input,
122            "bit_storage cannot be applied to unions",
123        )),
124    }
125}
126
127// =============================================================================
128// Attribute types
129// =============================================================================
130
131/// Parsed #[bit_storage(repr = T)] or #[bit_storage(repr = T, discriminant(start = N, len = M))]
132struct StorageAttr {
133    repr: Ident,
134    discriminant: Option<BitRange>,
135}
136
137/// Bit range for a field
138struct BitRange {
139    start: u32,
140    len: u32,
141}
142
143/// Parsed field/flag from struct
144enum MemberDef {
145    Field {
146        name: Ident,
147        ty: Type,
148        range: BitRange,
149    },
150    Flag {
151        name: Ident,
152        bit: u32,
153    },
154}
155
156impl MemberDef {
157    fn name(&self) -> &Ident {
158        match self {
159            MemberDef::Field { name, .. } => name,
160            MemberDef::Flag { name, .. } => name,
161        }
162    }
163}
164
165// =============================================================================
166// Attribute parsing
167// =============================================================================
168
169fn parse_storage_attr_tokens(attr: TokenStream2) -> Result<StorageAttr> {
170    let mut repr = None;
171    let mut discriminant = None;
172
173    let parser = syn::meta::parser(|meta| {
174        if meta.path.is_ident("repr") {
175            meta.input.parse::<syn::Token![=]>()?;
176            repr = Some(meta.input.parse::<Ident>()?);
177            Ok(())
178        } else if meta.path.is_ident("discriminant") {
179            let content;
180            syn::parenthesized!(content in meta.input);
181            discriminant = Some(parse_bit_range(&content)?);
182            Ok(())
183        } else {
184            Err(meta.error("expected `repr` or `discriminant`"))
185        }
186    });
187
188    parser.parse2(attr)?;
189
190    let repr = repr.ok_or_else(|| {
191        Error::new(
192            proc_macro2::Span::call_site(),
193            "bit_storage requires `repr = ...`",
194        )
195    })?;
196
197    // Validate repr
198    match repr.to_string().as_str() {
199        "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" => {}
200        _ => return Err(Error::new_spanned(&repr, "repr must be an integer type")),
201    }
202
203    Ok(StorageAttr { repr, discriminant })
204}
205
206fn parse_bit_range(input: syn::parse::ParseStream) -> Result<BitRange> {
207    let mut start = None;
208    let mut len = None;
209
210    while !input.is_empty() {
211        let ident: Ident = input.parse()?;
212        input.parse::<syn::Token![=]>()?;
213        let lit: syn::LitInt = input.parse()?;
214        let value: u32 = lit.base10_parse()?;
215
216        match ident.to_string().as_str() {
217            "start" => start = Some(value),
218            "len" => len = Some(value),
219            _ => return Err(Error::new_spanned(ident, "expected `start` or `len`")),
220        }
221
222        if input.peek(syn::Token![,]) {
223            input.parse::<syn::Token![,]>()?;
224        }
225    }
226
227    let start = start.ok_or_else(|| Error::new(input.span(), "missing `start`"))?;
228    let len = len.ok_or_else(|| Error::new(input.span(), "missing `len`"))?;
229
230    if len == 0 {
231        return Err(Error::new(input.span(), "len must be > 0"));
232    }
233
234    Ok(BitRange { start, len })
235}
236
237fn parse_member(field: &syn::Field) -> Result<MemberDef> {
238    let name = field
239        .ident
240        .clone()
241        .ok_or_else(|| Error::new_spanned(field, "tuple structs not supported"))?;
242    let ty = field.ty.clone();
243
244    for attr in &field.attrs {
245        if attr.path().is_ident("field") {
246            let range = attr.parse_args_with(parse_bit_range)?;
247            return Ok(MemberDef::Field { name, ty, range });
248        } else if attr.path().is_ident("flag") {
249            let bit: syn::LitInt = attr.parse_args()?;
250            let bit: u32 = bit.base10_parse()?;
251            return Ok(MemberDef::Flag { name, bit });
252        }
253    }
254
255    Err(Error::new_spanned(
256        field,
257        "field requires #[field(start = N, len = M)] or #[flag(N)] attribute",
258    ))
259}
260
261fn parse_variant_attr(attrs: &[syn::Attribute]) -> Result<u64> {
262    for attr in attrs {
263        if attr.path().is_ident("variant") {
264            let lit: syn::LitInt = attr.parse_args()?;
265            return lit.base10_parse();
266        }
267    }
268    Err(Error::new(
269        proc_macro2::Span::call_site(),
270        "enum variant requires #[variant(N)] attribute",
271    ))
272}
273
274// =============================================================================
275// Helpers
276// =============================================================================
277
278fn is_primitive(ty: &Type) -> bool {
279    if let Type::Path(type_path) = ty {
280        if let Some(ident) = type_path.path.get_ident() {
281            return matches!(
282                ident.to_string().as_str(),
283                "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128"
284            );
285        }
286    }
287    false
288}
289
290fn repr_bits(repr: &Ident) -> u32 {
291    match repr.to_string().as_str() {
292        "u8" | "i8" => 8,
293        "u16" | "i16" => 16,
294        "u32" | "i32" => 32,
295        "u64" | "i64" => 64,
296        "u128" | "i128" => 128,
297        _ => 0,
298    }
299}
300
301// =============================================================================
302// Validation
303// =============================================================================
304
305fn validate_members(members: &[MemberDef], repr: &Ident) -> Result<()> {
306    let bits = repr_bits(repr);
307
308    // Check each field fits
309    for member in members {
310        match member {
311            MemberDef::Field { name, range, .. } => {
312                if range.start + range.len > bits {
313                    return Err(Error::new_spanned(
314                        name,
315                        format!(
316                            "field exceeds {} bits (start {} + len {} = {})",
317                            bits,
318                            range.start,
319                            range.len,
320                            range.start + range.len
321                        ),
322                    ));
323                }
324            }
325            MemberDef::Flag { name, bit, .. } => {
326                if *bit >= bits {
327                    return Err(Error::new_spanned(
328                        name,
329                        format!("flag bit {} exceeds {} bits", bit, bits),
330                    ));
331                }
332            }
333        }
334    }
335
336    // Check no overlap (simple O(n²) for now)
337    for (i, a) in members.iter().enumerate() {
338        for b in members.iter().skip(i + 1) {
339            if ranges_overlap(a, b) {
340                return Err(Error::new_spanned(
341                    b.name(),
342                    format!("field '{}' overlaps with '{}'", b.name(), a.name()),
343                ));
344            }
345        }
346    }
347
348    Ok(())
349}
350
351fn ranges_overlap(a: &MemberDef, b: &MemberDef) -> bool {
352    let (a_start, a_end) = member_bit_range(a);
353    let (b_start, b_end) = member_bit_range(b);
354    a_start < b_end && b_start < a_end
355}
356
357fn member_bit_range(m: &MemberDef) -> (u32, u32) {
358    match m {
359        MemberDef::Field { range, .. } => (range.start, range.start + range.len),
360        MemberDef::Flag { bit, .. } => (*bit, bit + 1),
361    }
362}
363
364// =============================================================================
365// Struct codegen
366// =============================================================================
367
368fn derive_storage_struct(
369    input: &DeriveInput,
370    data: &syn::DataStruct,
371    storage_attr: StorageAttr,
372) -> Result<TokenStream2> {
373    let fields = match &data.fields {
374        Fields::Named(f) => &f.named,
375        _ => {
376            return Err(Error::new_spanned(
377                input,
378                "bit_storage requires named fields",
379            ));
380        }
381    };
382
383    if storage_attr.discriminant.is_some() {
384        return Err(Error::new_spanned(
385            input,
386            "discriminant is only valid for enums",
387        ));
388    }
389
390    let members: Vec<MemberDef> = fields.iter().map(parse_member).collect::<Result<_>>()?;
391
392    validate_members(&members, &storage_attr.repr)?;
393
394    let vis = &input.vis;
395    let name = &input.ident;
396    let repr = &storage_attr.repr;
397    let builder_name = Ident::new(&format!("{}Builder", name), name.span());
398
399    let newtype = generate_struct_newtype(vis, name, repr);
400    let builder_struct = generate_struct_builder_struct(vis, &builder_name, &members);
401    let newtype_impl = generate_struct_newtype_impl(name, &builder_name, repr, &members);
402    let builder_impl = generate_struct_builder_impl(name, &builder_name, repr, &members);
403
404    Ok(quote! {
405        #newtype
406        #builder_struct
407        #newtype_impl
408        #builder_impl
409    })
410}
411
412fn generate_struct_newtype(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
413    quote! {
414        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
415        #[repr(transparent)]
416        #vis struct #name(#vis #repr);
417    }
418}
419
420fn generate_struct_builder_struct(
421    vis: &syn::Visibility,
422    builder_name: &Ident,
423    members: &[MemberDef],
424) -> TokenStream2 {
425    let fields: Vec<TokenStream2> = members
426        .iter()
427        .map(|m| match m {
428            MemberDef::Field { name, ty, .. } => {
429                quote! { #name: Option<#ty>, }
430            }
431            MemberDef::Flag { name, .. } => {
432                quote! { #name: Option<bool>, }
433            }
434        })
435        .collect();
436
437    quote! {
438        #[derive(Debug, Clone, Copy, Default)]
439        #vis struct #builder_name {
440            #(#fields)*
441        }
442    }
443}
444
445fn generate_struct_newtype_impl(
446    name: &Ident,
447    builder_name: &Ident,
448    repr: &Ident,
449    members: &[MemberDef],
450) -> TokenStream2 {
451    let repr_bit_count = repr_bits(repr);
452
453    let accessors: Vec<TokenStream2> = members.iter().map(|m| {
454        match m {
455            MemberDef::Field { name: field_name, ty, range } => {
456                let start = range.start;
457                let len = range.len;
458                let mask = if len >= repr_bit_count {
459                    quote! { #repr::MAX }
460                } else {
461                    quote! { ((1 as #repr) << #len) - 1 }
462                };
463
464                if is_primitive(ty) {
465                    quote! {
466                        #[inline]
467                        pub const fn #field_name(&self) -> #ty {
468                            ((self.0 >> #start) & #mask) as #ty
469                        }
470                    }
471                } else {
472                    // IntEnum field - returns Result
473                    quote! {
474                        #[inline]
475                        pub fn #field_name(&self) -> Result<#ty, nexus_bits::UnknownDiscriminant<#repr>> {
476                            let field_repr = ((self.0 >> #start) & #mask);
477                            <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
478                                .ok_or(nexus_bits::UnknownDiscriminant {
479                                    field: stringify!(#field_name),
480                                    value: field_repr as #repr,
481                                })
482                        }
483                    }
484                }
485            }
486            MemberDef::Flag { name: field_name, bit } => {
487                quote! {
488                    #[inline]
489                    pub const fn #field_name(&self) -> bool {
490                        (self.0 >> #bit) & 1 != 0
491                    }
492                }
493            }
494        }
495    }).collect();
496
497    quote! {
498        impl #name {
499            /// Create from raw integer value.
500            #[inline]
501            pub const fn from_raw(raw: #repr) -> Self {
502                Self(raw)
503            }
504
505            /// Get the raw integer value.
506            #[inline]
507            pub const fn raw(self) -> #repr {
508                self.0
509            }
510
511            /// Create a builder for this type.
512            #[inline]
513            pub fn builder() -> #builder_name {
514                #builder_name::default()
515            }
516
517            #(#accessors)*
518        }
519    }
520}
521
522fn generate_struct_builder_impl(
523    name: &Ident,
524    builder_name: &Ident,
525    repr: &Ident,
526    members: &[MemberDef],
527) -> TokenStream2 {
528    let repr_bit_count = repr_bits(repr);
529
530    // Setters - wrap in Some
531    let setters: Vec<TokenStream2> = members
532        .iter()
533        .map(|m| match m {
534            MemberDef::Field {
535                name: field_name,
536                ty,
537                ..
538            } => {
539                quote! {
540                    #[inline]
541                    pub fn #field_name(mut self, val: #ty) -> Self {
542                        self.#field_name = Some(val);
543                        self
544                    }
545                }
546            }
547            MemberDef::Flag {
548                name: field_name, ..
549            } => {
550                quote! {
551                    #[inline]
552                    pub fn #field_name(mut self, val: bool) -> Self {
553                        self.#field_name = Some(val);
554                        self
555                    }
556                }
557            }
558        })
559        .collect();
560
561    // Validations
562    let validations: Vec<TokenStream2> = members
563        .iter()
564        .filter_map(|m| match m {
565            MemberDef::Field {
566                name: field_name,
567                ty,
568                range,
569            } => {
570                let field_str = field_name.to_string();
571                let len = range.len;
572
573                let max_val = if len >= repr_bit_count {
574                    quote! { #repr::MAX }
575                } else {
576                    quote! { ((1 as #repr) << #len) - 1 }
577                };
578
579                if is_primitive(ty) {
580                    let type_bits: u32 = match ty {
581                        Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
582                        Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
583                        Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
584                        Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
585                        Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
586                        _ => 128,
587                    };
588
589                    // Skip validation if field can hold entire type
590                    if len >= type_bits {
591                        return None;
592                    }
593
594                    // Check if this is a signed type
595                    let is_signed = matches!(ty,
596                        Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
597                                         p.path.is_ident("i32") || p.path.is_ident("i64") ||
598                                         p.path.is_ident("i128")
599                    );
600
601                    if is_signed {
602                        // For signed types, check that value fits in signed field range
603                        // A signed N-bit field can hold -(2^(N-1)) to (2^(N-1) - 1)
604                        let min_shift = len - 1;
605                        Some(quote! {
606                            if let Some(v) = self.#field_name {
607                                let min_val = -((1i128 << #min_shift) as i128);
608                                let max_val = ((1i128 << #min_shift) - 1) as i128;
609                                let v_i128 = v as i128;
610                                if v_i128 < min_val || v_i128 > max_val {
611                                    return Err(nexus_bits::FieldOverflow {
612                                        field: #field_str,
613                                        overflow: nexus_bits::Overflow {
614                                            value: (v as #repr),
615                                            max: #max_val,
616                                        },
617                                    });
618                                }
619                            }
620                        })
621                    } else {
622                        // Unsigned - simple max check
623                        Some(quote! {
624                            if let Some(v) = self.#field_name {
625                                if (v as #repr) > #max_val {
626                                    return Err(nexus_bits::FieldOverflow {
627                                        field: #field_str,
628                                        overflow: nexus_bits::Overflow {
629                                            value: v as #repr,
630                                            max: #max_val,
631                                        },
632                                    });
633                                }
634                            }
635                        })
636                    }
637                } else {
638                    // IntEnum field - validate repr value fits in field
639                    Some(quote! {
640                        const _: () = assert!(
641                            core::mem::size_of::<<#ty as nexus_bits::IntEnum>::Repr>() <= core::mem::size_of::<#repr>(),
642                            "IntEnum repr type is wider than storage repr — values may be truncated"
643                        );
644                        if let Some(v) = self.#field_name {
645                            let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
646                            if repr_val > #max_val {
647                                return Err(nexus_bits::FieldOverflow {
648                                    field: #field_str,
649                                    overflow: nexus_bits::Overflow {
650                                        value: repr_val,
651                                        max: #max_val,
652                                    },
653                                });
654                            }
655                        }
656                    })
657                }
658            }
659            _ => None,
660        })
661        .collect();
662
663    // Pack statements - ALWAYS mask to prevent sign extension corruption
664    let pack_statements: Vec<TokenStream2> = members
665        .iter()
666        .map(|m| {
667            match m {
668                MemberDef::Field {
669                    name: field_name,
670                    ty,
671                    range,
672                } => {
673                    let start = range.start;
674                    let len = range.len;
675                    let mask = if len >= repr_bit_count {
676                        quote! { #repr::MAX }
677                    } else {
678                        quote! { ((1 as #repr) << #len) - 1 }
679                    };
680
681                    if is_primitive(ty) {
682                        quote! {
683                            if let Some(v) = self.#field_name {
684                                val |= ((v as #repr) & #mask) << #start;
685                            }
686                        }
687                    } else {
688                        // IntEnum
689                        quote! {
690                            if let Some(v) = self.#field_name {
691                                val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
692                            }
693                        }
694                    }
695                }
696                MemberDef::Flag {
697                    name: field_name,
698                    bit,
699                } => {
700                    quote! {
701                        if let Some(true) = self.#field_name {
702                            val |= (1 as #repr) << #bit;
703                        }
704                    }
705                }
706            }
707        })
708        .collect();
709
710    quote! {
711        impl #builder_name {
712            #(#setters)*
713
714            /// Build the final value, validating all fields.
715            #[inline]
716            pub fn build(self) -> Result<#name, nexus_bits::FieldOverflow<#repr>> {
717                // Validate
718                #(#validations)*
719
720                // Pack
721                let mut val: #repr = 0;
722                #(#pack_statements)*
723
724                Ok(#name(val))
725            }
726        }
727    }
728}
729
730// =============================================================================
731// Enum codegen
732// =============================================================================
733
734/// Parsed variant for tagged enum
735struct ParsedVariant {
736    name: Ident,
737    discriminant: u64,
738    members: Vec<MemberDef>,
739}
740
741fn derive_storage_enum(
742    input: &DeriveInput,
743    data: &syn::DataEnum,
744    storage_attr: StorageAttr,
745) -> Result<TokenStream2> {
746    let discriminant = storage_attr.discriminant.ok_or_else(|| {
747        Error::new_spanned(
748            input,
749            "bit_storage enum requires discriminant: #[bit_storage(repr = T, discriminant(start = N, len = M))]",
750        )
751    })?;
752
753    let repr = &storage_attr.repr;
754    let bits = repr_bits(repr);
755
756    // Validate discriminant fits
757    if discriminant.start + discriminant.len > bits {
758        return Err(Error::new_spanned(
759            input,
760            format!(
761                "discriminant exceeds {} bits (start {} + len {} = {})",
762                bits,
763                discriminant.start,
764                discriminant.len,
765                discriminant.start + discriminant.len
766            ),
767        ));
768    }
769
770    let max_discriminant = if discriminant.len >= 64 {
771        u64::MAX
772    } else {
773        (1u64 << discriminant.len) - 1
774    };
775
776    // Parse all variants
777    let mut variants = Vec::new();
778    for variant in &data.variants {
779        let disc = parse_variant_attr(&variant.attrs)?;
780
781        if disc > max_discriminant {
782            return Err(Error::new_spanned(
783                &variant.ident,
784                format!(
785                    "variant discriminant {} exceeds max {} for {}-bit field",
786                    disc, max_discriminant, discriminant.len
787                ),
788            ));
789        }
790
791        // Check for duplicate discriminants
792        for existing in &variants {
793            let existing: &ParsedVariant = existing;
794            if existing.discriminant == disc {
795                return Err(Error::new_spanned(
796                    &variant.ident,
797                    format!(
798                        "duplicate discriminant {}: already used by '{}'",
799                        disc, existing.name
800                    ),
801                ));
802            }
803        }
804
805        let members: Vec<MemberDef> = match &variant.fields {
806            Fields::Named(fields) => fields
807                .named
808                .iter()
809                .map(parse_member)
810                .collect::<Result<_>>()?,
811            Fields::Unit => Vec::new(),
812            Fields::Unnamed(_) => {
813                return Err(Error::new_spanned(
814                    variant,
815                    "tuple variants not supported, use named fields",
816                ));
817            }
818        };
819
820        // Validate members don't overlap with discriminant
821        let disc_range = MemberDef::Field {
822            name: Ident::new("__discriminant", proc_macro2::Span::call_site()),
823            ty: syn::parse_quote!(u64),
824            range: BitRange {
825                start: discriminant.start,
826                len: discriminant.len,
827            },
828        };
829
830        for member in &members {
831            if ranges_overlap(&disc_range, member) {
832                return Err(Error::new_spanned(
833                    member.name(),
834                    format!("field '{}' overlaps with discriminant", member.name()),
835                ));
836            }
837        }
838
839        // Validate members within this variant
840        validate_members(&members, repr)?;
841
842        variants.push(ParsedVariant {
843            name: variant.ident.clone(),
844            discriminant: disc,
845            members,
846        });
847    }
848
849    let vis = &input.vis;
850    let name = &input.ident;
851
852    let parent_type = generate_enum_parent_type(vis, name, repr);
853    let variant_types = generate_enum_variant_types(vis, name, repr, &variants);
854    let kind_enum = generate_enum_kind(vis, name, &variants);
855    let builder_structs = generate_enum_builder_structs(vis, name, &variants);
856    let parent_impl = generate_enum_parent_impl(name, repr, &discriminant, &variants);
857    let variant_impls = generate_enum_variant_impls(name, repr, &variants);
858    let builder_impls = generate_enum_builder_impls(name, repr, &discriminant, &variants);
859    let from_impls = generate_enum_from_impls(name, &variants);
860
861    Ok(quote! {
862        #parent_type
863        #variant_types
864        #kind_enum
865        #builder_structs
866        #parent_impl
867        #variant_impls
868        #builder_impls
869        #from_impls
870    })
871}
872
873fn variant_type_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
874    Ident::new(
875        &format!("{}{}", parent_name, variant_name),
876        variant_name.span(),
877    )
878}
879
880fn variant_builder_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
881    Ident::new(
882        &format!("{}{}Builder", parent_name, variant_name),
883        variant_name.span(),
884    )
885}
886
887fn kind_enum_name(parent_name: &Ident) -> Ident {
888    Ident::new(&format!("{}Kind", parent_name), parent_name.span())
889}
890
891fn generate_enum_parent_type(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
892    quote! {
893        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
894        #[repr(transparent)]
895        #vis struct #name(#vis #repr);
896    }
897}
898
899fn generate_enum_variant_types(
900    vis: &syn::Visibility,
901    parent_name: &Ident,
902    repr: &Ident,
903    variants: &[ParsedVariant],
904) -> TokenStream2 {
905    let types: Vec<TokenStream2> = variants
906        .iter()
907        .map(|v| {
908            let type_name = variant_type_name(parent_name, &v.name);
909            quote! {
910                #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
911                #[repr(transparent)]
912                #vis struct #type_name(#repr);
913            }
914        })
915        .collect();
916
917    quote! { #(#types)* }
918}
919
920fn generate_enum_kind(
921    vis: &syn::Visibility,
922    parent_name: &Ident,
923    variants: &[ParsedVariant],
924) -> TokenStream2 {
925    let kind_name = kind_enum_name(parent_name);
926    let variant_names: Vec<&Ident> = variants.iter().map(|v| &v.name).collect();
927
928    quote! {
929        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
930        #vis enum #kind_name {
931            #(#variant_names),*
932        }
933    }
934}
935
936fn generate_enum_builder_structs(
937    vis: &syn::Visibility,
938    parent_name: &Ident,
939    variants: &[ParsedVariant],
940) -> TokenStream2 {
941    let builders: Vec<TokenStream2> = variants
942        .iter()
943        .map(|v| {
944            let builder_name = variant_builder_name(parent_name, &v.name);
945
946            let fields: Vec<TokenStream2> = v
947                .members
948                .iter()
949                .map(|m| match m {
950                    MemberDef::Field { name, ty, .. } => {
951                        quote! { #name: Option<#ty>, }
952                    }
953                    MemberDef::Flag { name, .. } => {
954                        quote! { #name: Option<bool>, }
955                    }
956                })
957                .collect();
958
959            quote! {
960                #[derive(Debug, Clone, Copy, Default)]
961                #vis struct #builder_name {
962                    #(#fields)*
963                }
964            }
965        })
966        .collect();
967
968    quote! { #(#builders)* }
969}
970
971fn generate_enum_parent_impl(
972    name: &Ident,
973    repr: &Ident,
974    discriminant: &BitRange,
975    variants: &[ParsedVariant],
976) -> TokenStream2 {
977    let repr_bit_count = repr_bits(repr);
978    let kind_name = kind_enum_name(name);
979    let disc_start = discriminant.start;
980    let disc_len = discriminant.len;
981
982    // Discriminant is extracted as u64 for matching. Wider discriminants
983    // would truncate silently, so reject them at macro expansion time.
984    assert!(
985        disc_len <= 64,
986        "discriminant length must be <= 64 bits (got {disc_len})"
987    );
988
989    let disc_mask = if disc_len >= repr_bit_count {
990        quote! { #repr::MAX }
991    } else {
992        quote! { ((1 as #repr) << #disc_len) - 1 }
993    };
994
995    // kind() match arms
996    let kind_arms: Vec<TokenStream2> = variants
997        .iter()
998        .map(|v| {
999            let variant_name = &v.name;
1000            let disc_val = v.discriminant;
1001            quote! {
1002                #disc_val => Ok(#kind_name::#variant_name),
1003            }
1004        })
1005        .collect();
1006
1007    // is_* methods
1008    let is_methods: Vec<TokenStream2> = variants
1009        .iter()
1010        .map(|v| {
1011            let variant_name = &v.name;
1012            let method_name = Ident::new(
1013                &format!("is_{}", to_snake_case(&variant_name.to_string())),
1014                variant_name.span(),
1015            );
1016            let disc_val = v.discriminant;
1017            quote! {
1018                #[inline]
1019                pub fn #method_name(&self) -> bool {
1020                    let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1021                    disc == #disc_val
1022                }
1023            }
1024        })
1025        .collect();
1026
1027    // as_* methods
1028    let as_methods: Vec<TokenStream2> = variants
1029        .iter()
1030        .map(|v| {
1031            let variant_name = &v.name;
1032            let variant_type = variant_type_name(name, variant_name);
1033            let method_name = Ident::new(
1034                &format!("as_{}", to_snake_case(&variant_name.to_string())),
1035                variant_name.span(),
1036            );
1037            let disc_val = v.discriminant;
1038
1039            // Validation for IntEnum fields
1040            let validations: Vec<TokenStream2> = v.members
1041                .iter()
1042                .filter_map(|m| {
1043                    if let MemberDef::Field { name: field_name, ty, range } = m {
1044                        if !is_primitive(ty) {
1045                            let start = range.start;
1046                            let len = range.len;
1047                            let repr_bit_count = repr_bits(repr);
1048                            let mask = if len >= repr_bit_count {
1049                                quote! { #repr::MAX }
1050                            } else {
1051                                quote! { ((1 as #repr) << #len) - 1 }
1052                            };
1053                            return Some(quote! {
1054                                let field_repr = ((self.0 >> #start) & #mask);
1055                                if <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _).is_none() {
1056                                    return Err(nexus_bits::UnknownDiscriminant {
1057                                        field: stringify!(#field_name),
1058                                        value: field_repr as #repr,
1059                                    });
1060                                }
1061                            });
1062                        }
1063                    }
1064                    None
1065                })
1066                .collect();
1067
1068            quote! {
1069                #[inline]
1070                pub fn #method_name(&self) -> Result<#variant_type, nexus_bits::UnknownDiscriminant<#repr>> {
1071                    let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1072                    if disc != #disc_val {
1073                        return Err(nexus_bits::UnknownDiscriminant {
1074                            field: "__discriminant",
1075                            value: disc as #repr,
1076                        });
1077                    }
1078                    #(#validations)*
1079                    Ok(#variant_type(self.0))
1080                }
1081            }
1082        })
1083        .collect();
1084
1085    // Builder shortcut methods
1086    let builder_methods: Vec<TokenStream2> = variants
1087        .iter()
1088        .map(|v| {
1089            let variant_name = &v.name;
1090            let builder_name = variant_builder_name(name, variant_name);
1091            let method_name = Ident::new(
1092                &to_snake_case(&variant_name.to_string()),
1093                variant_name.span(),
1094            );
1095            quote! {
1096                #[inline]
1097                pub fn #method_name() -> #builder_name {
1098                    #builder_name::default()
1099                }
1100            }
1101        })
1102        .collect();
1103
1104    quote! {
1105        impl #name {
1106            /// Create from raw integer value.
1107            #[inline]
1108            pub const fn from_raw(raw: #repr) -> Self {
1109                Self(raw)
1110            }
1111
1112            /// Get the raw integer value.
1113            #[inline]
1114            pub const fn raw(self) -> #repr {
1115                self.0
1116            }
1117
1118            /// Get the kind (discriminant) of this value.
1119            #[inline]
1120            pub fn kind(&self) -> Result<#kind_name, nexus_bits::UnknownDiscriminant<#repr>> {
1121                let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1122                match disc {
1123                    #(#kind_arms)*
1124                    _ => Err(nexus_bits::UnknownDiscriminant {
1125                        field: "__discriminant",
1126                        value: disc as #repr,
1127                    }),
1128                }
1129            }
1130
1131            #(#is_methods)*
1132
1133            #(#as_methods)*
1134
1135            #(#builder_methods)*
1136        }
1137    }
1138}
1139
1140fn generate_enum_variant_impls(
1141    parent_name: &Ident,
1142    repr: &Ident,
1143    variants: &[ParsedVariant],
1144) -> TokenStream2 {
1145    let repr_bit_count = repr_bits(repr);
1146
1147    let impls: Vec<TokenStream2> =
1148        variants
1149            .iter()
1150            .map(|v| {
1151                let variant_name = &v.name;
1152                let variant_type = variant_type_name(parent_name, variant_name);
1153                let builder_name = variant_builder_name(parent_name, variant_name);
1154
1155                // Accessors - infallible since variant is pre-validated
1156                let accessors: Vec<TokenStream2> = v.members
1157                .iter()
1158                .map(|m| {
1159                    match m {
1160                        MemberDef::Field { name: field_name, ty, range } => {
1161                            let start = range.start;
1162                            let len = range.len;
1163                            let mask = if len >= repr_bit_count {
1164                                quote! { #repr::MAX }
1165                            } else {
1166                                quote! { ((1 as #repr) << #len) - 1 }
1167                            };
1168
1169                            if is_primitive(ty) {
1170                                quote! {
1171                                    #[inline]
1172                                    pub const fn #field_name(&self) -> #ty {
1173                                        ((self.0 >> #start) & #mask) as #ty
1174                                    }
1175                                }
1176                            } else {
1177                                // IntEnum - infallible because already validated
1178                                quote! {
1179                                    #[inline]
1180                                    pub fn #field_name(&self) -> #ty {
1181                                        let field_repr = ((self.0 >> #start) & #mask);
1182                                        // SAFETY: This type was validated during construction
1183                                        <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
1184                                            .expect("variant type invariant violated")
1185                                    }
1186                                }
1187                            }
1188                        }
1189                        MemberDef::Flag { name: field_name, bit } => {
1190                            quote! {
1191                                #[inline]
1192                                pub const fn #field_name(&self) -> bool {
1193                                    (self.0 >> #bit) & 1 != 0
1194                                }
1195                            }
1196                        }
1197                    }
1198                })
1199                .collect();
1200
1201                quote! {
1202                    impl #variant_type {
1203                        /// Create a builder for this variant.
1204                        #[inline]
1205                        pub fn builder() -> #builder_name {
1206                            #builder_name::default()
1207                        }
1208
1209                        /// Get the raw integer value.
1210                        #[inline]
1211                        pub const fn raw(self) -> #repr {
1212                            self.0
1213                        }
1214
1215                        /// Convert to parent type.
1216                        #[inline]
1217                        pub const fn as_parent(self) -> #parent_name {
1218                            #parent_name(self.0)
1219                        }
1220
1221                        #(#accessors)*
1222                    }
1223                }
1224            })
1225            .collect();
1226
1227    quote! { #(#impls)* }
1228}
1229
1230fn generate_enum_builder_impls(
1231    parent_name: &Ident,
1232    repr: &Ident,
1233    discriminant: &BitRange,
1234    variants: &[ParsedVariant],
1235) -> TokenStream2 {
1236    let repr_bit_count = repr_bits(repr);
1237    let disc_start = discriminant.start;
1238
1239    let impls: Vec<TokenStream2> = variants
1240        .iter()
1241        .map(|v| {
1242            let variant_name = &v.name;
1243            let variant_type = variant_type_name(parent_name, variant_name);
1244            let builder_name = variant_builder_name(parent_name, variant_name);
1245            let disc_val = v.discriminant;
1246
1247            // Setters
1248            let setters: Vec<TokenStream2> = v.members
1249                .iter()
1250                .map(|m| match m {
1251                    MemberDef::Field { name: field_name, ty, .. } => {
1252                        quote! {
1253                            #[inline]
1254                            pub fn #field_name(mut self, val: #ty) -> Self {
1255                                self.#field_name = Some(val);
1256                                self
1257                            }
1258                        }
1259                    }
1260                    MemberDef::Flag { name: field_name, .. } => {
1261                        quote! {
1262                            #[inline]
1263                            pub fn #field_name(mut self, val: bool) -> Self {
1264                                self.#field_name = Some(val);
1265                                self
1266                            }
1267                        }
1268                    }
1269                })
1270                .collect();
1271
1272            // Validations
1273            let validations: Vec<TokenStream2> = v.members
1274                .iter()
1275                .filter_map(|m| match m {
1276                    MemberDef::Field { name: field_name, ty, range } => {
1277                        let field_str = field_name.to_string();
1278                        let len = range.len;
1279
1280                        let max_val = if len >= repr_bit_count {
1281                            quote! { #repr::MAX }
1282                        } else {
1283                            quote! { ((1 as #repr) << #len) - 1 }
1284                        };
1285
1286                        if is_primitive(ty) {
1287                            let type_bits: u32 = match ty {
1288                                Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
1289                                Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
1290                                Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
1291                                Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
1292                                Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
1293                                _ => 128,
1294                            };
1295
1296                            if len >= type_bits {
1297                                return None;
1298                            }
1299
1300                            let is_signed = matches!(ty,
1301                                Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
1302                                                 p.path.is_ident("i32") || p.path.is_ident("i64") ||
1303                                                 p.path.is_ident("i128")
1304                            );
1305
1306                            if is_signed {
1307                                let min_shift = len - 1;
1308                                Some(quote! {
1309                                    if let Some(v) = self.#field_name {
1310                                        let min_val = -((1i128 << #min_shift) as i128);
1311                                        let max_val = ((1i128 << #min_shift) - 1) as i128;
1312                                        let v_i128 = v as i128;
1313                                        if v_i128 < min_val || v_i128 > max_val {
1314                                            return Err(nexus_bits::FieldOverflow {
1315                                                field: #field_str,
1316                                                overflow: nexus_bits::Overflow {
1317                                                    value: (v as #repr),
1318                                                    max: #max_val,
1319                                                },
1320                                            });
1321                                        }
1322                                    }
1323                                })
1324                            } else {
1325                                Some(quote! {
1326                                    if let Some(v) = self.#field_name {
1327                                        if (v as #repr) > #max_val {
1328                                            return Err(nexus_bits::FieldOverflow {
1329                                                field: #field_str,
1330                                                overflow: nexus_bits::Overflow {
1331                                                    value: v as #repr,
1332                                                    max: #max_val,
1333                                                },
1334                                            });
1335                                        }
1336                                    }
1337                                })
1338                            }
1339                        } else {
1340                            // IntEnum field
1341                            Some(quote! {
1342                                if let Some(v) = self.#field_name {
1343                                    let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
1344                                    if repr_val > #max_val {
1345                                        return Err(nexus_bits::FieldOverflow {
1346                                            field: #field_str,
1347                                            overflow: nexus_bits::Overflow {
1348                                                value: repr_val,
1349                                                max: #max_val,
1350                                            },
1351                                        });
1352                                    }
1353                                }
1354                            })
1355                        }
1356                    }
1357                    _ => None,
1358                })
1359                .collect();
1360
1361            // Pack statements
1362            let pack_statements: Vec<TokenStream2> = v.members
1363                .iter()
1364                .map(|m| {
1365                    match m {
1366                        MemberDef::Field { name: field_name, ty, range } => {
1367                            let start = range.start;
1368                            let len = range.len;
1369                            let mask = if len >= repr_bit_count {
1370                                quote! { #repr::MAX }
1371                            } else {
1372                                quote! { ((1 as #repr) << #len) - 1 }
1373                            };
1374
1375                            if is_primitive(ty) {
1376                                quote! {
1377                                    if let Some(v) = self.#field_name {
1378                                        val |= ((v as #repr) & #mask) << #start;
1379                                    }
1380                                }
1381                            } else {
1382                                quote! {
1383                                    if let Some(v) = self.#field_name {
1384                                        val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
1385                                    }
1386                                }
1387                            }
1388                        }
1389                        MemberDef::Flag { name: field_name, bit } => {
1390                            quote! {
1391                                if let Some(true) = self.#field_name {
1392                                    val |= (1 as #repr) << #bit;
1393                                }
1394                            }
1395                        }
1396                    }
1397                })
1398                .collect();
1399
1400            quote! {
1401                impl #builder_name {
1402                    #(#setters)*
1403
1404                    /// Build the variant type, validating all fields.
1405                    #[inline]
1406                    pub fn build(self) -> Result<#variant_type, nexus_bits::FieldOverflow<#repr>> {
1407                        #(#validations)*
1408
1409                        let mut val: #repr = 0;
1410                        // Set discriminant
1411                        val |= (#disc_val as #repr) << #disc_start;
1412                        #(#pack_statements)*
1413
1414                        Ok(#variant_type(val))
1415                    }
1416
1417                    /// Build directly to parent type, validating all fields.
1418                    #[inline]
1419                    pub fn build_parent(self) -> Result<#parent_name, nexus_bits::FieldOverflow<#repr>> {
1420                        self.build().map(|v| v.as_parent())
1421                    }
1422                }
1423            }
1424        })
1425        .collect();
1426
1427    quote! { #(#impls)* }
1428}
1429
1430fn generate_enum_from_impls(parent_name: &Ident, variants: &[ParsedVariant]) -> TokenStream2 {
1431    let impls: Vec<TokenStream2> = variants
1432        .iter()
1433        .map(|v| {
1434            let variant_type = variant_type_name(parent_name, &v.name);
1435            quote! {
1436                impl From<#variant_type> for #parent_name {
1437                    #[inline]
1438                    fn from(v: #variant_type) -> Self {
1439                        v.as_parent()
1440                    }
1441                }
1442            }
1443        })
1444        .collect();
1445
1446    quote! { #(#impls)* }
1447}
1448
1449fn to_snake_case(s: &str) -> String {
1450    let mut result = String::new();
1451    for (i, c) in s.chars().enumerate() {
1452        if c.is_uppercase() {
1453            if i > 0 {
1454                result.push('_');
1455            }
1456            result.push(c.to_lowercase().next().unwrap());
1457        } else {
1458            result.push(c);
1459        }
1460    }
1461    result
1462}