Skip to main content

nexus_bits_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{
5    Data, DeriveInput, Error, Fields, Ident, Result, Type, parse::Parser, parse_macro_input,
6};
7
8// =============================================================================
9// IntEnum derive
10// =============================================================================
11
12#[proc_macro_derive(IntEnum)]
13pub fn derive_int_enum(input: TokenStream) -> TokenStream {
14    let input = parse_macro_input!(input as DeriveInput);
15
16    match derive_int_enum_impl(&input) {
17        Ok(tokens) => tokens.into(),
18        Err(err) => err.to_compile_error().into(),
19    }
20}
21
22fn derive_int_enum_impl(input: &DeriveInput) -> Result<TokenStream2> {
23    let variants = match &input.data {
24        Data::Enum(data) => &data.variants,
25        _ => {
26            return Err(Error::new_spanned(
27                input,
28                "IntEnum can only be derived for enums",
29            ));
30        }
31    };
32
33    let repr = parse_repr(input)?;
34
35    for variant in variants {
36        if !matches!(variant.fields, Fields::Unit) {
37            return Err(Error::new_spanned(
38                variant,
39                "IntEnum variants cannot have fields",
40            ));
41        }
42    }
43
44    let name = &input.ident;
45    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
46
47    let from_arms = variants.iter().map(|v| {
48        let variant_name = &v.ident;
49        quote! {
50            x if x == #name::#variant_name as #repr => Some(#name::#variant_name),
51        }
52    });
53
54    Ok(quote! {
55        impl #impl_generics nexus_bits::IntEnum for #name #ty_generics #where_clause {
56            type Repr = #repr;
57
58            #[inline]
59            fn into_repr(self) -> #repr {
60                self as #repr
61            }
62
63            #[inline]
64            fn try_from_repr(repr: #repr) -> Option<Self> {
65                match repr {
66                    #(#from_arms)*
67                    _ => None,
68                }
69            }
70        }
71    })
72}
73
74fn parse_repr(input: &DeriveInput) -> Result<Ident> {
75    for attr in &input.attrs {
76        if attr.path().is_ident("repr") {
77            let repr: Ident = attr.parse_args()?;
78            match repr.to_string().as_str() {
79                "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" => {
80                    return Ok(repr);
81                }
82                _ => {
83                    return Err(Error::new_spanned(
84                        repr,
85                        "IntEnum requires a primitive integer repr (u8..u128, i8..i128)",
86                    ));
87                }
88            }
89        }
90    }
91
92    Err(Error::new_spanned(
93        input,
94        "IntEnum requires a #[repr(u8/u16/u32/u64/i8/i16/i32/i64)] attribute",
95    ))
96}
97
98// =============================================================================
99// BitStorage attribute macro
100// =============================================================================
101
102#[proc_macro_attribute]
103pub fn bit_storage(attr: TokenStream, item: TokenStream) -> TokenStream {
104    let attr = proc_macro2::TokenStream::from(attr);
105    let item = parse_macro_input!(item as DeriveInput);
106
107    match bit_storage_impl(attr, &item) {
108        Ok(tokens) => tokens.into(),
109        Err(err) => err.to_compile_error().into(),
110    }
111}
112
113fn bit_storage_impl(attr: TokenStream2, input: &DeriveInput) -> Result<TokenStream2> {
114    let storage_attr = parse_storage_attr_tokens(attr)?;
115
116    match &input.data {
117        Data::Struct(data) => derive_storage_struct(input, data, &storage_attr),
118        Data::Enum(data) => derive_storage_enum(input, data, &storage_attr),
119        Data::Union(_) => Err(Error::new_spanned(
120            input,
121            "bit_storage cannot be applied to unions",
122        )),
123    }
124}
125
126// =============================================================================
127// Attribute types
128// =============================================================================
129
130/// Parsed #[bit_storage(repr = T)] or #[bit_storage(repr = T, discriminant(start = N, len = M))]
131struct StorageAttr {
132    repr: Ident,
133    discriminant: Option<BitRange>,
134}
135
136/// Bit range for a field
137#[derive(Clone, Copy)]
138struct BitRange {
139    start: u32,
140    len: u32,
141}
142
143/// Parsed field/flag from struct
144// Field variant holds syn::Type (~256 bytes) vs Flag (~28 bytes). Boxing isn't
145// worth it — this is compile-time proc-macro code, not runtime.
146#[allow(clippy::large_enum_variant)]
147enum MemberDef {
148    Field {
149        name: Ident,
150        ty: Type,
151        range: BitRange,
152    },
153    Flag {
154        name: Ident,
155        bit: u32,
156    },
157}
158
159impl MemberDef {
160    fn name(&self) -> &Ident {
161        match self {
162            MemberDef::Field { name, .. } | MemberDef::Flag { name, .. } => name,
163        }
164    }
165}
166
167// =============================================================================
168// Attribute parsing
169// =============================================================================
170
171fn parse_storage_attr_tokens(attr: TokenStream2) -> Result<StorageAttr> {
172    let mut repr = None;
173    let mut discriminant = None;
174
175    let parser = syn::meta::parser(|meta| {
176        if meta.path.is_ident("repr") {
177            meta.input.parse::<syn::Token![=]>()?;
178            repr = Some(meta.input.parse::<Ident>()?);
179            Ok(())
180        } else if meta.path.is_ident("discriminant") {
181            let content;
182            syn::parenthesized!(content in meta.input);
183            discriminant = Some(parse_bit_range(&content)?);
184            Ok(())
185        } else {
186            Err(meta.error("expected `repr` or `discriminant`"))
187        }
188    });
189
190    parser.parse2(attr)?;
191
192    let repr = repr.ok_or_else(|| {
193        Error::new(
194            proc_macro2::Span::call_site(),
195            "bit_storage requires `repr = ...`",
196        )
197    })?;
198
199    // Validate repr
200    match repr.to_string().as_str() {
201        "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" => {}
202        _ => return Err(Error::new_spanned(&repr, "repr must be an integer type")),
203    }
204
205    Ok(StorageAttr { repr, discriminant })
206}
207
208fn parse_bit_range(input: syn::parse::ParseStream) -> Result<BitRange> {
209    let mut start = None;
210    let mut len = None;
211
212    while !input.is_empty() {
213        let ident: Ident = input.parse()?;
214        input.parse::<syn::Token![=]>()?;
215        let lit: syn::LitInt = input.parse()?;
216        let value: u32 = lit.base10_parse()?;
217
218        match ident.to_string().as_str() {
219            "start" => start = Some(value),
220            "len" => len = Some(value),
221            _ => return Err(Error::new_spanned(ident, "expected `start` or `len`")),
222        }
223
224        if input.peek(syn::Token![,]) {
225            input.parse::<syn::Token![,]>()?;
226        }
227    }
228
229    let start = start.ok_or_else(|| Error::new(input.span(), "missing `start`"))?;
230    let len = len.ok_or_else(|| Error::new(input.span(), "missing `len`"))?;
231
232    if len == 0 {
233        return Err(Error::new(input.span(), "len must be > 0"));
234    }
235
236    Ok(BitRange { start, len })
237}
238
239fn parse_member(field: &syn::Field) -> Result<MemberDef> {
240    let name = field
241        .ident
242        .clone()
243        .ok_or_else(|| Error::new_spanned(field, "tuple structs not supported"))?;
244    let ty = field.ty.clone();
245
246    for attr in &field.attrs {
247        if attr.path().is_ident("field") {
248            let range = attr.parse_args_with(parse_bit_range)?;
249            return Ok(MemberDef::Field { name, ty, range });
250        } else if attr.path().is_ident("flag") {
251            let bit: syn::LitInt = attr.parse_args()?;
252            let bit: u32 = bit.base10_parse()?;
253            return Ok(MemberDef::Flag { name, bit });
254        }
255    }
256
257    Err(Error::new_spanned(
258        field,
259        "field requires #[field(start = N, len = M)] or #[flag(N)] attribute",
260    ))
261}
262
263fn parse_variant_attr(attrs: &[syn::Attribute]) -> Result<u64> {
264    for attr in attrs {
265        if attr.path().is_ident("variant") {
266            let lit: syn::LitInt = attr.parse_args()?;
267            return lit.base10_parse();
268        }
269    }
270    Err(Error::new(
271        proc_macro2::Span::call_site(),
272        "enum variant requires #[variant(N)] attribute",
273    ))
274}
275
276// =============================================================================
277// Helpers
278// =============================================================================
279
280fn is_primitive(ty: &Type) -> bool {
281    if let Type::Path(type_path) = ty {
282        if let Some(ident) = type_path.path.get_ident() {
283            return matches!(
284                ident.to_string().as_str(),
285                "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128"
286            );
287        }
288    }
289    false
290}
291
292fn repr_bits(repr: &Ident) -> u32 {
293    match repr.to_string().as_str() {
294        "u8" | "i8" => 8,
295        "u16" | "i16" => 16,
296        "u32" | "i32" => 32,
297        "u64" | "i64" => 64,
298        "u128" | "i128" => 128,
299        _ => 0,
300    }
301}
302
303// =============================================================================
304// Validation
305// =============================================================================
306
307fn validate_members(members: &[MemberDef], repr: &Ident) -> Result<()> {
308    let bits = repr_bits(repr);
309
310    // Check each field fits
311    for member in members {
312        match member {
313            MemberDef::Field { name, range, .. } => {
314                if range.start + range.len > bits {
315                    return Err(Error::new_spanned(
316                        name,
317                        format!(
318                            "field exceeds {} bits (start {} + len {} = {})",
319                            bits,
320                            range.start,
321                            range.len,
322                            range.start + range.len
323                        ),
324                    ));
325                }
326            }
327            MemberDef::Flag { name, bit, .. } => {
328                if *bit >= bits {
329                    return Err(Error::new_spanned(
330                        name,
331                        format!("flag bit {} exceeds {} bits", bit, bits),
332                    ));
333                }
334            }
335        }
336    }
337
338    // Check no overlap (simple O(n²) for now)
339    for (i, a) in members.iter().enumerate() {
340        for b in members.iter().skip(i + 1) {
341            if ranges_overlap(a, b) {
342                return Err(Error::new_spanned(
343                    b.name(),
344                    format!("field '{}' overlaps with '{}'", b.name(), a.name()),
345                ));
346            }
347        }
348    }
349
350    Ok(())
351}
352
353fn ranges_overlap(a: &MemberDef, b: &MemberDef) -> bool {
354    let (a_start, a_end) = member_bit_range(a);
355    let (b_start, b_end) = member_bit_range(b);
356    a_start < b_end && b_start < a_end
357}
358
359fn member_bit_range(m: &MemberDef) -> (u32, u32) {
360    match m {
361        MemberDef::Field { range, .. } => (range.start, range.start + range.len),
362        MemberDef::Flag { bit, .. } => (*bit, bit + 1),
363    }
364}
365
366// =============================================================================
367// Struct codegen
368// =============================================================================
369
370fn derive_storage_struct(
371    input: &DeriveInput,
372    data: &syn::DataStruct,
373    storage_attr: &StorageAttr,
374) -> Result<TokenStream2> {
375    let fields = match &data.fields {
376        Fields::Named(f) => &f.named,
377        _ => {
378            return Err(Error::new_spanned(
379                input,
380                "bit_storage requires named fields",
381            ));
382        }
383    };
384
385    if storage_attr.discriminant.is_some() {
386        return Err(Error::new_spanned(
387            input,
388            "discriminant is only valid for enums",
389        ));
390    }
391
392    let members: Vec<MemberDef> = fields.iter().map(parse_member).collect::<Result<_>>()?;
393
394    validate_members(&members, &storage_attr.repr)?;
395
396    let vis = &input.vis;
397    let name = &input.ident;
398    let repr = &storage_attr.repr;
399    let builder_name = Ident::new(&format!("{}Builder", name), name.span());
400
401    let newtype = generate_struct_newtype(vis, name, repr);
402    let builder_struct = generate_struct_builder_struct(vis, &builder_name, &members);
403    let newtype_impl = generate_struct_newtype_impl(name, &builder_name, repr, &members);
404    let builder_impl = generate_struct_builder_impl(name, &builder_name, repr, &members);
405
406    Ok(quote! {
407        #newtype
408        #builder_struct
409        #newtype_impl
410        #builder_impl
411    })
412}
413
414fn generate_struct_newtype(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
415    quote! {
416        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
417        #[repr(transparent)]
418        #vis struct #name(#vis #repr);
419    }
420}
421
422fn generate_struct_builder_struct(
423    vis: &syn::Visibility,
424    builder_name: &Ident,
425    members: &[MemberDef],
426) -> TokenStream2 {
427    let fields: Vec<TokenStream2> = members
428        .iter()
429        .map(|m| match m {
430            MemberDef::Field { name, ty, .. } => {
431                quote! { #name: Option<#ty>, }
432            }
433            MemberDef::Flag { name, .. } => {
434                quote! { #name: Option<bool>, }
435            }
436        })
437        .collect();
438
439    quote! {
440        #[derive(Debug, Clone, Copy, Default)]
441        #vis struct #builder_name {
442            #(#fields)*
443        }
444    }
445}
446
447fn generate_struct_newtype_impl(
448    name: &Ident,
449    builder_name: &Ident,
450    repr: &Ident,
451    members: &[MemberDef],
452) -> TokenStream2 {
453    let repr_bit_count = repr_bits(repr);
454
455    let accessors: Vec<TokenStream2> = members.iter().map(|m| {
456        match m {
457            MemberDef::Field { name: field_name, ty, range } => {
458                let start = range.start;
459                let len = range.len;
460                let mask = if len >= repr_bit_count {
461                    quote! { #repr::MAX }
462                } else {
463                    quote! { ((1 as #repr) << #len) - 1 }
464                };
465
466                if is_primitive(ty) {
467                    quote! {
468                        #[inline]
469                        pub const fn #field_name(&self) -> #ty {
470                            ((self.0 >> #start) & #mask) as #ty
471                        }
472                    }
473                } else {
474                    // IntEnum field - returns Result
475                    quote! {
476                        #[inline]
477                        pub fn #field_name(&self) -> Result<#ty, nexus_bits::UnknownDiscriminant<#repr>> {
478                            let field_repr = ((self.0 >> #start) & #mask);
479                            <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
480                                .ok_or(nexus_bits::UnknownDiscriminant {
481                                    field: stringify!(#field_name),
482                                    value: field_repr as #repr,
483                                })
484                        }
485                    }
486                }
487            }
488            MemberDef::Flag { name: field_name, bit } => {
489                quote! {
490                    #[inline]
491                    pub const fn #field_name(&self) -> bool {
492                        (self.0 >> #bit) & 1 != 0
493                    }
494                }
495            }
496        }
497    }).collect();
498
499    quote! {
500        impl #name {
501            /// Create from raw integer value.
502            #[inline]
503            pub const fn from_raw(raw: #repr) -> Self {
504                Self(raw)
505            }
506
507            /// Get the raw integer value.
508            #[inline]
509            pub const fn raw(self) -> #repr {
510                self.0
511            }
512
513            /// Create a builder for this type.
514            #[inline]
515            pub fn builder() -> #builder_name {
516                #builder_name::default()
517            }
518
519            #(#accessors)*
520        }
521    }
522}
523
524fn generate_struct_builder_impl(
525    name: &Ident,
526    builder_name: &Ident,
527    repr: &Ident,
528    members: &[MemberDef],
529) -> TokenStream2 {
530    let repr_bit_count = repr_bits(repr);
531
532    // Setters - wrap in Some
533    let setters: Vec<TokenStream2> = members
534        .iter()
535        .map(|m| match m {
536            MemberDef::Field {
537                name: field_name,
538                ty,
539                ..
540            } => {
541                quote! {
542                    #[inline]
543                    pub fn #field_name(mut self, val: #ty) -> Self {
544                        self.#field_name = Some(val);
545                        self
546                    }
547                }
548            }
549            MemberDef::Flag {
550                name: field_name, ..
551            } => {
552                quote! {
553                    #[inline]
554                    pub fn #field_name(mut self, val: bool) -> Self {
555                        self.#field_name = Some(val);
556                        self
557                    }
558                }
559            }
560        })
561        .collect();
562
563    // Validations
564    let validations: Vec<TokenStream2> = members
565        .iter()
566        .filter_map(|m| match m {
567            MemberDef::Field {
568                name: field_name,
569                ty,
570                range,
571            } => {
572                let field_str = field_name.to_string();
573                let len = range.len;
574
575                let max_val = if len >= repr_bit_count {
576                    quote! { #repr::MAX }
577                } else {
578                    quote! { ((1 as #repr) << #len) - 1 }
579                };
580
581                if is_primitive(ty) {
582                    let type_bits: u32 = match ty {
583                        Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
584                        Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
585                        Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
586                        Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
587                        Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
588                        _ => 128,
589                    };
590
591                    // Skip validation if field can hold entire type
592                    if len >= type_bits {
593                        return None;
594                    }
595
596                    // Check if this is a signed type
597                    let is_signed = matches!(ty,
598                        Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
599                                         p.path.is_ident("i32") || p.path.is_ident("i64") ||
600                                         p.path.is_ident("i128")
601                    );
602
603                    if is_signed {
604                        // For signed types, check that value fits in signed field range
605                        // A signed N-bit field can hold -(2^(N-1)) to (2^(N-1) - 1)
606                        // Note: len < 128 is guaranteed here — the early return at
607                        // line 596 (len >= type_bits) catches len >= 128 for i128.
608                        let min_shift = len - 1;
609                        Some(quote! {
610                            if let Some(v) = self.#field_name {
611                                let min_val = -((1i128 << #min_shift) as i128);
612                                let max_val = ((1i128 << #min_shift) - 1) as i128;
613                                let v_i128 = v as i128;
614                                if v_i128 < min_val || v_i128 > max_val {
615                                    return Err(nexus_bits::FieldOverflow {
616                                        field: #field_str,
617                                        overflow: nexus_bits::Overflow {
618                                            value: (v as #repr),
619                                            max: #max_val,
620                                        },
621                                    });
622                                }
623                            }
624                        })
625                    } else {
626                        // Unsigned - simple max check
627                        Some(quote! {
628                            if let Some(v) = self.#field_name {
629                                if (v as #repr) > #max_val {
630                                    return Err(nexus_bits::FieldOverflow {
631                                        field: #field_str,
632                                        overflow: nexus_bits::Overflow {
633                                            value: v as #repr,
634                                            max: #max_val,
635                                        },
636                                    });
637                                }
638                            }
639                        })
640                    }
641                } else {
642                    // IntEnum field - validate repr value fits in field
643                    Some(quote! {
644                        const _: () = assert!(
645                            core::mem::size_of::<<#ty as nexus_bits::IntEnum>::Repr>() <= core::mem::size_of::<#repr>(),
646                            "IntEnum repr type is wider than storage repr — values may be truncated"
647                        );
648                        if let Some(v) = self.#field_name {
649                            let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
650                            if repr_val > #max_val {
651                                return Err(nexus_bits::FieldOverflow {
652                                    field: #field_str,
653                                    overflow: nexus_bits::Overflow {
654                                        value: repr_val,
655                                        max: #max_val,
656                                    },
657                                });
658                            }
659                        }
660                    })
661                }
662            }
663            MemberDef::Flag { .. } => None,
664        })
665        .collect();
666
667    // Pack statements - ALWAYS mask to prevent sign extension corruption
668    let pack_statements: Vec<TokenStream2> = members
669        .iter()
670        .map(|m| {
671            match m {
672                MemberDef::Field {
673                    name: field_name,
674                    ty,
675                    range,
676                } => {
677                    let start = range.start;
678                    let len = range.len;
679                    let mask = if len >= repr_bit_count {
680                        quote! { #repr::MAX }
681                    } else {
682                        quote! { ((1 as #repr) << #len) - 1 }
683                    };
684
685                    if is_primitive(ty) {
686                        quote! {
687                            if let Some(v) = self.#field_name {
688                                val |= ((v as #repr) & #mask) << #start;
689                            }
690                        }
691                    } else {
692                        // IntEnum
693                        quote! {
694                            if let Some(v) = self.#field_name {
695                                val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
696                            }
697                        }
698                    }
699                }
700                MemberDef::Flag {
701                    name: field_name,
702                    bit,
703                } => {
704                    quote! {
705                        if let Some(true) = self.#field_name {
706                            val |= (1 as #repr) << #bit;
707                        }
708                    }
709                }
710            }
711        })
712        .collect();
713
714    quote! {
715        impl #builder_name {
716            #(#setters)*
717
718            /// Build the final value, validating all fields.
719            #[inline]
720            pub fn build(self) -> Result<#name, nexus_bits::FieldOverflow<#repr>> {
721                // Validate
722                #(#validations)*
723
724                // Pack
725                let mut val: #repr = 0;
726                #(#pack_statements)*
727
728                Ok(#name(val))
729            }
730        }
731    }
732}
733
734// =============================================================================
735// Enum codegen
736// =============================================================================
737
738/// Parsed variant for tagged enum
739struct ParsedVariant {
740    name: Ident,
741    discriminant: u64,
742    members: Vec<MemberDef>,
743}
744
745fn derive_storage_enum(
746    input: &DeriveInput,
747    data: &syn::DataEnum,
748    storage_attr: &StorageAttr,
749) -> Result<TokenStream2> {
750    let discriminant = storage_attr.discriminant.ok_or_else(|| {
751        Error::new_spanned(
752            input,
753            "bit_storage enum requires discriminant: #[bit_storage(repr = T, discriminant(start = N, len = M))]",
754        )
755    })?;
756
757    let repr = &storage_attr.repr;
758    let bits = repr_bits(repr);
759
760    // Validate discriminant fits
761    if discriminant.start + discriminant.len > bits {
762        return Err(Error::new_spanned(
763            input,
764            format!(
765                "discriminant exceeds {} bits (start {} + len {} = {})",
766                bits,
767                discriminant.start,
768                discriminant.len,
769                discriminant.start + discriminant.len
770            ),
771        ));
772    }
773
774    let max_discriminant = if discriminant.len >= 64 {
775        u64::MAX
776    } else {
777        (1u64 << discriminant.len) - 1
778    };
779
780    // Parse all variants
781    let mut variants = Vec::new();
782    for variant in &data.variants {
783        let disc = parse_variant_attr(&variant.attrs)?;
784
785        if disc > max_discriminant {
786            return Err(Error::new_spanned(
787                &variant.ident,
788                format!(
789                    "variant discriminant {} exceeds max {} for {}-bit field",
790                    disc, max_discriminant, discriminant.len
791                ),
792            ));
793        }
794
795        // Check for duplicate discriminants
796        for existing in &variants {
797            let existing: &ParsedVariant = existing;
798            if existing.discriminant == disc {
799                return Err(Error::new_spanned(
800                    &variant.ident,
801                    format!(
802                        "duplicate discriminant {}: already used by '{}'",
803                        disc, existing.name
804                    ),
805                ));
806            }
807        }
808
809        let members: Vec<MemberDef> = match &variant.fields {
810            Fields::Named(fields) => fields
811                .named
812                .iter()
813                .map(parse_member)
814                .collect::<Result<_>>()?,
815            Fields::Unit => Vec::new(),
816            Fields::Unnamed(_) => {
817                return Err(Error::new_spanned(
818                    variant,
819                    "tuple variants not supported, use named fields",
820                ));
821            }
822        };
823
824        // Validate members don't overlap with discriminant
825        let disc_range = MemberDef::Field {
826            name: Ident::new("__discriminant", proc_macro2::Span::call_site()),
827            ty: syn::parse_quote!(u64),
828            range: discriminant,
829        };
830
831        for member in &members {
832            if ranges_overlap(&disc_range, member) {
833                return Err(Error::new_spanned(
834                    member.name(),
835                    format!("field '{}' overlaps with discriminant", member.name()),
836                ));
837            }
838        }
839
840        // Validate members within this variant
841        validate_members(&members, repr)?;
842
843        variants.push(ParsedVariant {
844            name: variant.ident.clone(),
845            discriminant: disc,
846            members,
847        });
848    }
849
850    let vis = &input.vis;
851    let name = &input.ident;
852
853    let parent_type = generate_enum_parent_type(vis, name, repr);
854    let variant_types = generate_enum_variant_types(vis, name, repr, &variants);
855    let kind_enum = generate_enum_kind(vis, name, &variants);
856    let builder_structs = generate_enum_builder_structs(vis, name, &variants);
857    let parent_impl = generate_enum_parent_impl(name, repr, discriminant, &variants);
858    let variant_impls = generate_enum_variant_impls(name, repr, &variants);
859    let builder_impls = generate_enum_builder_impls(name, repr, discriminant, &variants);
860    let from_impls = generate_enum_from_impls(name, &variants);
861
862    Ok(quote! {
863        #parent_type
864        #variant_types
865        #kind_enum
866        #builder_structs
867        #parent_impl
868        #variant_impls
869        #builder_impls
870        #from_impls
871    })
872}
873
874fn variant_type_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
875    Ident::new(
876        &format!("{}{}", parent_name, variant_name),
877        variant_name.span(),
878    )
879}
880
881fn variant_builder_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
882    Ident::new(
883        &format!("{}{}Builder", parent_name, variant_name),
884        variant_name.span(),
885    )
886}
887
888fn kind_enum_name(parent_name: &Ident) -> Ident {
889    Ident::new(&format!("{}Kind", parent_name), parent_name.span())
890}
891
892fn generate_enum_parent_type(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
893    quote! {
894        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
895        #[repr(transparent)]
896        #vis struct #name(#vis #repr);
897    }
898}
899
900fn generate_enum_variant_types(
901    vis: &syn::Visibility,
902    parent_name: &Ident,
903    repr: &Ident,
904    variants: &[ParsedVariant],
905) -> TokenStream2 {
906    let types: Vec<TokenStream2> = variants
907        .iter()
908        .map(|v| {
909            let type_name = variant_type_name(parent_name, &v.name);
910            quote! {
911                #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
912                #[repr(transparent)]
913                #vis struct #type_name(#repr);
914            }
915        })
916        .collect();
917
918    quote! { #(#types)* }
919}
920
921fn generate_enum_kind(
922    vis: &syn::Visibility,
923    parent_name: &Ident,
924    variants: &[ParsedVariant],
925) -> TokenStream2 {
926    let kind_name = kind_enum_name(parent_name);
927    let variant_names: Vec<&Ident> = variants.iter().map(|v| &v.name).collect();
928
929    quote! {
930        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
931        #vis enum #kind_name {
932            #(#variant_names),*
933        }
934    }
935}
936
937fn generate_enum_builder_structs(
938    vis: &syn::Visibility,
939    parent_name: &Ident,
940    variants: &[ParsedVariant],
941) -> TokenStream2 {
942    let builders: Vec<TokenStream2> = variants
943        .iter()
944        .map(|v| {
945            let builder_name = variant_builder_name(parent_name, &v.name);
946
947            let fields: Vec<TokenStream2> = v
948                .members
949                .iter()
950                .map(|m| match m {
951                    MemberDef::Field { name, ty, .. } => {
952                        quote! { #name: Option<#ty>, }
953                    }
954                    MemberDef::Flag { name, .. } => {
955                        quote! { #name: Option<bool>, }
956                    }
957                })
958                .collect();
959
960            quote! {
961                #[derive(Debug, Clone, Copy, Default)]
962                #vis struct #builder_name {
963                    #(#fields)*
964                }
965            }
966        })
967        .collect();
968
969    quote! { #(#builders)* }
970}
971
972fn generate_enum_parent_impl(
973    name: &Ident,
974    repr: &Ident,
975    discriminant: BitRange,
976    variants: &[ParsedVariant],
977) -> TokenStream2 {
978    let repr_bit_count = repr_bits(repr);
979    let kind_name = kind_enum_name(name);
980    let disc_start = discriminant.start;
981    let disc_len = discriminant.len;
982
983    // Discriminant is extracted as u64 for matching. Wider discriminants
984    // would truncate silently, so reject them at macro expansion time.
985    assert!(
986        disc_len <= 64,
987        "discriminant length must be <= 64 bits (got {disc_len})"
988    );
989
990    let disc_mask = if disc_len >= repr_bit_count {
991        quote! { #repr::MAX }
992    } else {
993        quote! { ((1 as #repr) << #disc_len) - 1 }
994    };
995
996    // kind() match arms
997    let kind_arms: Vec<TokenStream2> = variants
998        .iter()
999        .map(|v| {
1000            let variant_name = &v.name;
1001            let disc_val = v.discriminant;
1002            quote! {
1003                #disc_val => Ok(#kind_name::#variant_name),
1004            }
1005        })
1006        .collect();
1007
1008    // is_* methods
1009    let is_methods: Vec<TokenStream2> = variants
1010        .iter()
1011        .map(|v| {
1012            let variant_name = &v.name;
1013            let method_name = Ident::new(
1014                &format!("is_{}", to_snake_case(&variant_name.to_string())),
1015                variant_name.span(),
1016            );
1017            let disc_val = v.discriminant;
1018            quote! {
1019                #[inline]
1020                pub fn #method_name(&self) -> bool {
1021                    let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1022                    disc == #disc_val
1023                }
1024            }
1025        })
1026        .collect();
1027
1028    // as_* methods
1029    let as_methods: Vec<TokenStream2> = variants
1030        .iter()
1031        .map(|v| {
1032            let variant_name = &v.name;
1033            let variant_type = variant_type_name(name, variant_name);
1034            let method_name = Ident::new(
1035                &format!("as_{}", to_snake_case(&variant_name.to_string())),
1036                variant_name.span(),
1037            );
1038            let disc_val = v.discriminant;
1039
1040            // Validation for IntEnum fields
1041            let validations: Vec<TokenStream2> = v.members
1042                .iter()
1043                .filter_map(|m| {
1044                    if let MemberDef::Field { name: field_name, ty, range } = m {
1045                        if !is_primitive(ty) {
1046                            let start = range.start;
1047                            let len = range.len;
1048                            let repr_bit_count = repr_bits(repr);
1049                            let mask = if len >= repr_bit_count {
1050                                quote! { #repr::MAX }
1051                            } else {
1052                                quote! { ((1 as #repr) << #len) - 1 }
1053                            };
1054                            return Some(quote! {
1055                                let field_repr = ((self.0 >> #start) & #mask);
1056                                if <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _).is_none() {
1057                                    return Err(nexus_bits::UnknownDiscriminant {
1058                                        field: stringify!(#field_name),
1059                                        value: field_repr as #repr,
1060                                    });
1061                                }
1062                            });
1063                        }
1064                    }
1065                    None
1066                })
1067                .collect();
1068
1069            quote! {
1070                #[inline]
1071                pub fn #method_name(&self) -> Result<#variant_type, nexus_bits::UnknownDiscriminant<#repr>> {
1072                    let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1073                    if disc != #disc_val {
1074                        return Err(nexus_bits::UnknownDiscriminant {
1075                            field: "__discriminant",
1076                            value: disc as #repr,
1077                        });
1078                    }
1079                    #(#validations)*
1080                    Ok(#variant_type(self.0))
1081                }
1082            }
1083        })
1084        .collect();
1085
1086    // Builder shortcut methods
1087    let builder_methods: Vec<TokenStream2> = variants
1088        .iter()
1089        .map(|v| {
1090            let variant_name = &v.name;
1091            let builder_name = variant_builder_name(name, variant_name);
1092            let method_name = Ident::new(
1093                &to_snake_case(&variant_name.to_string()),
1094                variant_name.span(),
1095            );
1096            quote! {
1097                #[inline]
1098                pub fn #method_name() -> #builder_name {
1099                    #builder_name::default()
1100                }
1101            }
1102        })
1103        .collect();
1104
1105    quote! {
1106        impl #name {
1107            /// Create from raw integer value.
1108            #[inline]
1109            pub const fn from_raw(raw: #repr) -> Self {
1110                Self(raw)
1111            }
1112
1113            /// Get the raw integer value.
1114            #[inline]
1115            pub const fn raw(self) -> #repr {
1116                self.0
1117            }
1118
1119            /// Get the kind (discriminant) of this value.
1120            #[inline]
1121            pub fn kind(&self) -> Result<#kind_name, nexus_bits::UnknownDiscriminant<#repr>> {
1122                let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1123                match disc {
1124                    #(#kind_arms)*
1125                    _ => Err(nexus_bits::UnknownDiscriminant {
1126                        field: "__discriminant",
1127                        value: disc as #repr,
1128                    }),
1129                }
1130            }
1131
1132            #(#is_methods)*
1133
1134            #(#as_methods)*
1135
1136            #(#builder_methods)*
1137        }
1138    }
1139}
1140
1141fn generate_enum_variant_impls(
1142    parent_name: &Ident,
1143    repr: &Ident,
1144    variants: &[ParsedVariant],
1145) -> TokenStream2 {
1146    let repr_bit_count = repr_bits(repr);
1147
1148    let impls: Vec<TokenStream2> =
1149        variants
1150            .iter()
1151            .map(|v| {
1152                let variant_name = &v.name;
1153                let variant_type = variant_type_name(parent_name, variant_name);
1154                let builder_name = variant_builder_name(parent_name, variant_name);
1155
1156                // Accessors - infallible since variant is pre-validated
1157                let accessors: Vec<TokenStream2> = v.members
1158                .iter()
1159                .map(|m| {
1160                    match m {
1161                        MemberDef::Field { name: field_name, ty, range } => {
1162                            let start = range.start;
1163                            let len = range.len;
1164                            let mask = if len >= repr_bit_count {
1165                                quote! { #repr::MAX }
1166                            } else {
1167                                quote! { ((1 as #repr) << #len) - 1 }
1168                            };
1169
1170                            if is_primitive(ty) {
1171                                quote! {
1172                                    #[inline]
1173                                    pub const fn #field_name(&self) -> #ty {
1174                                        ((self.0 >> #start) & #mask) as #ty
1175                                    }
1176                                }
1177                            } else {
1178                                // IntEnum - infallible because already validated
1179                                quote! {
1180                                    #[inline]
1181                                    pub fn #field_name(&self) -> #ty {
1182                                        let field_repr = ((self.0 >> #start) & #mask);
1183                                        // SAFETY: This type was validated during construction
1184                                        <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
1185                                            .expect("variant type invariant violated")
1186                                    }
1187                                }
1188                            }
1189                        }
1190                        MemberDef::Flag { name: field_name, bit } => {
1191                            quote! {
1192                                #[inline]
1193                                pub const fn #field_name(&self) -> bool {
1194                                    (self.0 >> #bit) & 1 != 0
1195                                }
1196                            }
1197                        }
1198                    }
1199                })
1200                .collect();
1201
1202                quote! {
1203                    impl #variant_type {
1204                        /// Create a builder for this variant.
1205                        #[inline]
1206                        pub fn builder() -> #builder_name {
1207                            #builder_name::default()
1208                        }
1209
1210                        /// Get the raw integer value.
1211                        #[inline]
1212                        pub const fn raw(self) -> #repr {
1213                            self.0
1214                        }
1215
1216                        /// Convert to parent type.
1217                        #[inline]
1218                        pub const fn as_parent(self) -> #parent_name {
1219                            #parent_name(self.0)
1220                        }
1221
1222                        #(#accessors)*
1223                    }
1224                }
1225            })
1226            .collect();
1227
1228    quote! { #(#impls)* }
1229}
1230
1231fn generate_enum_builder_impls(
1232    parent_name: &Ident,
1233    repr: &Ident,
1234    discriminant: BitRange,
1235    variants: &[ParsedVariant],
1236) -> TokenStream2 {
1237    let repr_bit_count = repr_bits(repr);
1238    let disc_start = discriminant.start;
1239
1240    let impls: Vec<TokenStream2> = variants
1241        .iter()
1242        .map(|v| {
1243            let variant_name = &v.name;
1244            let variant_type = variant_type_name(parent_name, variant_name);
1245            let builder_name = variant_builder_name(parent_name, variant_name);
1246            let disc_val = v.discriminant;
1247
1248            // Setters
1249            let setters: Vec<TokenStream2> = v.members
1250                .iter()
1251                .map(|m| match m {
1252                    MemberDef::Field { name: field_name, ty, .. } => {
1253                        quote! {
1254                            #[inline]
1255                            pub fn #field_name(mut self, val: #ty) -> Self {
1256                                self.#field_name = Some(val);
1257                                self
1258                            }
1259                        }
1260                    }
1261                    MemberDef::Flag { name: field_name, .. } => {
1262                        quote! {
1263                            #[inline]
1264                            pub fn #field_name(mut self, val: bool) -> Self {
1265                                self.#field_name = Some(val);
1266                                self
1267                            }
1268                        }
1269                    }
1270                })
1271                .collect();
1272
1273            // Validations
1274            let validations: Vec<TokenStream2> = v.members
1275                .iter()
1276                .filter_map(|m| match m {
1277                    MemberDef::Field { name: field_name, ty, range } => {
1278                        let field_str = field_name.to_string();
1279                        let len = range.len;
1280
1281                        let max_val = if len >= repr_bit_count {
1282                            quote! { #repr::MAX }
1283                        } else {
1284                            quote! { ((1 as #repr) << #len) - 1 }
1285                        };
1286
1287                        if is_primitive(ty) {
1288                            let type_bits: u32 = match ty {
1289                                Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
1290                                Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
1291                                Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
1292                                Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
1293                                Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
1294                                _ => 128,
1295                            };
1296
1297                            if len >= type_bits {
1298                                return None;
1299                            }
1300
1301                            let is_signed = matches!(ty,
1302                                Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
1303                                                 p.path.is_ident("i32") || p.path.is_ident("i64") ||
1304                                                 p.path.is_ident("i128")
1305                            );
1306
1307                            if is_signed {
1308                                let min_shift = len - 1;
1309                                Some(quote! {
1310                                    if let Some(v) = self.#field_name {
1311                                        let min_val = -((1i128 << #min_shift) as i128);
1312                                        let max_val = ((1i128 << #min_shift) - 1) as i128;
1313                                        let v_i128 = v as i128;
1314                                        if v_i128 < min_val || v_i128 > max_val {
1315                                            return Err(nexus_bits::FieldOverflow {
1316                                                field: #field_str,
1317                                                overflow: nexus_bits::Overflow {
1318                                                    value: (v as #repr),
1319                                                    max: #max_val,
1320                                                },
1321                                            });
1322                                        }
1323                                    }
1324                                })
1325                            } else {
1326                                Some(quote! {
1327                                    if let Some(v) = self.#field_name {
1328                                        if (v as #repr) > #max_val {
1329                                            return Err(nexus_bits::FieldOverflow {
1330                                                field: #field_str,
1331                                                overflow: nexus_bits::Overflow {
1332                                                    value: v as #repr,
1333                                                    max: #max_val,
1334                                                },
1335                                            });
1336                                        }
1337                                    }
1338                                })
1339                            }
1340                        } else {
1341                            // IntEnum field
1342                            Some(quote! {
1343                                if let Some(v) = self.#field_name {
1344                                    let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
1345                                    if repr_val > #max_val {
1346                                        return Err(nexus_bits::FieldOverflow {
1347                                            field: #field_str,
1348                                            overflow: nexus_bits::Overflow {
1349                                                value: repr_val,
1350                                                max: #max_val,
1351                                            },
1352                                        });
1353                                    }
1354                                }
1355                            })
1356                        }
1357                    }
1358                    MemberDef::Flag { .. } => None,
1359                })
1360                .collect();
1361
1362            // Pack statements
1363            let pack_statements: Vec<TokenStream2> = v.members
1364                .iter()
1365                .map(|m| {
1366                    match m {
1367                        MemberDef::Field { name: field_name, ty, range } => {
1368                            let start = range.start;
1369                            let len = range.len;
1370                            let mask = if len >= repr_bit_count {
1371                                quote! { #repr::MAX }
1372                            } else {
1373                                quote! { ((1 as #repr) << #len) - 1 }
1374                            };
1375
1376                            if is_primitive(ty) {
1377                                quote! {
1378                                    if let Some(v) = self.#field_name {
1379                                        val |= ((v as #repr) & #mask) << #start;
1380                                    }
1381                                }
1382                            } else {
1383                                quote! {
1384                                    if let Some(v) = self.#field_name {
1385                                        val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
1386                                    }
1387                                }
1388                            }
1389                        }
1390                        MemberDef::Flag { name: field_name, bit } => {
1391                            quote! {
1392                                if let Some(true) = self.#field_name {
1393                                    val |= (1 as #repr) << #bit;
1394                                }
1395                            }
1396                        }
1397                    }
1398                })
1399                .collect();
1400
1401            quote! {
1402                impl #builder_name {
1403                    #(#setters)*
1404
1405                    /// Build the variant type, validating all fields.
1406                    #[inline]
1407                    pub fn build(self) -> Result<#variant_type, nexus_bits::FieldOverflow<#repr>> {
1408                        #(#validations)*
1409
1410                        let mut val: #repr = 0;
1411                        // Set discriminant
1412                        val |= (#disc_val as #repr) << #disc_start;
1413                        #(#pack_statements)*
1414
1415                        Ok(#variant_type(val))
1416                    }
1417
1418                    /// Build directly to parent type, validating all fields.
1419                    #[inline]
1420                    pub fn build_parent(self) -> Result<#parent_name, nexus_bits::FieldOverflow<#repr>> {
1421                        self.build().map(|v| v.as_parent())
1422                    }
1423                }
1424            }
1425        })
1426        .collect();
1427
1428    quote! { #(#impls)* }
1429}
1430
1431fn generate_enum_from_impls(parent_name: &Ident, variants: &[ParsedVariant]) -> TokenStream2 {
1432    let impls: Vec<TokenStream2> = variants
1433        .iter()
1434        .map(|v| {
1435            let variant_type = variant_type_name(parent_name, &v.name);
1436            quote! {
1437                impl From<#variant_type> for #parent_name {
1438                    #[inline]
1439                    fn from(v: #variant_type) -> Self {
1440                        v.as_parent()
1441                    }
1442                }
1443            }
1444        })
1445        .collect();
1446
1447    quote! { #(#impls)* }
1448}
1449
1450fn to_snake_case(s: &str) -> String {
1451    let mut result = String::new();
1452    for (i, c) in s.chars().enumerate() {
1453        if c.is_uppercase() {
1454            if i > 0 {
1455                result.push('_');
1456            }
1457            result.push(c.to_lowercase().next().unwrap());
1458        } else {
1459            result.push(c);
1460        }
1461    }
1462    result
1463}