nexus_bits_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{
5    Data, DeriveInput, Error, Fields, Ident, Result, Type, parse::Parser, parse_macro_input,
6};
7
8// =============================================================================
9// IntEnum derive
10// =============================================================================
11
12#[proc_macro_derive(IntEnum)]
13pub fn derive_int_enum(input: TokenStream) -> TokenStream {
14    let input = parse_macro_input!(input as DeriveInput);
15
16    match derive_int_enum_impl(input) {
17        Ok(tokens) => tokens.into(),
18        Err(err) => err.to_compile_error().into(),
19    }
20}
21
22fn derive_int_enum_impl(input: DeriveInput) -> Result<TokenStream2> {
23    let variants = match &input.data {
24        Data::Enum(data) => &data.variants,
25        _ => {
26            return Err(Error::new_spanned(
27                &input,
28                "IntEnum can only be derived for enums",
29            ));
30        }
31    };
32
33    let repr = parse_repr(&input)?;
34
35    for variant in variants {
36        if !matches!(variant.fields, Fields::Unit) {
37            return Err(Error::new_spanned(
38                variant,
39                "IntEnum variants cannot have fields",
40            ));
41        }
42    }
43
44    let name = &input.ident;
45    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
46
47    let from_arms = variants.iter().map(|v| {
48        let variant_name = &v.ident;
49        quote! {
50            x if x == #name::#variant_name as #repr => Some(#name::#variant_name),
51        }
52    });
53
54    Ok(quote! {
55        impl #impl_generics nexus_bits::IntEnum for #name #ty_generics #where_clause {
56            type Repr = #repr;
57
58            #[inline]
59            fn into_repr(self) -> #repr {
60                self as #repr
61            }
62
63            #[inline]
64            fn try_from_repr(repr: #repr) -> Option<Self> {
65                match repr {
66                    #(#from_arms)*
67                    _ => None,
68                }
69            }
70        }
71    })
72}
73
74fn parse_repr(input: &DeriveInput) -> Result<Ident> {
75    for attr in &input.attrs {
76        if attr.path().is_ident("repr") {
77            let repr: Ident = attr.parse_args()?;
78            match repr.to_string().as_str() {
79                "u8" | "u16" | "u32" | "u64" | "i8" | "i16" | "i32" | "i64" => {
80                    return Ok(repr);
81                }
82                _ => {
83                    return Err(Error::new_spanned(
84                        repr,
85                        "IntEnum requires repr(u8), repr(u16), repr(u32), repr(u64), repr(i8), repr(i16), repr(i32), or repr(i64)",
86                    ));
87                }
88            }
89        }
90    }
91
92    Err(Error::new_spanned(
93        input,
94        "IntEnum requires a #[repr(u8/u16/u32/u64/i8/i16/i32/i64)] attribute",
95    ))
96}
97
98// =============================================================================
99// BitStorage attribute macro
100// =============================================================================
101
102#[proc_macro_attribute]
103pub fn bit_storage(attr: TokenStream, item: TokenStream) -> TokenStream {
104    let attr = proc_macro2::TokenStream::from(attr);
105    let item = parse_macro_input!(item as DeriveInput);
106
107    match bit_storage_impl(attr, item) {
108        Ok(tokens) => tokens.into(),
109        Err(err) => err.to_compile_error().into(),
110    }
111}
112
113fn bit_storage_impl(attr: TokenStream2, input: DeriveInput) -> Result<TokenStream2> {
114    let storage_attr = parse_storage_attr_tokens(attr)?;
115
116    match &input.data {
117        Data::Struct(data) => derive_storage_struct(&input, data, storage_attr),
118        Data::Enum(data) => derive_storage_enum(&input, data, storage_attr),
119        Data::Union(_) => Err(Error::new_spanned(
120            &input,
121            "bit_storage cannot be applied to unions",
122        )),
123    }
124}
125
126// =============================================================================
127// Attribute types
128// =============================================================================
129
130/// Parsed #[bit_storage(repr = T)] or #[bit_storage(repr = T, discriminant(start = N, len = M))]
131struct StorageAttr {
132    repr: Ident,
133    discriminant: Option<BitRange>,
134}
135
136/// Bit range for a field
137struct BitRange {
138    start: u32,
139    len: u32,
140}
141
142/// Parsed field/flag from struct
143enum MemberDef {
144    Field {
145        name: Ident,
146        ty: Type,
147        range: BitRange,
148    },
149    Flag {
150        name: Ident,
151        bit: u32,
152    },
153}
154
155impl MemberDef {
156    fn name(&self) -> &Ident {
157        match self {
158            MemberDef::Field { name, .. } => name,
159            MemberDef::Flag { name, .. } => name,
160        }
161    }
162}
163
164// =============================================================================
165// Attribute parsing
166// =============================================================================
167
168fn parse_storage_attr_tokens(attr: TokenStream2) -> Result<StorageAttr> {
169    let mut repr = None;
170    let mut discriminant = None;
171
172    let parser = syn::meta::parser(|meta| {
173        if meta.path.is_ident("repr") {
174            meta.input.parse::<syn::Token![=]>()?;
175            repr = Some(meta.input.parse::<Ident>()?);
176            Ok(())
177        } else if meta.path.is_ident("discriminant") {
178            let content;
179            syn::parenthesized!(content in meta.input);
180            discriminant = Some(parse_bit_range(&content)?);
181            Ok(())
182        } else {
183            Err(meta.error("expected `repr` or `discriminant`"))
184        }
185    });
186
187    parser.parse2(attr)?;
188
189    let repr = repr.ok_or_else(|| {
190        Error::new(
191            proc_macro2::Span::call_site(),
192            "bit_storage requires `repr = ...`",
193        )
194    })?;
195
196    // Validate repr
197    match repr.to_string().as_str() {
198        "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" => {}
199        _ => return Err(Error::new_spanned(&repr, "repr must be an integer type")),
200    }
201
202    Ok(StorageAttr { repr, discriminant })
203}
204
205fn parse_bit_range(input: syn::parse::ParseStream) -> Result<BitRange> {
206    let mut start = None;
207    let mut len = None;
208
209    while !input.is_empty() {
210        let ident: Ident = input.parse()?;
211        input.parse::<syn::Token![=]>()?;
212        let lit: syn::LitInt = input.parse()?;
213        let value: u32 = lit.base10_parse()?;
214
215        match ident.to_string().as_str() {
216            "start" => start = Some(value),
217            "len" => len = Some(value),
218            _ => return Err(Error::new_spanned(ident, "expected `start` or `len`")),
219        }
220
221        if input.peek(syn::Token![,]) {
222            input.parse::<syn::Token![,]>()?;
223        }
224    }
225
226    let start = start.ok_or_else(|| Error::new(input.span(), "missing `start`"))?;
227    let len = len.ok_or_else(|| Error::new(input.span(), "missing `len`"))?;
228
229    if len == 0 {
230        return Err(Error::new(input.span(), "len must be > 0"));
231    }
232
233    Ok(BitRange { start, len })
234}
235
236fn parse_member(field: &syn::Field) -> Result<MemberDef> {
237    let name = field
238        .ident
239        .clone()
240        .ok_or_else(|| Error::new_spanned(field, "tuple structs not supported"))?;
241    let ty = field.ty.clone();
242
243    for attr in &field.attrs {
244        if attr.path().is_ident("field") {
245            let range = attr.parse_args_with(parse_bit_range)?;
246            return Ok(MemberDef::Field { name, ty, range });
247        } else if attr.path().is_ident("flag") {
248            let bit: syn::LitInt = attr.parse_args()?;
249            let bit: u32 = bit.base10_parse()?;
250            return Ok(MemberDef::Flag { name, bit });
251        }
252    }
253
254    Err(Error::new_spanned(
255        field,
256        "field requires #[field(start = N, len = M)] or #[flag(N)] attribute",
257    ))
258}
259
260fn parse_variant_attr(attrs: &[syn::Attribute]) -> Result<u64> {
261    for attr in attrs {
262        if attr.path().is_ident("variant") {
263            let lit: syn::LitInt = attr.parse_args()?;
264            return lit.base10_parse();
265        }
266    }
267    Err(Error::new(
268        proc_macro2::Span::call_site(),
269        "enum variant requires #[variant(N)] attribute",
270    ))
271}
272
273// =============================================================================
274// Helpers
275// =============================================================================
276
277fn is_primitive(ty: &Type) -> bool {
278    if let Type::Path(type_path) = ty {
279        if let Some(ident) = type_path.path.get_ident() {
280            return matches!(
281                ident.to_string().as_str(),
282                "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128"
283            );
284        }
285    }
286    false
287}
288
289fn repr_bits(repr: &Ident) -> u32 {
290    match repr.to_string().as_str() {
291        "u8" | "i8" => 8,
292        "u16" | "i16" => 16,
293        "u32" | "i32" => 32,
294        "u64" | "i64" => 64,
295        "u128" | "i128" => 128,
296        _ => 0,
297    }
298}
299
300// =============================================================================
301// Validation
302// =============================================================================
303
304fn validate_members(members: &[MemberDef], repr: &Ident) -> Result<()> {
305    let bits = repr_bits(repr);
306
307    // Check each field fits
308    for member in members {
309        match member {
310            MemberDef::Field { name, range, .. } => {
311                if range.start + range.len > bits {
312                    return Err(Error::new_spanned(
313                        name,
314                        format!(
315                            "field exceeds {} bits (start {} + len {} = {})",
316                            bits,
317                            range.start,
318                            range.len,
319                            range.start + range.len
320                        ),
321                    ));
322                }
323            }
324            MemberDef::Flag { name, bit, .. } => {
325                if *bit >= bits {
326                    return Err(Error::new_spanned(
327                        name,
328                        format!("flag bit {} exceeds {} bits", bit, bits),
329                    ));
330                }
331            }
332        }
333    }
334
335    // Check no overlap (simple O(n²) for now)
336    for (i, a) in members.iter().enumerate() {
337        for b in members.iter().skip(i + 1) {
338            if ranges_overlap(a, b) {
339                return Err(Error::new_spanned(
340                    b.name(),
341                    format!("field '{}' overlaps with '{}'", b.name(), a.name()),
342                ));
343            }
344        }
345    }
346
347    Ok(())
348}
349
350fn ranges_overlap(a: &MemberDef, b: &MemberDef) -> bool {
351    let (a_start, a_end) = member_bit_range(a);
352    let (b_start, b_end) = member_bit_range(b);
353    a_start < b_end && b_start < a_end
354}
355
356fn member_bit_range(m: &MemberDef) -> (u32, u32) {
357    match m {
358        MemberDef::Field { range, .. } => (range.start, range.start + range.len),
359        MemberDef::Flag { bit, .. } => (*bit, bit + 1),
360    }
361}
362
363// =============================================================================
364// Struct codegen
365// =============================================================================
366
367fn derive_storage_struct(
368    input: &DeriveInput,
369    data: &syn::DataStruct,
370    storage_attr: StorageAttr,
371) -> Result<TokenStream2> {
372    let fields = match &data.fields {
373        Fields::Named(f) => &f.named,
374        _ => {
375            return Err(Error::new_spanned(
376                input,
377                "bit_storage requires named fields",
378            ));
379        }
380    };
381
382    if storage_attr.discriminant.is_some() {
383        return Err(Error::new_spanned(
384            input,
385            "discriminant is only valid for enums",
386        ));
387    }
388
389    let members: Vec<MemberDef> = fields.iter().map(parse_member).collect::<Result<_>>()?;
390
391    validate_members(&members, &storage_attr.repr)?;
392
393    let vis = &input.vis;
394    let name = &input.ident;
395    let repr = &storage_attr.repr;
396    let builder_name = Ident::new(&format!("{}Builder", name), name.span());
397
398    let newtype = generate_struct_newtype(vis, name, repr);
399    let builder_struct = generate_struct_builder_struct(vis, &builder_name, &members);
400    let newtype_impl = generate_struct_newtype_impl(name, &builder_name, repr, &members);
401    let builder_impl = generate_struct_builder_impl(name, &builder_name, repr, &members);
402
403    Ok(quote! {
404        #newtype
405        #builder_struct
406        #newtype_impl
407        #builder_impl
408    })
409}
410
411fn generate_struct_newtype(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
412    quote! {
413        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
414        #[repr(transparent)]
415        #vis struct #name(#vis #repr);
416    }
417}
418
419fn generate_struct_builder_struct(
420    vis: &syn::Visibility,
421    builder_name: &Ident,
422    members: &[MemberDef],
423) -> TokenStream2 {
424    let fields: Vec<TokenStream2> = members
425        .iter()
426        .map(|m| match m {
427            MemberDef::Field { name, ty, .. } => {
428                quote! { #name: Option<#ty>, }
429            }
430            MemberDef::Flag { name, .. } => {
431                quote! { #name: Option<bool>, }
432            }
433        })
434        .collect();
435
436    quote! {
437        #[derive(Debug, Clone, Copy, Default)]
438        #vis struct #builder_name {
439            #(#fields)*
440        }
441    }
442}
443
444fn generate_struct_newtype_impl(
445    name: &Ident,
446    builder_name: &Ident,
447    repr: &Ident,
448    members: &[MemberDef],
449) -> TokenStream2 {
450    let repr_bit_count = repr_bits(repr);
451
452    let accessors: Vec<TokenStream2> = members.iter().map(|m| {
453        match m {
454            MemberDef::Field { name: field_name, ty, range } => {
455                let start = range.start;
456                let len = range.len;
457                let mask = if len >= repr_bit_count {
458                    quote! { #repr::MAX }
459                } else {
460                    quote! { ((1 as #repr) << #len) - 1 }
461                };
462
463                if is_primitive(ty) {
464                    quote! {
465                        #[inline]
466                        pub const fn #field_name(&self) -> #ty {
467                            ((self.0 >> #start) & #mask) as #ty
468                        }
469                    }
470                } else {
471                    // IntEnum field - returns Result
472                    quote! {
473                        #[inline]
474                        pub fn #field_name(&self) -> Result<#ty, nexus_bits::UnknownDiscriminant<#repr>> {
475                            let field_repr = ((self.0 >> #start) & #mask);
476                            <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
477                                .ok_or(nexus_bits::UnknownDiscriminant {
478                                    field: stringify!(#field_name),
479                                    value: self.0,
480                                })
481                        }
482                    }
483                }
484            }
485            MemberDef::Flag { name: field_name, bit } => {
486                quote! {
487                    #[inline]
488                    pub const fn #field_name(&self) -> bool {
489                        (self.0 >> #bit) & 1 != 0
490                    }
491                }
492            }
493        }
494    }).collect();
495
496    quote! {
497        impl #name {
498            /// Create from raw integer value.
499            #[inline]
500            pub const fn from_raw(raw: #repr) -> Self {
501                Self(raw)
502            }
503
504            /// Get the raw integer value.
505            #[inline]
506            pub const fn raw(self) -> #repr {
507                self.0
508            }
509
510            /// Create a builder for this type.
511            #[inline]
512            pub fn builder() -> #builder_name {
513                #builder_name::default()
514            }
515
516            #(#accessors)*
517        }
518    }
519}
520
521fn generate_struct_builder_impl(
522    name: &Ident,
523    builder_name: &Ident,
524    repr: &Ident,
525    members: &[MemberDef],
526) -> TokenStream2 {
527    let repr_bit_count = repr_bits(repr);
528
529    // Setters - wrap in Some
530    let setters: Vec<TokenStream2> = members
531        .iter()
532        .map(|m| match m {
533            MemberDef::Field {
534                name: field_name,
535                ty,
536                ..
537            } => {
538                quote! {
539                    #[inline]
540                    pub fn #field_name(mut self, val: #ty) -> Self {
541                        self.#field_name = Some(val);
542                        self
543                    }
544                }
545            }
546            MemberDef::Flag {
547                name: field_name, ..
548            } => {
549                quote! {
550                    #[inline]
551                    pub fn #field_name(mut self, val: bool) -> Self {
552                        self.#field_name = Some(val);
553                        self
554                    }
555                }
556            }
557        })
558        .collect();
559
560    // Validations
561    let validations: Vec<TokenStream2> = members
562        .iter()
563        .filter_map(|m| match m {
564            MemberDef::Field {
565                name: field_name,
566                ty,
567                range,
568            } => {
569                let field_str = field_name.to_string();
570                let len = range.len;
571
572                let max_val = if len >= repr_bit_count {
573                    quote! { #repr::MAX }
574                } else {
575                    quote! { ((1 as #repr) << #len) - 1 }
576                };
577
578                if is_primitive(ty) {
579                    let type_bits: u32 = match ty {
580                        Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
581                        Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
582                        Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
583                        Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
584                        Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
585                        _ => 128,
586                    };
587
588                    // Skip validation if field can hold entire type
589                    if len >= type_bits {
590                        return None;
591                    }
592
593                    // Check if this is a signed type
594                    let is_signed = matches!(ty,
595                        Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
596                                         p.path.is_ident("i32") || p.path.is_ident("i64") ||
597                                         p.path.is_ident("i128")
598                    );
599
600                    if is_signed {
601                        // For signed types, check that value fits in signed field range
602                        // A signed N-bit field can hold -(2^(N-1)) to (2^(N-1) - 1)
603                        let min_shift = len - 1;
604                        Some(quote! {
605                            if let Some(v) = self.#field_name {
606                                let min_val = -((1i128 << #min_shift) as i128);
607                                let max_val = ((1i128 << #min_shift) - 1) as i128;
608                                let v_i128 = v as i128;
609                                if v_i128 < min_val || v_i128 > max_val {
610                                    return Err(nexus_bits::FieldOverflow {
611                                        field: #field_str,
612                                        overflow: nexus_bits::Overflow {
613                                            value: (v as #repr),
614                                            max: #max_val,
615                                        },
616                                    });
617                                }
618                            }
619                        })
620                    } else {
621                        // Unsigned - simple max check
622                        Some(quote! {
623                            if let Some(v) = self.#field_name {
624                                if (v as #repr) > #max_val {
625                                    return Err(nexus_bits::FieldOverflow {
626                                        field: #field_str,
627                                        overflow: nexus_bits::Overflow {
628                                            value: v as #repr,
629                                            max: #max_val,
630                                        },
631                                    });
632                                }
633                            }
634                        })
635                    }
636                } else {
637                    // IntEnum field - validate repr value fits in field
638                    Some(quote! {
639                        if let Some(v) = self.#field_name {
640                            let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
641                            if repr_val > #max_val {
642                                return Err(nexus_bits::FieldOverflow {
643                                    field: #field_str,
644                                    overflow: nexus_bits::Overflow {
645                                        value: repr_val,
646                                        max: #max_val,
647                                    },
648                                });
649                            }
650                        }
651                    })
652                }
653            }
654            _ => None,
655        })
656        .collect();
657
658    // Pack statements - ALWAYS mask to prevent sign extension corruption
659    let pack_statements: Vec<TokenStream2> = members
660        .iter()
661        .map(|m| {
662            match m {
663                MemberDef::Field {
664                    name: field_name,
665                    ty,
666                    range,
667                } => {
668                    let start = range.start;
669                    let len = range.len;
670                    let mask = if len >= repr_bit_count {
671                        quote! { #repr::MAX }
672                    } else {
673                        quote! { ((1 as #repr) << #len) - 1 }
674                    };
675
676                    if is_primitive(ty) {
677                        quote! {
678                            if let Some(v) = self.#field_name {
679                                val |= ((v as #repr) & #mask) << #start;
680                            }
681                        }
682                    } else {
683                        // IntEnum
684                        quote! {
685                            if let Some(v) = self.#field_name {
686                                val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
687                            }
688                        }
689                    }
690                }
691                MemberDef::Flag {
692                    name: field_name,
693                    bit,
694                } => {
695                    quote! {
696                        if let Some(true) = self.#field_name {
697                            val |= (1 as #repr) << #bit;
698                        }
699                    }
700                }
701            }
702        })
703        .collect();
704
705    quote! {
706        impl #builder_name {
707            #(#setters)*
708
709            /// Build the final value, validating all fields.
710            #[inline]
711            pub fn build(self) -> Result<#name, nexus_bits::FieldOverflow<#repr>> {
712                // Validate
713                #(#validations)*
714
715                // Pack
716                let mut val: #repr = 0;
717                #(#pack_statements)*
718
719                Ok(#name(val))
720            }
721        }
722    }
723}
724
725// =============================================================================
726// Enum codegen
727// =============================================================================
728
729/// Parsed variant for tagged enum
730struct ParsedVariant {
731    name: Ident,
732    discriminant: u64,
733    members: Vec<MemberDef>,
734}
735
736fn derive_storage_enum(
737    input: &DeriveInput,
738    data: &syn::DataEnum,
739    storage_attr: StorageAttr,
740) -> Result<TokenStream2> {
741    let discriminant = storage_attr.discriminant.ok_or_else(|| {
742        Error::new_spanned(
743            input,
744            "bit_storage enum requires discriminant: #[bit_storage(repr = T, discriminant(start = N, len = M))]",
745        )
746    })?;
747
748    let repr = &storage_attr.repr;
749    let bits = repr_bits(repr);
750
751    // Validate discriminant fits
752    if discriminant.start + discriminant.len > bits {
753        return Err(Error::new_spanned(
754            input,
755            format!(
756                "discriminant exceeds {} bits (start {} + len {} = {})",
757                bits,
758                discriminant.start,
759                discriminant.len,
760                discriminant.start + discriminant.len
761            ),
762        ));
763    }
764
765    let max_discriminant = if discriminant.len >= 64 {
766        u64::MAX
767    } else {
768        (1u64 << discriminant.len) - 1
769    };
770
771    // Parse all variants
772    let mut variants = Vec::new();
773    for variant in &data.variants {
774        let disc = parse_variant_attr(&variant.attrs)?;
775
776        if disc > max_discriminant {
777            return Err(Error::new_spanned(
778                &variant.ident,
779                format!(
780                    "variant discriminant {} exceeds max {} for {}-bit field",
781                    disc, max_discriminant, discriminant.len
782                ),
783            ));
784        }
785
786        // Check for duplicate discriminants
787        for existing in &variants {
788            let existing: &ParsedVariant = existing;
789            if existing.discriminant == disc {
790                return Err(Error::new_spanned(
791                    &variant.ident,
792                    format!(
793                        "duplicate discriminant {}: already used by '{}'",
794                        disc, existing.name
795                    ),
796                ));
797            }
798        }
799
800        let members: Vec<MemberDef> = match &variant.fields {
801            Fields::Named(fields) => fields
802                .named
803                .iter()
804                .map(parse_member)
805                .collect::<Result<_>>()?,
806            Fields::Unit => Vec::new(),
807            Fields::Unnamed(_) => {
808                return Err(Error::new_spanned(
809                    variant,
810                    "tuple variants not supported, use named fields",
811                ));
812            }
813        };
814
815        // Validate members don't overlap with discriminant
816        let disc_range = MemberDef::Field {
817            name: Ident::new("__discriminant", proc_macro2::Span::call_site()),
818            ty: syn::parse_quote!(u64),
819            range: BitRange {
820                start: discriminant.start,
821                len: discriminant.len,
822            },
823        };
824
825        for member in &members {
826            if ranges_overlap(&disc_range, member) {
827                return Err(Error::new_spanned(
828                    member.name(),
829                    format!("field '{}' overlaps with discriminant", member.name()),
830                ));
831            }
832        }
833
834        // Validate members within this variant
835        validate_members(&members, repr)?;
836
837        variants.push(ParsedVariant {
838            name: variant.ident.clone(),
839            discriminant: disc,
840            members,
841        });
842    }
843
844    let vis = &input.vis;
845    let name = &input.ident;
846
847    let parent_type = generate_enum_parent_type(vis, name, repr);
848    let variant_types = generate_enum_variant_types(vis, name, repr, &variants);
849    let kind_enum = generate_enum_kind(vis, name, &variants);
850    let builder_structs = generate_enum_builder_structs(vis, name, &variants);
851    let parent_impl = generate_enum_parent_impl(name, repr, &discriminant, &variants);
852    let variant_impls = generate_enum_variant_impls(name, repr, &variants);
853    let builder_impls = generate_enum_builder_impls(name, repr, &discriminant, &variants);
854    let from_impls = generate_enum_from_impls(name, &variants);
855
856    Ok(quote! {
857        #parent_type
858        #variant_types
859        #kind_enum
860        #builder_structs
861        #parent_impl
862        #variant_impls
863        #builder_impls
864        #from_impls
865    })
866}
867
868fn variant_type_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
869    Ident::new(
870        &format!("{}{}", parent_name, variant_name),
871        variant_name.span(),
872    )
873}
874
875fn variant_builder_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
876    Ident::new(
877        &format!("{}{}Builder", parent_name, variant_name),
878        variant_name.span(),
879    )
880}
881
882fn kind_enum_name(parent_name: &Ident) -> Ident {
883    Ident::new(&format!("{}Kind", parent_name), parent_name.span())
884}
885
886fn generate_enum_parent_type(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
887    quote! {
888        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
889        #[repr(transparent)]
890        #vis struct #name(#vis #repr);
891    }
892}
893
894fn generate_enum_variant_types(
895    vis: &syn::Visibility,
896    parent_name: &Ident,
897    repr: &Ident,
898    variants: &[ParsedVariant],
899) -> TokenStream2 {
900    let types: Vec<TokenStream2> = variants
901        .iter()
902        .map(|v| {
903            let type_name = variant_type_name(parent_name, &v.name);
904            quote! {
905                #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
906                #[repr(transparent)]
907                #vis struct #type_name(#repr);
908            }
909        })
910        .collect();
911
912    quote! { #(#types)* }
913}
914
915fn generate_enum_kind(
916    vis: &syn::Visibility,
917    parent_name: &Ident,
918    variants: &[ParsedVariant],
919) -> TokenStream2 {
920    let kind_name = kind_enum_name(parent_name);
921    let variant_names: Vec<&Ident> = variants.iter().map(|v| &v.name).collect();
922
923    quote! {
924        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
925        #vis enum #kind_name {
926            #(#variant_names),*
927        }
928    }
929}
930
931fn generate_enum_builder_structs(
932    vis: &syn::Visibility,
933    parent_name: &Ident,
934    variants: &[ParsedVariant],
935) -> TokenStream2 {
936    let builders: Vec<TokenStream2> = variants
937        .iter()
938        .map(|v| {
939            let builder_name = variant_builder_name(parent_name, &v.name);
940
941            let fields: Vec<TokenStream2> = v
942                .members
943                .iter()
944                .map(|m| match m {
945                    MemberDef::Field { name, ty, .. } => {
946                        quote! { #name: Option<#ty>, }
947                    }
948                    MemberDef::Flag { name, .. } => {
949                        quote! { #name: Option<bool>, }
950                    }
951                })
952                .collect();
953
954            quote! {
955                #[derive(Debug, Clone, Copy, Default)]
956                #vis struct #builder_name {
957                    #(#fields)*
958                }
959            }
960        })
961        .collect();
962
963    quote! { #(#builders)* }
964}
965
966fn generate_enum_parent_impl(
967    name: &Ident,
968    repr: &Ident,
969    discriminant: &BitRange,
970    variants: &[ParsedVariant],
971) -> TokenStream2 {
972    let repr_bit_count = repr_bits(repr);
973    let kind_name = kind_enum_name(name);
974    let disc_start = discriminant.start;
975    let disc_len = discriminant.len;
976
977    let disc_mask = if disc_len >= repr_bit_count {
978        quote! { #repr::MAX }
979    } else {
980        quote! { ((1 as #repr) << #disc_len) - 1 }
981    };
982
983    // kind() match arms
984    let kind_arms: Vec<TokenStream2> = variants
985        .iter()
986        .map(|v| {
987            let variant_name = &v.name;
988            let disc_val = v.discriminant;
989            quote! {
990                #disc_val => Ok(#kind_name::#variant_name),
991            }
992        })
993        .collect();
994
995    // is_* methods
996    let is_methods: Vec<TokenStream2> = variants
997        .iter()
998        .map(|v| {
999            let variant_name = &v.name;
1000            let method_name = Ident::new(
1001                &format!("is_{}", to_snake_case(&variant_name.to_string())),
1002                variant_name.span(),
1003            );
1004            let disc_val = v.discriminant;
1005            quote! {
1006                #[inline]
1007                pub fn #method_name(&self) -> bool {
1008                    let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1009                    disc == #disc_val
1010                }
1011            }
1012        })
1013        .collect();
1014
1015    // as_* methods
1016    let as_methods: Vec<TokenStream2> = variants
1017        .iter()
1018        .map(|v| {
1019            let variant_name = &v.name;
1020            let variant_type = variant_type_name(name, variant_name);
1021            let method_name = Ident::new(
1022                &format!("as_{}", to_snake_case(&variant_name.to_string())),
1023                variant_name.span(),
1024            );
1025            let disc_val = v.discriminant;
1026
1027            // Validation for IntEnum fields
1028            let validations: Vec<TokenStream2> = v.members
1029                .iter()
1030                .filter_map(|m| {
1031                    if let MemberDef::Field { name: field_name, ty, range } = m {
1032                        if !is_primitive(ty) {
1033                            let start = range.start;
1034                            let len = range.len;
1035                            let repr_bit_count = repr_bits(repr);
1036                            let mask = if len >= repr_bit_count {
1037                                quote! { #repr::MAX }
1038                            } else {
1039                                quote! { ((1 as #repr) << #len) - 1 }
1040                            };
1041                            return Some(quote! {
1042                                let field_repr = ((self.0 >> #start) & #mask);
1043                                if <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _).is_none() {
1044                                    return Err(nexus_bits::UnknownDiscriminant {
1045                                        field: stringify!(#field_name),
1046                                        value: self.0,
1047                                    });
1048                                }
1049                            });
1050                        }
1051                    }
1052                    None
1053                })
1054                .collect();
1055
1056            quote! {
1057                #[inline]
1058                pub fn #method_name(&self) -> Result<#variant_type, nexus_bits::UnknownDiscriminant<#repr>> {
1059                    let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1060                    if disc != #disc_val {
1061                        return Err(nexus_bits::UnknownDiscriminant {
1062                            field: "__discriminant",
1063                            value: self.0,
1064                        });
1065                    }
1066                    #(#validations)*
1067                    Ok(#variant_type(self.0))
1068                }
1069            }
1070        })
1071        .collect();
1072
1073    // Builder shortcut methods
1074    let builder_methods: Vec<TokenStream2> = variants
1075        .iter()
1076        .map(|v| {
1077            let variant_name = &v.name;
1078            let builder_name = variant_builder_name(name, variant_name);
1079            let method_name = Ident::new(
1080                &to_snake_case(&variant_name.to_string()),
1081                variant_name.span(),
1082            );
1083            quote! {
1084                #[inline]
1085                pub fn #method_name() -> #builder_name {
1086                    #builder_name::default()
1087                }
1088            }
1089        })
1090        .collect();
1091
1092    quote! {
1093        impl #name {
1094            /// Create from raw integer value.
1095            #[inline]
1096            pub const fn from_raw(raw: #repr) -> Self {
1097                Self(raw)
1098            }
1099
1100            /// Get the raw integer value.
1101            #[inline]
1102            pub const fn raw(self) -> #repr {
1103                self.0
1104            }
1105
1106            /// Get the kind (discriminant) of this value.
1107            #[inline]
1108            pub fn kind(&self) -> Result<#kind_name, nexus_bits::UnknownDiscriminant<#repr>> {
1109                let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1110                match disc {
1111                    #(#kind_arms)*
1112                    _ => Err(nexus_bits::UnknownDiscriminant {
1113                        field: "__discriminant",
1114                        value: self.0,
1115                    }),
1116                }
1117            }
1118
1119            #(#is_methods)*
1120
1121            #(#as_methods)*
1122
1123            #(#builder_methods)*
1124        }
1125    }
1126}
1127
1128fn generate_enum_variant_impls(
1129    parent_name: &Ident,
1130    repr: &Ident,
1131    variants: &[ParsedVariant],
1132) -> TokenStream2 {
1133    let repr_bit_count = repr_bits(repr);
1134
1135    let impls: Vec<TokenStream2> =
1136        variants
1137            .iter()
1138            .map(|v| {
1139                let variant_name = &v.name;
1140                let variant_type = variant_type_name(parent_name, variant_name);
1141                let builder_name = variant_builder_name(parent_name, variant_name);
1142
1143                // Accessors - infallible since variant is pre-validated
1144                let accessors: Vec<TokenStream2> = v.members
1145                .iter()
1146                .map(|m| {
1147                    match m {
1148                        MemberDef::Field { name: field_name, ty, range } => {
1149                            let start = range.start;
1150                            let len = range.len;
1151                            let mask = if len >= repr_bit_count {
1152                                quote! { #repr::MAX }
1153                            } else {
1154                                quote! { ((1 as #repr) << #len) - 1 }
1155                            };
1156
1157                            if is_primitive(ty) {
1158                                quote! {
1159                                    #[inline]
1160                                    pub const fn #field_name(&self) -> #ty {
1161                                        ((self.0 >> #start) & #mask) as #ty
1162                                    }
1163                                }
1164                            } else {
1165                                // IntEnum - infallible because already validated
1166                                quote! {
1167                                    #[inline]
1168                                    pub fn #field_name(&self) -> #ty {
1169                                        let field_repr = ((self.0 >> #start) & #mask);
1170                                        // SAFETY: This type was validated during construction
1171                                        <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
1172                                            .expect("variant type invariant violated")
1173                                    }
1174                                }
1175                            }
1176                        }
1177                        MemberDef::Flag { name: field_name, bit } => {
1178                            quote! {
1179                                #[inline]
1180                                pub const fn #field_name(&self) -> bool {
1181                                    (self.0 >> #bit) & 1 != 0
1182                                }
1183                            }
1184                        }
1185                    }
1186                })
1187                .collect();
1188
1189                quote! {
1190                    impl #variant_type {
1191                        /// Create a builder for this variant.
1192                        #[inline]
1193                        pub fn builder() -> #builder_name {
1194                            #builder_name::default()
1195                        }
1196
1197                        /// Get the raw integer value.
1198                        #[inline]
1199                        pub const fn raw(self) -> #repr {
1200                            self.0
1201                        }
1202
1203                        /// Convert to parent type.
1204                        #[inline]
1205                        pub const fn as_parent(self) -> #parent_name {
1206                            #parent_name(self.0)
1207                        }
1208
1209                        #(#accessors)*
1210                    }
1211                }
1212            })
1213            .collect();
1214
1215    quote! { #(#impls)* }
1216}
1217
1218fn generate_enum_builder_impls(
1219    parent_name: &Ident,
1220    repr: &Ident,
1221    discriminant: &BitRange,
1222    variants: &[ParsedVariant],
1223) -> TokenStream2 {
1224    let repr_bit_count = repr_bits(repr);
1225    let disc_start = discriminant.start;
1226
1227    let impls: Vec<TokenStream2> = variants
1228        .iter()
1229        .map(|v| {
1230            let variant_name = &v.name;
1231            let variant_type = variant_type_name(parent_name, variant_name);
1232            let builder_name = variant_builder_name(parent_name, variant_name);
1233            let disc_val = v.discriminant;
1234
1235            // Setters
1236            let setters: Vec<TokenStream2> = v.members
1237                .iter()
1238                .map(|m| match m {
1239                    MemberDef::Field { name: field_name, ty, .. } => {
1240                        quote! {
1241                            #[inline]
1242                            pub fn #field_name(mut self, val: #ty) -> Self {
1243                                self.#field_name = Some(val);
1244                                self
1245                            }
1246                        }
1247                    }
1248                    MemberDef::Flag { name: field_name, .. } => {
1249                        quote! {
1250                            #[inline]
1251                            pub fn #field_name(mut self, val: bool) -> Self {
1252                                self.#field_name = Some(val);
1253                                self
1254                            }
1255                        }
1256                    }
1257                })
1258                .collect();
1259
1260            // Validations
1261            let validations: Vec<TokenStream2> = v.members
1262                .iter()
1263                .filter_map(|m| match m {
1264                    MemberDef::Field { name: field_name, ty, range } => {
1265                        let field_str = field_name.to_string();
1266                        let len = range.len;
1267
1268                        let max_val = if len >= repr_bit_count {
1269                            quote! { #repr::MAX }
1270                        } else {
1271                            quote! { ((1 as #repr) << #len) - 1 }
1272                        };
1273
1274                        if is_primitive(ty) {
1275                            let type_bits: u32 = match ty {
1276                                Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
1277                                Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
1278                                Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
1279                                Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
1280                                Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
1281                                _ => 128,
1282                            };
1283
1284                            if len >= type_bits {
1285                                return None;
1286                            }
1287
1288                            let is_signed = matches!(ty,
1289                                Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
1290                                                 p.path.is_ident("i32") || p.path.is_ident("i64") ||
1291                                                 p.path.is_ident("i128")
1292                            );
1293
1294                            if is_signed {
1295                                let min_shift = len - 1;
1296                                Some(quote! {
1297                                    if let Some(v) = self.#field_name {
1298                                        let min_val = -((1i128 << #min_shift) as i128);
1299                                        let max_val = ((1i128 << #min_shift) - 1) as i128;
1300                                        let v_i128 = v as i128;
1301                                        if v_i128 < min_val || v_i128 > max_val {
1302                                            return Err(nexus_bits::FieldOverflow {
1303                                                field: #field_str,
1304                                                overflow: nexus_bits::Overflow {
1305                                                    value: (v as #repr),
1306                                                    max: #max_val,
1307                                                },
1308                                            });
1309                                        }
1310                                    }
1311                                })
1312                            } else {
1313                                Some(quote! {
1314                                    if let Some(v) = self.#field_name {
1315                                        if (v as #repr) > #max_val {
1316                                            return Err(nexus_bits::FieldOverflow {
1317                                                field: #field_str,
1318                                                overflow: nexus_bits::Overflow {
1319                                                    value: v as #repr,
1320                                                    max: #max_val,
1321                                                },
1322                                            });
1323                                        }
1324                                    }
1325                                })
1326                            }
1327                        } else {
1328                            // IntEnum field
1329                            Some(quote! {
1330                                if let Some(v) = self.#field_name {
1331                                    let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
1332                                    if repr_val > #max_val {
1333                                        return Err(nexus_bits::FieldOverflow {
1334                                            field: #field_str,
1335                                            overflow: nexus_bits::Overflow {
1336                                                value: repr_val,
1337                                                max: #max_val,
1338                                            },
1339                                        });
1340                                    }
1341                                }
1342                            })
1343                        }
1344                    }
1345                    _ => None,
1346                })
1347                .collect();
1348
1349            // Pack statements
1350            let pack_statements: Vec<TokenStream2> = v.members
1351                .iter()
1352                .map(|m| {
1353                    match m {
1354                        MemberDef::Field { name: field_name, ty, range } => {
1355                            let start = range.start;
1356                            let len = range.len;
1357                            let mask = if len >= repr_bit_count {
1358                                quote! { #repr::MAX }
1359                            } else {
1360                                quote! { ((1 as #repr) << #len) - 1 }
1361                            };
1362
1363                            if is_primitive(ty) {
1364                                quote! {
1365                                    if let Some(v) = self.#field_name {
1366                                        val |= ((v as #repr) & #mask) << #start;
1367                                    }
1368                                }
1369                            } else {
1370                                quote! {
1371                                    if let Some(v) = self.#field_name {
1372                                        val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
1373                                    }
1374                                }
1375                            }
1376                        }
1377                        MemberDef::Flag { name: field_name, bit } => {
1378                            quote! {
1379                                if let Some(true) = self.#field_name {
1380                                    val |= (1 as #repr) << #bit;
1381                                }
1382                            }
1383                        }
1384                    }
1385                })
1386                .collect();
1387
1388            quote! {
1389                impl #builder_name {
1390                    #(#setters)*
1391
1392                    /// Build the variant type, validating all fields.
1393                    #[inline]
1394                    pub fn build(self) -> Result<#variant_type, nexus_bits::FieldOverflow<#repr>> {
1395                        #(#validations)*
1396
1397                        let mut val: #repr = 0;
1398                        // Set discriminant
1399                        val |= (#disc_val as #repr) << #disc_start;
1400                        #(#pack_statements)*
1401
1402                        Ok(#variant_type(val))
1403                    }
1404
1405                    /// Build directly to parent type, validating all fields.
1406                    #[inline]
1407                    pub fn build_parent(self) -> Result<#parent_name, nexus_bits::FieldOverflow<#repr>> {
1408                        self.build().map(|v| v.as_parent())
1409                    }
1410                }
1411            }
1412        })
1413        .collect();
1414
1415    quote! { #(#impls)* }
1416}
1417
1418fn generate_enum_from_impls(parent_name: &Ident, variants: &[ParsedVariant]) -> TokenStream2 {
1419    let impls: Vec<TokenStream2> = variants
1420        .iter()
1421        .map(|v| {
1422            let variant_type = variant_type_name(parent_name, &v.name);
1423            quote! {
1424                impl From<#variant_type> for #parent_name {
1425                    #[inline]
1426                    fn from(v: #variant_type) -> Self {
1427                        v.as_parent()
1428                    }
1429                }
1430            }
1431        })
1432        .collect();
1433
1434    quote! { #(#impls)* }
1435}
1436
1437fn to_snake_case(s: &str) -> String {
1438    let mut result = String::new();
1439    for (i, c) in s.chars().enumerate() {
1440        if c.is_uppercase() {
1441            if i > 0 {
1442                result.push('_');
1443            }
1444            result.push(c.to_lowercase().next().unwrap());
1445        } else {
1446            result.push(c);
1447        }
1448    }
1449    result
1450}