Skip to main content

nexus_bits_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{
5    Data, DeriveInput, Error, Fields, Ident, Result, Type, parse::Parser, parse_macro_input,
6};
7
8// =============================================================================
9// IntEnum derive
10// =============================================================================
11
12#[proc_macro_derive(IntEnum)]
13pub fn derive_int_enum(input: TokenStream) -> TokenStream {
14    let input = parse_macro_input!(input as DeriveInput);
15
16    match derive_int_enum_impl(&input) {
17        Ok(tokens) => tokens.into(),
18        Err(err) => err.to_compile_error().into(),
19    }
20}
21
22fn derive_int_enum_impl(input: &DeriveInput) -> Result<TokenStream2> {
23    let variants = match &input.data {
24        Data::Enum(data) => &data.variants,
25        _ => {
26            return Err(Error::new_spanned(
27                input,
28                "IntEnum can only be derived for enums",
29            ));
30        }
31    };
32
33    let repr = parse_repr(input)?;
34
35    for variant in variants {
36        if !matches!(variant.fields, Fields::Unit) {
37            return Err(Error::new_spanned(
38                variant,
39                "IntEnum variants cannot have fields",
40            ));
41        }
42    }
43
44    let name = &input.ident;
45    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
46
47    let from_arms = variants.iter().map(|v| {
48        let variant_name = &v.ident;
49        quote! {
50            x if x == #name::#variant_name as #repr => Some(#name::#variant_name),
51        }
52    });
53
54    Ok(quote! {
55        impl #impl_generics nexus_bits::IntEnum for #name #ty_generics #where_clause {
56            type Repr = #repr;
57
58            #[inline]
59            fn into_repr(self) -> #repr {
60                self as #repr
61            }
62
63            #[inline]
64            fn try_from_repr(repr: #repr) -> Option<Self> {
65                match repr {
66                    #(#from_arms)*
67                    _ => None,
68                }
69            }
70        }
71    })
72}
73
74fn parse_repr(input: &DeriveInput) -> Result<Ident> {
75    for attr in &input.attrs {
76        if attr.path().is_ident("repr") {
77            let repr: Ident = attr.parse_args()?;
78            match repr.to_string().as_str() {
79                "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" => {
80                    return Ok(repr);
81                }
82                _ => {
83                    return Err(Error::new_spanned(
84                        repr,
85                        "IntEnum requires a primitive integer repr (u8..u128, i8..i128)",
86                    ));
87                }
88            }
89        }
90    }
91
92    Err(Error::new_spanned(
93        input,
94        "IntEnum requires a #[repr(u8/u16/u32/u64/i8/i16/i32/i64)] attribute",
95    ))
96}
97
98// =============================================================================
99// BitStorage attribute macro
100// =============================================================================
101
102#[proc_macro_attribute]
103pub fn bit_storage(attr: TokenStream, item: TokenStream) -> TokenStream {
104    let attr = proc_macro2::TokenStream::from(attr);
105    let item = parse_macro_input!(item as DeriveInput);
106
107    match bit_storage_impl(attr, &item) {
108        Ok(tokens) => tokens.into(),
109        Err(err) => err.to_compile_error().into(),
110    }
111}
112
113fn bit_storage_impl(attr: TokenStream2, input: &DeriveInput) -> Result<TokenStream2> {
114    let storage_attr = parse_storage_attr_tokens(attr)?;
115
116    match &input.data {
117        Data::Struct(data) => derive_storage_struct(input, data, &storage_attr),
118        Data::Enum(data) => derive_storage_enum(input, data, &storage_attr),
119        Data::Union(_) => Err(Error::new_spanned(
120            input,
121            "bit_storage cannot be applied to unions",
122        )),
123    }
124}
125
126// =============================================================================
127// Attribute types
128// =============================================================================
129
130/// Parsed #[bit_storage(repr = T)] or #[bit_storage(repr = T, discriminant(start = N, len = M))]
131struct StorageAttr {
132    repr: Ident,
133    discriminant: Option<BitRange>,
134}
135
136/// Bit range for a field
137#[derive(Clone, Copy)]
138struct BitRange {
139    start: u32,
140    len: u32,
141}
142
143/// Parsed field/flag from struct
144// Field variant holds syn::Type (~256 bytes) vs Flag (~28 bytes). Boxing isn't
145// worth it — this is compile-time proc-macro code, not runtime.
146#[allow(clippy::large_enum_variant)]
147enum MemberDef {
148    Field {
149        name: Ident,
150        ty: Type,
151        range: BitRange,
152    },
153    Flag {
154        name: Ident,
155        bit: u32,
156    },
157}
158
159impl MemberDef {
160    fn name(&self) -> &Ident {
161        match self {
162            MemberDef::Field { name, .. } | MemberDef::Flag { name, .. } => name,
163        }
164    }
165}
166
167// =============================================================================
168// Attribute parsing
169// =============================================================================
170
171fn parse_storage_attr_tokens(attr: TokenStream2) -> Result<StorageAttr> {
172    let mut repr = None;
173    let mut discriminant = None;
174
175    let parser = syn::meta::parser(|meta| {
176        if meta.path.is_ident("repr") {
177            meta.input.parse::<syn::Token![=]>()?;
178            repr = Some(meta.input.parse::<Ident>()?);
179            Ok(())
180        } else if meta.path.is_ident("discriminant") {
181            let content;
182            syn::parenthesized!(content in meta.input);
183            discriminant = Some(parse_bit_range(&content)?);
184            Ok(())
185        } else {
186            Err(meta.error("expected `repr` or `discriminant`"))
187        }
188    });
189
190    parser.parse2(attr)?;
191
192    let repr = repr.ok_or_else(|| {
193        Error::new(
194            proc_macro2::Span::call_site(),
195            "bit_storage requires `repr = ...`",
196        )
197    })?;
198
199    // Validate repr
200    match repr.to_string().as_str() {
201        "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" => {}
202        _ => return Err(Error::new_spanned(&repr, "repr must be an integer type")),
203    }
204
205    Ok(StorageAttr { repr, discriminant })
206}
207
208fn parse_bit_range(input: syn::parse::ParseStream) -> Result<BitRange> {
209    let mut start = None;
210    let mut len = None;
211
212    while !input.is_empty() {
213        let ident: Ident = input.parse()?;
214        input.parse::<syn::Token![=]>()?;
215        let lit: syn::LitInt = input.parse()?;
216        let value: u32 = lit.base10_parse()?;
217
218        match ident.to_string().as_str() {
219            "start" => start = Some(value),
220            "len" => len = Some(value),
221            _ => return Err(Error::new_spanned(ident, "expected `start` or `len`")),
222        }
223
224        if input.peek(syn::Token![,]) {
225            input.parse::<syn::Token![,]>()?;
226        }
227    }
228
229    let start = start.ok_or_else(|| Error::new(input.span(), "missing `start`"))?;
230    let len = len.ok_or_else(|| Error::new(input.span(), "missing `len`"))?;
231
232    if len == 0 {
233        return Err(Error::new(input.span(), "len must be > 0"));
234    }
235
236    Ok(BitRange { start, len })
237}
238
239fn parse_member(field: &syn::Field) -> Result<MemberDef> {
240    let name = field
241        .ident
242        .clone()
243        .ok_or_else(|| Error::new_spanned(field, "tuple structs not supported"))?;
244    let ty = field.ty.clone();
245
246    for attr in &field.attrs {
247        if attr.path().is_ident("field") {
248            let range = attr.parse_args_with(parse_bit_range)?;
249            return Ok(MemberDef::Field { name, ty, range });
250        } else if attr.path().is_ident("flag") {
251            let bit: syn::LitInt = attr.parse_args()?;
252            let bit: u32 = bit.base10_parse()?;
253            return Ok(MemberDef::Flag { name, bit });
254        }
255    }
256
257    Err(Error::new_spanned(
258        field,
259        "field requires #[field(start = N, len = M)] or #[flag(N)] attribute",
260    ))
261}
262
263fn parse_variant_attr(attrs: &[syn::Attribute]) -> Result<u64> {
264    for attr in attrs {
265        if attr.path().is_ident("variant") {
266            let lit: syn::LitInt = attr.parse_args()?;
267            return lit.base10_parse();
268        }
269    }
270    Err(Error::new(
271        proc_macro2::Span::call_site(),
272        "enum variant requires #[variant(N)] attribute",
273    ))
274}
275
276// =============================================================================
277// Helpers
278// =============================================================================
279
280fn is_primitive(ty: &Type) -> bool {
281    if let Type::Path(type_path) = ty {
282        if let Some(ident) = type_path.path.get_ident() {
283            return matches!(
284                ident.to_string().as_str(),
285                "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128"
286            );
287        }
288    }
289    false
290}
291
292fn is_signed_primitive(ty: &Type) -> bool {
293    if let Type::Path(type_path) = ty {
294        if let Some(ident) = type_path.path.get_ident() {
295            return matches!(
296                ident.to_string().as_str(),
297                "i8" | "i16" | "i32" | "i64" | "i128"
298            );
299        }
300    }
301    false
302}
303
304fn primitive_bits(ty: &Type) -> u32 {
305    if let Type::Path(type_path) = ty {
306        if let Some(ident) = type_path.path.get_ident() {
307            return match ident.to_string().as_str() {
308                "u8" | "i8" => 8,
309                "u16" | "i16" => 16,
310                "u32" | "i32" => 32,
311                "u64" | "i64" => 64,
312                "u128" | "i128" => 128,
313                _ => 0,
314            };
315        }
316    }
317    0
318}
319
320fn repr_bits(repr: &Ident) -> u32 {
321    match repr.to_string().as_str() {
322        "u8" | "i8" => 8,
323        "u16" | "i16" => 16,
324        "u32" | "i32" => 32,
325        "u64" | "i64" => 64,
326        "u128" | "i128" => 128,
327        _ => 0,
328    }
329}
330
331/// Generate a bitmask of `len` ones for the given repr type.
332///
333/// For full-width fields (`len >= repr_bits`), uses `!0` which is all-ones
334/// for both signed and unsigned reprs. For partial-width, computes the mask
335/// in u128 to avoid signed overflow, then casts to the repr type.
336fn field_mask(repr: &Ident, len: u32, repr_bit_count: u32) -> TokenStream2 {
337    if len >= repr_bit_count {
338        quote! { (!0 as #repr) }
339    } else {
340        quote! { (((1u128 << #len) - 1) as #repr) }
341    }
342}
343
344// =============================================================================
345// Validation
346// =============================================================================
347
348fn validate_members(members: &[MemberDef], repr: &Ident) -> Result<()> {
349    let bits = repr_bits(repr);
350
351    // Check each field fits
352    for member in members {
353        match member {
354            MemberDef::Field { name, range, .. } => {
355                if range.start + range.len > bits {
356                    return Err(Error::new_spanned(
357                        name,
358                        format!(
359                            "field exceeds {} bits (start {} + len {} = {})",
360                            bits,
361                            range.start,
362                            range.len,
363                            range.start + range.len
364                        ),
365                    ));
366                }
367            }
368            MemberDef::Flag { name, bit, .. } => {
369                if *bit >= bits {
370                    return Err(Error::new_spanned(
371                        name,
372                        format!("flag bit {} exceeds {} bits", bit, bits),
373                    ));
374                }
375            }
376        }
377    }
378
379    // Check no overlap (simple O(n²) for now)
380    for (i, a) in members.iter().enumerate() {
381        for b in members.iter().skip(i + 1) {
382            if ranges_overlap(a, b) {
383                return Err(Error::new_spanned(
384                    b.name(),
385                    format!("field '{}' overlaps with '{}'", b.name(), a.name()),
386                ));
387            }
388        }
389    }
390
391    Ok(())
392}
393
394fn ranges_overlap(a: &MemberDef, b: &MemberDef) -> bool {
395    let (a_start, a_end) = member_bit_range(a);
396    let (b_start, b_end) = member_bit_range(b);
397    a_start < b_end && b_start < a_end
398}
399
400fn member_bit_range(m: &MemberDef) -> (u32, u32) {
401    match m {
402        MemberDef::Field { range, .. } => (range.start, range.start + range.len),
403        MemberDef::Flag { bit, .. } => (*bit, bit + 1),
404    }
405}
406
407// =============================================================================
408// Struct codegen
409// =============================================================================
410
411fn derive_storage_struct(
412    input: &DeriveInput,
413    data: &syn::DataStruct,
414    storage_attr: &StorageAttr,
415) -> Result<TokenStream2> {
416    let fields = match &data.fields {
417        Fields::Named(f) => &f.named,
418        _ => {
419            return Err(Error::new_spanned(
420                input,
421                "bit_storage requires named fields",
422            ));
423        }
424    };
425
426    if storage_attr.discriminant.is_some() {
427        return Err(Error::new_spanned(
428            input,
429            "discriminant is only valid for enums",
430        ));
431    }
432
433    let members: Vec<MemberDef> = fields.iter().map(parse_member).collect::<Result<_>>()?;
434
435    validate_members(&members, &storage_attr.repr)?;
436
437    let vis = &input.vis;
438    let name = &input.ident;
439    let repr = &storage_attr.repr;
440    let builder_name = Ident::new(&format!("{}Builder", name), name.span());
441
442    let newtype = generate_struct_newtype(vis, name, repr);
443    let builder_struct = generate_struct_builder_struct(vis, &builder_name, &members);
444    let newtype_impl = generate_struct_newtype_impl(name, &builder_name, repr, &members);
445    let builder_impl = generate_struct_builder_impl(name, &builder_name, repr, &members);
446
447    Ok(quote! {
448        #newtype
449        #builder_struct
450        #newtype_impl
451        #builder_impl
452    })
453}
454
455fn generate_struct_newtype(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
456    quote! {
457        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
458        #[repr(transparent)]
459        #vis struct #name(#vis #repr);
460    }
461}
462
463fn generate_struct_builder_struct(
464    vis: &syn::Visibility,
465    builder_name: &Ident,
466    members: &[MemberDef],
467) -> TokenStream2 {
468    let fields: Vec<TokenStream2> = members
469        .iter()
470        .map(|m| match m {
471            MemberDef::Field { name, ty, .. } => {
472                quote! { #name: Option<#ty>, }
473            }
474            MemberDef::Flag { name, .. } => {
475                quote! { #name: Option<bool>, }
476            }
477        })
478        .collect();
479
480    quote! {
481        #[derive(Debug, Clone, Copy, Default)]
482        #vis struct #builder_name {
483            #(#fields)*
484        }
485    }
486}
487
488fn generate_struct_newtype_impl(
489    name: &Ident,
490    builder_name: &Ident,
491    repr: &Ident,
492    members: &[MemberDef],
493) -> TokenStream2 {
494    let repr_bit_count = repr_bits(repr);
495
496    let accessors: Vec<TokenStream2> = members.iter().map(|m| {
497        match m {
498            MemberDef::Field { name: field_name, ty, range } => {
499                let start = range.start;
500                let len = range.len;
501                let mask = field_mask(repr, len, repr_bit_count);
502
503                if is_primitive(ty) {
504                    let type_bits = primitive_bits(ty);
505                    if is_signed_primitive(ty) && len < type_bits {
506                        // Sign-extend: shift left to MSB, arithmetic right shift back
507                        let shift = type_bits - len;
508                        quote! {
509                            #[inline]
510                            pub const fn #field_name(&self) -> #ty {
511                                let raw = ((self.0 >> #start) & #mask) as #ty;
512                                (raw << #shift) >> #shift
513                            }
514                        }
515                    } else {
516                        quote! {
517                            #[inline]
518                            pub const fn #field_name(&self) -> #ty {
519                                ((self.0 >> #start) & #mask) as #ty
520                            }
521                        }
522                    }
523                } else {
524                    // IntEnum field - returns Result
525                    quote! {
526                        #[inline]
527                        pub fn #field_name(&self) -> Result<#ty, nexus_bits::UnknownDiscriminant<#repr>> {
528                            let field_repr = ((self.0 >> #start) & #mask);
529                            <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
530                                .ok_or(nexus_bits::UnknownDiscriminant {
531                                    field: stringify!(#field_name),
532                                    value: field_repr as #repr,
533                                })
534                        }
535                    }
536                }
537            }
538            MemberDef::Flag { name: field_name, bit } => {
539                quote! {
540                    #[inline]
541                    pub const fn #field_name(&self) -> bool {
542                        (self.0 >> #bit) & 1 != 0
543                    }
544                }
545            }
546        }
547    }).collect();
548
549    quote! {
550        impl #name {
551            /// Create from raw integer value.
552            #[inline]
553            pub const fn from_raw(raw: #repr) -> Self {
554                Self(raw)
555            }
556
557            /// Get the raw integer value.
558            #[inline]
559            pub const fn raw(self) -> #repr {
560                self.0
561            }
562
563            /// Create a builder for this type.
564            #[inline]
565            pub fn builder() -> #builder_name {
566                #builder_name::default()
567            }
568
569            #(#accessors)*
570        }
571    }
572}
573
574fn generate_struct_builder_impl(
575    name: &Ident,
576    builder_name: &Ident,
577    repr: &Ident,
578    members: &[MemberDef],
579) -> TokenStream2 {
580    let repr_bit_count = repr_bits(repr);
581
582    // Setters - wrap in Some
583    let setters: Vec<TokenStream2> = members
584        .iter()
585        .map(|m| match m {
586            MemberDef::Field {
587                name: field_name,
588                ty,
589                ..
590            } => {
591                quote! {
592                    #[inline]
593                    pub fn #field_name(mut self, val: #ty) -> Self {
594                        self.#field_name = Some(val);
595                        self
596                    }
597                }
598            }
599            MemberDef::Flag {
600                name: field_name, ..
601            } => {
602                quote! {
603                    #[inline]
604                    pub fn #field_name(mut self, val: bool) -> Self {
605                        self.#field_name = Some(val);
606                        self
607                    }
608                }
609            }
610        })
611        .collect();
612
613    // Validations
614    let validations: Vec<TokenStream2> = members
615        .iter()
616        .filter_map(|m| match m {
617            MemberDef::Field {
618                name: field_name,
619                ty,
620                range,
621            } => {
622                let field_str = field_name.to_string();
623                let len = range.len;
624
625                let max_val = field_mask(repr, len, repr_bit_count);
626
627                if is_primitive(ty) {
628                    let type_bits: u32 = match ty {
629                        Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
630                        Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
631                        Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
632                        Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
633                        Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
634                        _ => 128,
635                    };
636
637                    // Skip validation if field can hold entire type
638                    if len >= type_bits {
639                        return None;
640                    }
641
642                    // Check if this is a signed type
643                    let is_signed = matches!(ty,
644                        Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
645                                         p.path.is_ident("i32") || p.path.is_ident("i64") ||
646                                         p.path.is_ident("i128")
647                    );
648
649                    if is_signed {
650                        // For signed types, check that value fits in signed field range
651                        // A signed N-bit field can hold -(2^(N-1)) to (2^(N-1) - 1)
652                        // Note: len < 128 is guaranteed here — the early return at
653                        // line 596 (len >= type_bits) catches len >= 128 for i128.
654                        let min_shift = len - 1;
655                        Some(quote! {
656                            if let Some(v) = self.#field_name {
657                                let min_val = -((1i128 << #min_shift) as i128);
658                                let max_val = ((1i128 << #min_shift) - 1) as i128;
659                                let v_i128 = v as i128;
660                                if v_i128 < min_val || v_i128 > max_val {
661                                    return Err(nexus_bits::FieldOverflow {
662                                        field: #field_str,
663                                        overflow: nexus_bits::Overflow {
664                                            value: (v as #repr),
665                                            max: #max_val,
666                                        },
667                                    });
668                                }
669                            }
670                        })
671                    } else {
672                        // Unsigned - simple max check
673                        Some(quote! {
674                            if let Some(v) = self.#field_name {
675                                if (v as #repr) > #max_val {
676                                    return Err(nexus_bits::FieldOverflow {
677                                        field: #field_str,
678                                        overflow: nexus_bits::Overflow {
679                                            value: v as #repr,
680                                            max: #max_val,
681                                        },
682                                    });
683                                }
684                            }
685                        })
686                    }
687                } else {
688                    // IntEnum field - validate repr value fits in field
689                    Some(quote! {
690                        const _: () = assert!(
691                            core::mem::size_of::<<#ty as nexus_bits::IntEnum>::Repr>() <= core::mem::size_of::<#repr>(),
692                            "IntEnum repr type is wider than storage repr — values may be truncated"
693                        );
694                        if let Some(v) = self.#field_name {
695                            let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
696                            if repr_val > #max_val {
697                                return Err(nexus_bits::FieldOverflow {
698                                    field: #field_str,
699                                    overflow: nexus_bits::Overflow {
700                                        value: repr_val,
701                                        max: #max_val,
702                                    },
703                                });
704                            }
705                        }
706                    })
707                }
708            }
709            MemberDef::Flag { .. } => None,
710        })
711        .collect();
712
713    // Pack statements - ALWAYS mask to prevent sign extension corruption
714    let pack_statements: Vec<TokenStream2> = members
715        .iter()
716        .map(|m| {
717            match m {
718                MemberDef::Field {
719                    name: field_name,
720                    ty,
721                    range,
722                } => {
723                    let start = range.start;
724                    let len = range.len;
725                    let mask = field_mask(repr, len, repr_bit_count);
726
727                    if is_primitive(ty) {
728                        quote! {
729                            if let Some(v) = self.#field_name {
730                                val |= ((v as #repr) & #mask) << #start;
731                            }
732                        }
733                    } else {
734                        // IntEnum
735                        quote! {
736                            if let Some(v) = self.#field_name {
737                                val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
738                            }
739                        }
740                    }
741                }
742                MemberDef::Flag {
743                    name: field_name,
744                    bit,
745                } => {
746                    quote! {
747                        if let Some(true) = self.#field_name {
748                            val |= (1 as #repr) << #bit;
749                        }
750                    }
751                }
752            }
753        })
754        .collect();
755
756    quote! {
757        impl #builder_name {
758            #(#setters)*
759
760            /// Build the final value, validating all fields.
761            #[inline]
762            pub fn build(self) -> Result<#name, nexus_bits::FieldOverflow<#repr>> {
763                // Validate
764                #(#validations)*
765
766                // Pack
767                let mut val: #repr = 0;
768                #(#pack_statements)*
769
770                Ok(#name(val))
771            }
772        }
773    }
774}
775
776// =============================================================================
777// Enum codegen
778// =============================================================================
779
780/// Parsed variant for tagged enum
781struct ParsedVariant {
782    name: Ident,
783    discriminant: u64,
784    members: Vec<MemberDef>,
785}
786
787fn derive_storage_enum(
788    input: &DeriveInput,
789    data: &syn::DataEnum,
790    storage_attr: &StorageAttr,
791) -> Result<TokenStream2> {
792    let discriminant = storage_attr.discriminant.ok_or_else(|| {
793        Error::new_spanned(
794            input,
795            "bit_storage enum requires discriminant: #[bit_storage(repr = T, discriminant(start = N, len = M))]",
796        )
797    })?;
798
799    let repr = &storage_attr.repr;
800    let bits = repr_bits(repr);
801
802    // Validate discriminant fits
803    if discriminant.start + discriminant.len > bits {
804        return Err(Error::new_spanned(
805            input,
806            format!(
807                "discriminant exceeds {} bits (start {} + len {} = {})",
808                bits,
809                discriminant.start,
810                discriminant.len,
811                discriminant.start + discriminant.len
812            ),
813        ));
814    }
815
816    let max_discriminant = if discriminant.len >= 64 {
817        u64::MAX
818    } else {
819        (1u64 << discriminant.len) - 1
820    };
821
822    // Parse all variants
823    let mut variants = Vec::new();
824    for variant in &data.variants {
825        let disc = parse_variant_attr(&variant.attrs)?;
826
827        if disc > max_discriminant {
828            return Err(Error::new_spanned(
829                &variant.ident,
830                format!(
831                    "variant discriminant {} exceeds max {} for {}-bit field",
832                    disc, max_discriminant, discriminant.len
833                ),
834            ));
835        }
836
837        // Check for duplicate discriminants
838        for existing in &variants {
839            let existing: &ParsedVariant = existing;
840            if existing.discriminant == disc {
841                return Err(Error::new_spanned(
842                    &variant.ident,
843                    format!(
844                        "duplicate discriminant {}: already used by '{}'",
845                        disc, existing.name
846                    ),
847                ));
848            }
849        }
850
851        let members: Vec<MemberDef> = match &variant.fields {
852            Fields::Named(fields) => fields
853                .named
854                .iter()
855                .map(parse_member)
856                .collect::<Result<_>>()?,
857            Fields::Unit => Vec::new(),
858            Fields::Unnamed(_) => {
859                return Err(Error::new_spanned(
860                    variant,
861                    "tuple variants not supported, use named fields",
862                ));
863            }
864        };
865
866        // Validate members don't overlap with discriminant
867        let disc_range = MemberDef::Field {
868            name: Ident::new("__discriminant", proc_macro2::Span::call_site()),
869            ty: syn::parse_quote!(u64),
870            range: discriminant,
871        };
872
873        for member in &members {
874            if ranges_overlap(&disc_range, member) {
875                return Err(Error::new_spanned(
876                    member.name(),
877                    format!("field '{}' overlaps with discriminant", member.name()),
878                ));
879            }
880        }
881
882        // Validate members within this variant
883        validate_members(&members, repr)?;
884
885        variants.push(ParsedVariant {
886            name: variant.ident.clone(),
887            discriminant: disc,
888            members,
889        });
890    }
891
892    let vis = &input.vis;
893    let name = &input.ident;
894
895    let parent_type = generate_enum_parent_type(vis, name, repr);
896    let variant_types = generate_enum_variant_types(vis, name, repr, &variants);
897    let kind_enum = generate_enum_kind(vis, name, &variants);
898    let builder_structs = generate_enum_builder_structs(vis, name, &variants);
899    let parent_impl = generate_enum_parent_impl(name, repr, discriminant, &variants);
900    let variant_impls = generate_enum_variant_impls(name, repr, &variants);
901    let builder_impls = generate_enum_builder_impls(name, repr, discriminant, &variants);
902    let from_impls = generate_enum_from_impls(name, &variants);
903
904    Ok(quote! {
905        #parent_type
906        #variant_types
907        #kind_enum
908        #builder_structs
909        #parent_impl
910        #variant_impls
911        #builder_impls
912        #from_impls
913    })
914}
915
916fn variant_type_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
917    Ident::new(
918        &format!("{}{}", parent_name, variant_name),
919        variant_name.span(),
920    )
921}
922
923fn variant_builder_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
924    Ident::new(
925        &format!("{}{}Builder", parent_name, variant_name),
926        variant_name.span(),
927    )
928}
929
930fn kind_enum_name(parent_name: &Ident) -> Ident {
931    Ident::new(&format!("{}Kind", parent_name), parent_name.span())
932}
933
934fn generate_enum_parent_type(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
935    quote! {
936        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
937        #[repr(transparent)]
938        #vis struct #name(#vis #repr);
939    }
940}
941
942fn generate_enum_variant_types(
943    vis: &syn::Visibility,
944    parent_name: &Ident,
945    repr: &Ident,
946    variants: &[ParsedVariant],
947) -> TokenStream2 {
948    let types: Vec<TokenStream2> = variants
949        .iter()
950        .map(|v| {
951            let type_name = variant_type_name(parent_name, &v.name);
952            quote! {
953                #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
954                #[repr(transparent)]
955                #vis struct #type_name(#repr);
956            }
957        })
958        .collect();
959
960    quote! { #(#types)* }
961}
962
963fn generate_enum_kind(
964    vis: &syn::Visibility,
965    parent_name: &Ident,
966    variants: &[ParsedVariant],
967) -> TokenStream2 {
968    let kind_name = kind_enum_name(parent_name);
969    let variant_names: Vec<&Ident> = variants.iter().map(|v| &v.name).collect();
970
971    quote! {
972        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
973        #vis enum #kind_name {
974            #(#variant_names),*
975        }
976    }
977}
978
979fn generate_enum_builder_structs(
980    vis: &syn::Visibility,
981    parent_name: &Ident,
982    variants: &[ParsedVariant],
983) -> TokenStream2 {
984    let builders: Vec<TokenStream2> = variants
985        .iter()
986        .map(|v| {
987            let builder_name = variant_builder_name(parent_name, &v.name);
988
989            let fields: Vec<TokenStream2> = v
990                .members
991                .iter()
992                .map(|m| match m {
993                    MemberDef::Field { name, ty, .. } => {
994                        quote! { #name: Option<#ty>, }
995                    }
996                    MemberDef::Flag { name, .. } => {
997                        quote! { #name: Option<bool>, }
998                    }
999                })
1000                .collect();
1001
1002            quote! {
1003                #[derive(Debug, Clone, Copy, Default)]
1004                #vis struct #builder_name {
1005                    #(#fields)*
1006                }
1007            }
1008        })
1009        .collect();
1010
1011    quote! { #(#builders)* }
1012}
1013
1014fn generate_enum_parent_impl(
1015    name: &Ident,
1016    repr: &Ident,
1017    discriminant: BitRange,
1018    variants: &[ParsedVariant],
1019) -> TokenStream2 {
1020    let repr_bit_count = repr_bits(repr);
1021    let kind_name = kind_enum_name(name);
1022    let disc_start = discriminant.start;
1023    let disc_len = discriminant.len;
1024
1025    // Discriminant is extracted as u64 for matching. Wider discriminants
1026    // would truncate silently, so reject them at macro expansion time.
1027    assert!(
1028        disc_len <= 64,
1029        "discriminant length must be <= 64 bits (got {disc_len})"
1030    );
1031
1032    let disc_mask = field_mask(repr, disc_len, repr_bit_count);
1033
1034    // kind() match arms
1035    let kind_arms: Vec<TokenStream2> = variants
1036        .iter()
1037        .map(|v| {
1038            let variant_name = &v.name;
1039            let disc_val = v.discriminant;
1040            quote! {
1041                #disc_val => Ok(#kind_name::#variant_name),
1042            }
1043        })
1044        .collect();
1045
1046    // is_* methods
1047    let is_methods: Vec<TokenStream2> = variants
1048        .iter()
1049        .map(|v| {
1050            let variant_name = &v.name;
1051            let method_name = Ident::new(
1052                &format!("is_{}", to_snake_case(&variant_name.to_string())),
1053                variant_name.span(),
1054            );
1055            let disc_val = v.discriminant;
1056            quote! {
1057                #[inline]
1058                pub fn #method_name(&self) -> bool {
1059                    let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1060                    disc == #disc_val
1061                }
1062            }
1063        })
1064        .collect();
1065
1066    // as_* methods
1067    let as_methods: Vec<TokenStream2> = variants
1068        .iter()
1069        .map(|v| {
1070            let variant_name = &v.name;
1071            let variant_type = variant_type_name(name, variant_name);
1072            let method_name = Ident::new(
1073                &format!("as_{}", to_snake_case(&variant_name.to_string())),
1074                variant_name.span(),
1075            );
1076            let disc_val = v.discriminant;
1077
1078            // Validation for IntEnum fields
1079            let validations: Vec<TokenStream2> = v.members
1080                .iter()
1081                .filter_map(|m| {
1082                    if let MemberDef::Field { name: field_name, ty, range } = m {
1083                        if !is_primitive(ty) {
1084                            let start = range.start;
1085                            let len = range.len;
1086                            let repr_bit_count = repr_bits(repr);
1087                            let mask = field_mask(repr, len, repr_bit_count);
1088                            return Some(quote! {
1089                                let field_repr = ((self.0 >> #start) & #mask);
1090                                if <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _).is_none() {
1091                                    return Err(nexus_bits::UnknownDiscriminant {
1092                                        field: stringify!(#field_name),
1093                                        value: field_repr as #repr,
1094                                    });
1095                                }
1096                            });
1097                        }
1098                    }
1099                    None
1100                })
1101                .collect();
1102
1103            quote! {
1104                #[inline]
1105                pub fn #method_name(&self) -> Result<#variant_type, nexus_bits::UnknownDiscriminant<#repr>> {
1106                    let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1107                    if disc != #disc_val {
1108                        return Err(nexus_bits::UnknownDiscriminant {
1109                            field: "__discriminant",
1110                            value: disc as #repr,
1111                        });
1112                    }
1113                    #(#validations)*
1114                    Ok(#variant_type(self.0))
1115                }
1116            }
1117        })
1118        .collect();
1119
1120    // Builder shortcut methods
1121    let builder_methods: Vec<TokenStream2> = variants
1122        .iter()
1123        .map(|v| {
1124            let variant_name = &v.name;
1125            let builder_name = variant_builder_name(name, variant_name);
1126            let method_name = Ident::new(
1127                &to_snake_case(&variant_name.to_string()),
1128                variant_name.span(),
1129            );
1130            quote! {
1131                #[inline]
1132                pub fn #method_name() -> #builder_name {
1133                    #builder_name::default()
1134                }
1135            }
1136        })
1137        .collect();
1138
1139    quote! {
1140        impl #name {
1141            /// Create from raw integer value.
1142            #[inline]
1143            pub const fn from_raw(raw: #repr) -> Self {
1144                Self(raw)
1145            }
1146
1147            /// Get the raw integer value.
1148            #[inline]
1149            pub const fn raw(self) -> #repr {
1150                self.0
1151            }
1152
1153            /// Get the kind (discriminant) of this value.
1154            #[inline]
1155            pub fn kind(&self) -> Result<#kind_name, nexus_bits::UnknownDiscriminant<#repr>> {
1156                let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1157                match disc {
1158                    #(#kind_arms)*
1159                    _ => Err(nexus_bits::UnknownDiscriminant {
1160                        field: "__discriminant",
1161                        value: disc as #repr,
1162                    }),
1163                }
1164            }
1165
1166            #(#is_methods)*
1167
1168            #(#as_methods)*
1169
1170            #(#builder_methods)*
1171        }
1172    }
1173}
1174
1175fn generate_enum_variant_impls(
1176    parent_name: &Ident,
1177    repr: &Ident,
1178    variants: &[ParsedVariant],
1179) -> TokenStream2 {
1180    let repr_bit_count = repr_bits(repr);
1181
1182    let impls: Vec<TokenStream2> =
1183        variants
1184            .iter()
1185            .map(|v| {
1186                let variant_name = &v.name;
1187                let variant_type = variant_type_name(parent_name, variant_name);
1188                let builder_name = variant_builder_name(parent_name, variant_name);
1189
1190                // Accessors - infallible since variant is pre-validated
1191                let accessors: Vec<TokenStream2> = v.members
1192                .iter()
1193                .map(|m| {
1194                    match m {
1195                        MemberDef::Field { name: field_name, ty, range } => {
1196                            let start = range.start;
1197                            let len = range.len;
1198                            let mask = field_mask(repr, len, repr_bit_count);
1199
1200                            if is_primitive(ty) {
1201                                let type_bits = primitive_bits(ty);
1202                                if is_signed_primitive(ty) && len < type_bits {
1203                                    let shift = type_bits - len;
1204                                    quote! {
1205                                        #[inline]
1206                                        pub const fn #field_name(&self) -> #ty {
1207                                            let raw = ((self.0 >> #start) & #mask) as #ty;
1208                                            (raw << #shift) >> #shift
1209                                        }
1210                                    }
1211                                } else {
1212                                    quote! {
1213                                        #[inline]
1214                                        pub const fn #field_name(&self) -> #ty {
1215                                            ((self.0 >> #start) & #mask) as #ty
1216                                        }
1217                                    }
1218                                }
1219                            } else {
1220                                // IntEnum - infallible because already validated
1221                                quote! {
1222                                    #[inline]
1223                                    pub fn #field_name(&self) -> #ty {
1224                                        let field_repr = ((self.0 >> #start) & #mask);
1225                                        // SAFETY: This type was validated during construction
1226                                        <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
1227                                            .expect("variant type invariant violated")
1228                                    }
1229                                }
1230                            }
1231                        }
1232                        MemberDef::Flag { name: field_name, bit } => {
1233                            quote! {
1234                                #[inline]
1235                                pub const fn #field_name(&self) -> bool {
1236                                    (self.0 >> #bit) & 1 != 0
1237                                }
1238                            }
1239                        }
1240                    }
1241                })
1242                .collect();
1243
1244                quote! {
1245                    impl #variant_type {
1246                        /// Create a builder for this variant.
1247                        #[inline]
1248                        pub fn builder() -> #builder_name {
1249                            #builder_name::default()
1250                        }
1251
1252                        /// Get the raw integer value.
1253                        #[inline]
1254                        pub const fn raw(self) -> #repr {
1255                            self.0
1256                        }
1257
1258                        /// Convert to parent type.
1259                        #[inline]
1260                        pub const fn as_parent(self) -> #parent_name {
1261                            #parent_name(self.0)
1262                        }
1263
1264                        #(#accessors)*
1265                    }
1266                }
1267            })
1268            .collect();
1269
1270    quote! { #(#impls)* }
1271}
1272
1273fn generate_enum_builder_impls(
1274    parent_name: &Ident,
1275    repr: &Ident,
1276    discriminant: BitRange,
1277    variants: &[ParsedVariant],
1278) -> TokenStream2 {
1279    let repr_bit_count = repr_bits(repr);
1280    let disc_start = discriminant.start;
1281
1282    let impls: Vec<TokenStream2> = variants
1283        .iter()
1284        .map(|v| {
1285            let variant_name = &v.name;
1286            let variant_type = variant_type_name(parent_name, variant_name);
1287            let builder_name = variant_builder_name(parent_name, variant_name);
1288            let disc_val = v.discriminant;
1289
1290            // Setters
1291            let setters: Vec<TokenStream2> = v.members
1292                .iter()
1293                .map(|m| match m {
1294                    MemberDef::Field { name: field_name, ty, .. } => {
1295                        quote! {
1296                            #[inline]
1297                            pub fn #field_name(mut self, val: #ty) -> Self {
1298                                self.#field_name = Some(val);
1299                                self
1300                            }
1301                        }
1302                    }
1303                    MemberDef::Flag { name: field_name, .. } => {
1304                        quote! {
1305                            #[inline]
1306                            pub fn #field_name(mut self, val: bool) -> Self {
1307                                self.#field_name = Some(val);
1308                                self
1309                            }
1310                        }
1311                    }
1312                })
1313                .collect();
1314
1315            // Validations
1316            let validations: Vec<TokenStream2> = v.members
1317                .iter()
1318                .filter_map(|m| match m {
1319                    MemberDef::Field { name: field_name, ty, range } => {
1320                        let field_str = field_name.to_string();
1321                        let len = range.len;
1322
1323                        let max_val = field_mask(repr, len, repr_bit_count);
1324
1325                        if is_primitive(ty) {
1326                            let type_bits: u32 = match ty {
1327                                Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
1328                                Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
1329                                Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
1330                                Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
1331                                Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
1332                                _ => 128,
1333                            };
1334
1335                            if len >= type_bits {
1336                                return None;
1337                            }
1338
1339                            let is_signed = matches!(ty,
1340                                Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
1341                                                 p.path.is_ident("i32") || p.path.is_ident("i64") ||
1342                                                 p.path.is_ident("i128")
1343                            );
1344
1345                            if is_signed {
1346                                let min_shift = len - 1;
1347                                Some(quote! {
1348                                    if let Some(v) = self.#field_name {
1349                                        let min_val = -((1i128 << #min_shift) as i128);
1350                                        let max_val = ((1i128 << #min_shift) - 1) as i128;
1351                                        let v_i128 = v as i128;
1352                                        if v_i128 < min_val || v_i128 > max_val {
1353                                            return Err(nexus_bits::FieldOverflow {
1354                                                field: #field_str,
1355                                                overflow: nexus_bits::Overflow {
1356                                                    value: (v as #repr),
1357                                                    max: #max_val,
1358                                                },
1359                                            });
1360                                        }
1361                                    }
1362                                })
1363                            } else {
1364                                Some(quote! {
1365                                    if let Some(v) = self.#field_name {
1366                                        if (v as #repr) > #max_val {
1367                                            return Err(nexus_bits::FieldOverflow {
1368                                                field: #field_str,
1369                                                overflow: nexus_bits::Overflow {
1370                                                    value: v as #repr,
1371                                                    max: #max_val,
1372                                                },
1373                                            });
1374                                        }
1375                                    }
1376                                })
1377                            }
1378                        } else {
1379                            // IntEnum field
1380                            Some(quote! {
1381                                if let Some(v) = self.#field_name {
1382                                    let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
1383                                    if repr_val > #max_val {
1384                                        return Err(nexus_bits::FieldOverflow {
1385                                            field: #field_str,
1386                                            overflow: nexus_bits::Overflow {
1387                                                value: repr_val,
1388                                                max: #max_val,
1389                                            },
1390                                        });
1391                                    }
1392                                }
1393                            })
1394                        }
1395                    }
1396                    MemberDef::Flag { .. } => None,
1397                })
1398                .collect();
1399
1400            // Pack statements
1401            let pack_statements: Vec<TokenStream2> = v.members
1402                .iter()
1403                .map(|m| {
1404                    match m {
1405                        MemberDef::Field { name: field_name, ty, range } => {
1406                            let start = range.start;
1407                            let len = range.len;
1408                            let mask = field_mask(repr, len, repr_bit_count);
1409
1410                            if is_primitive(ty) {
1411                                quote! {
1412                                    if let Some(v) = self.#field_name {
1413                                        val |= ((v as #repr) & #mask) << #start;
1414                                    }
1415                                }
1416                            } else {
1417                                quote! {
1418                                    if let Some(v) = self.#field_name {
1419                                        val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
1420                                    }
1421                                }
1422                            }
1423                        }
1424                        MemberDef::Flag { name: field_name, bit } => {
1425                            quote! {
1426                                if let Some(true) = self.#field_name {
1427                                    val |= (1 as #repr) << #bit;
1428                                }
1429                            }
1430                        }
1431                    }
1432                })
1433                .collect();
1434
1435            quote! {
1436                impl #builder_name {
1437                    #(#setters)*
1438
1439                    /// Build the variant type, validating all fields.
1440                    #[inline]
1441                    pub fn build(self) -> Result<#variant_type, nexus_bits::FieldOverflow<#repr>> {
1442                        #(#validations)*
1443
1444                        let mut val: #repr = 0;
1445                        // Set discriminant
1446                        val |= (#disc_val as #repr) << #disc_start;
1447                        #(#pack_statements)*
1448
1449                        Ok(#variant_type(val))
1450                    }
1451
1452                    /// Build directly to parent type, validating all fields.
1453                    #[inline]
1454                    pub fn build_parent(self) -> Result<#parent_name, nexus_bits::FieldOverflow<#repr>> {
1455                        self.build().map(|v| v.as_parent())
1456                    }
1457                }
1458            }
1459        })
1460        .collect();
1461
1462    quote! { #(#impls)* }
1463}
1464
1465fn generate_enum_from_impls(parent_name: &Ident, variants: &[ParsedVariant]) -> TokenStream2 {
1466    let impls: Vec<TokenStream2> = variants
1467        .iter()
1468        .map(|v| {
1469            let variant_type = variant_type_name(parent_name, &v.name);
1470            quote! {
1471                impl From<#variant_type> for #parent_name {
1472                    #[inline]
1473                    fn from(v: #variant_type) -> Self {
1474                        v.as_parent()
1475                    }
1476                }
1477            }
1478        })
1479        .collect();
1480
1481    quote! { #(#impls)* }
1482}
1483
1484fn to_snake_case(s: &str) -> String {
1485    let mut result = String::new();
1486    for (i, c) in s.chars().enumerate() {
1487        if c.is_uppercase() {
1488            if i > 0 {
1489                result.push('_');
1490            }
1491            result.push(c.to_lowercase().next().unwrap());
1492        } else {
1493            result.push(c);
1494        }
1495    }
1496    result
1497}