1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{
5 Data, DeriveInput, Error, Fields, Ident, Result, Type, parse::Parser, parse_macro_input,
6};
7
8#[proc_macro_derive(IntEnum)]
13pub fn derive_int_enum(input: TokenStream) -> TokenStream {
14 let input = parse_macro_input!(input as DeriveInput);
15
16 match derive_int_enum_impl(input) {
17 Ok(tokens) => tokens.into(),
18 Err(err) => err.to_compile_error().into(),
19 }
20}
21
22fn derive_int_enum_impl(input: DeriveInput) -> Result<TokenStream2> {
23 let variants = match &input.data {
24 Data::Enum(data) => &data.variants,
25 _ => {
26 return Err(Error::new_spanned(
27 &input,
28 "IntEnum can only be derived for enums",
29 ));
30 }
31 };
32
33 let repr = parse_repr(&input)?;
34
35 for variant in variants {
36 if !matches!(variant.fields, Fields::Unit) {
37 return Err(Error::new_spanned(
38 variant,
39 "IntEnum variants cannot have fields",
40 ));
41 }
42 }
43
44 let name = &input.ident;
45 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
46
47 let from_arms = variants.iter().map(|v| {
48 let variant_name = &v.ident;
49 quote! {
50 x if x == #name::#variant_name as #repr => Some(#name::#variant_name),
51 }
52 });
53
54 Ok(quote! {
55 impl #impl_generics nexus_bits::IntEnum for #name #ty_generics #where_clause {
56 type Repr = #repr;
57
58 #[inline]
59 fn into_repr(self) -> #repr {
60 self as #repr
61 }
62
63 #[inline]
64 fn try_from_repr(repr: #repr) -> Option<Self> {
65 match repr {
66 #(#from_arms)*
67 _ => None,
68 }
69 }
70 }
71 })
72}
73
74fn parse_repr(input: &DeriveInput) -> Result<Ident> {
75 for attr in &input.attrs {
76 if attr.path().is_ident("repr") {
77 let repr: Ident = attr.parse_args()?;
78 match repr.to_string().as_str() {
79 "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64"
80 | "i128" => {
81 return Ok(repr);
82 }
83 _ => {
84 return Err(Error::new_spanned(
85 repr,
86 "IntEnum requires a primitive integer repr (u8..u128, i8..i128)",
87 ));
88 }
89 }
90 }
91 }
92
93 Err(Error::new_spanned(
94 input,
95 "IntEnum requires a #[repr(u8/u16/u32/u64/i8/i16/i32/i64)] attribute",
96 ))
97}
98
99#[proc_macro_attribute]
104pub fn bit_storage(attr: TokenStream, item: TokenStream) -> TokenStream {
105 let attr = proc_macro2::TokenStream::from(attr);
106 let item = parse_macro_input!(item as DeriveInput);
107
108 match bit_storage_impl(attr, item) {
109 Ok(tokens) => tokens.into(),
110 Err(err) => err.to_compile_error().into(),
111 }
112}
113
114fn bit_storage_impl(attr: TokenStream2, input: DeriveInput) -> Result<TokenStream2> {
115 let storage_attr = parse_storage_attr_tokens(attr)?;
116
117 match &input.data {
118 Data::Struct(data) => derive_storage_struct(&input, data, storage_attr),
119 Data::Enum(data) => derive_storage_enum(&input, data, storage_attr),
120 Data::Union(_) => Err(Error::new_spanned(
121 &input,
122 "bit_storage cannot be applied to unions",
123 )),
124 }
125}
126
127struct StorageAttr {
133 repr: Ident,
134 discriminant: Option<BitRange>,
135}
136
137struct BitRange {
139 start: u32,
140 len: u32,
141}
142
143enum MemberDef {
145 Field {
146 name: Ident,
147 ty: Type,
148 range: BitRange,
149 },
150 Flag {
151 name: Ident,
152 bit: u32,
153 },
154}
155
156impl MemberDef {
157 fn name(&self) -> &Ident {
158 match self {
159 MemberDef::Field { name, .. } => name,
160 MemberDef::Flag { name, .. } => name,
161 }
162 }
163}
164
165fn parse_storage_attr_tokens(attr: TokenStream2) -> Result<StorageAttr> {
170 let mut repr = None;
171 let mut discriminant = None;
172
173 let parser = syn::meta::parser(|meta| {
174 if meta.path.is_ident("repr") {
175 meta.input.parse::<syn::Token![=]>()?;
176 repr = Some(meta.input.parse::<Ident>()?);
177 Ok(())
178 } else if meta.path.is_ident("discriminant") {
179 let content;
180 syn::parenthesized!(content in meta.input);
181 discriminant = Some(parse_bit_range(&content)?);
182 Ok(())
183 } else {
184 Err(meta.error("expected `repr` or `discriminant`"))
185 }
186 });
187
188 parser.parse2(attr)?;
189
190 let repr = repr.ok_or_else(|| {
191 Error::new(
192 proc_macro2::Span::call_site(),
193 "bit_storage requires `repr = ...`",
194 )
195 })?;
196
197 match repr.to_string().as_str() {
199 "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" => {}
200 _ => return Err(Error::new_spanned(&repr, "repr must be an integer type")),
201 }
202
203 Ok(StorageAttr { repr, discriminant })
204}
205
206fn parse_bit_range(input: syn::parse::ParseStream) -> Result<BitRange> {
207 let mut start = None;
208 let mut len = None;
209
210 while !input.is_empty() {
211 let ident: Ident = input.parse()?;
212 input.parse::<syn::Token![=]>()?;
213 let lit: syn::LitInt = input.parse()?;
214 let value: u32 = lit.base10_parse()?;
215
216 match ident.to_string().as_str() {
217 "start" => start = Some(value),
218 "len" => len = Some(value),
219 _ => return Err(Error::new_spanned(ident, "expected `start` or `len`")),
220 }
221
222 if input.peek(syn::Token![,]) {
223 input.parse::<syn::Token![,]>()?;
224 }
225 }
226
227 let start = start.ok_or_else(|| Error::new(input.span(), "missing `start`"))?;
228 let len = len.ok_or_else(|| Error::new(input.span(), "missing `len`"))?;
229
230 if len == 0 {
231 return Err(Error::new(input.span(), "len must be > 0"));
232 }
233
234 Ok(BitRange { start, len })
235}
236
237fn parse_member(field: &syn::Field) -> Result<MemberDef> {
238 let name = field
239 .ident
240 .clone()
241 .ok_or_else(|| Error::new_spanned(field, "tuple structs not supported"))?;
242 let ty = field.ty.clone();
243
244 for attr in &field.attrs {
245 if attr.path().is_ident("field") {
246 let range = attr.parse_args_with(parse_bit_range)?;
247 return Ok(MemberDef::Field { name, ty, range });
248 } else if attr.path().is_ident("flag") {
249 let bit: syn::LitInt = attr.parse_args()?;
250 let bit: u32 = bit.base10_parse()?;
251 return Ok(MemberDef::Flag { name, bit });
252 }
253 }
254
255 Err(Error::new_spanned(
256 field,
257 "field requires #[field(start = N, len = M)] or #[flag(N)] attribute",
258 ))
259}
260
261fn parse_variant_attr(attrs: &[syn::Attribute]) -> Result<u64> {
262 for attr in attrs {
263 if attr.path().is_ident("variant") {
264 let lit: syn::LitInt = attr.parse_args()?;
265 return lit.base10_parse();
266 }
267 }
268 Err(Error::new(
269 proc_macro2::Span::call_site(),
270 "enum variant requires #[variant(N)] attribute",
271 ))
272}
273
274fn is_primitive(ty: &Type) -> bool {
279 if let Type::Path(type_path) = ty {
280 if let Some(ident) = type_path.path.get_ident() {
281 return matches!(
282 ident.to_string().as_str(),
283 "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128"
284 );
285 }
286 }
287 false
288}
289
290fn repr_bits(repr: &Ident) -> u32 {
291 match repr.to_string().as_str() {
292 "u8" | "i8" => 8,
293 "u16" | "i16" => 16,
294 "u32" | "i32" => 32,
295 "u64" | "i64" => 64,
296 "u128" | "i128" => 128,
297 _ => 0,
298 }
299}
300
301fn validate_members(members: &[MemberDef], repr: &Ident) -> Result<()> {
306 let bits = repr_bits(repr);
307
308 for member in members {
310 match member {
311 MemberDef::Field { name, range, .. } => {
312 if range.start + range.len > bits {
313 return Err(Error::new_spanned(
314 name,
315 format!(
316 "field exceeds {} bits (start {} + len {} = {})",
317 bits,
318 range.start,
319 range.len,
320 range.start + range.len
321 ),
322 ));
323 }
324 }
325 MemberDef::Flag { name, bit, .. } => {
326 if *bit >= bits {
327 return Err(Error::new_spanned(
328 name,
329 format!("flag bit {} exceeds {} bits", bit, bits),
330 ));
331 }
332 }
333 }
334 }
335
336 for (i, a) in members.iter().enumerate() {
338 for b in members.iter().skip(i + 1) {
339 if ranges_overlap(a, b) {
340 return Err(Error::new_spanned(
341 b.name(),
342 format!("field '{}' overlaps with '{}'", b.name(), a.name()),
343 ));
344 }
345 }
346 }
347
348 Ok(())
349}
350
351fn ranges_overlap(a: &MemberDef, b: &MemberDef) -> bool {
352 let (a_start, a_end) = member_bit_range(a);
353 let (b_start, b_end) = member_bit_range(b);
354 a_start < b_end && b_start < a_end
355}
356
357fn member_bit_range(m: &MemberDef) -> (u32, u32) {
358 match m {
359 MemberDef::Field { range, .. } => (range.start, range.start + range.len),
360 MemberDef::Flag { bit, .. } => (*bit, bit + 1),
361 }
362}
363
364fn derive_storage_struct(
369 input: &DeriveInput,
370 data: &syn::DataStruct,
371 storage_attr: StorageAttr,
372) -> Result<TokenStream2> {
373 let fields = match &data.fields {
374 Fields::Named(f) => &f.named,
375 _ => {
376 return Err(Error::new_spanned(
377 input,
378 "bit_storage requires named fields",
379 ));
380 }
381 };
382
383 if storage_attr.discriminant.is_some() {
384 return Err(Error::new_spanned(
385 input,
386 "discriminant is only valid for enums",
387 ));
388 }
389
390 let members: Vec<MemberDef> = fields.iter().map(parse_member).collect::<Result<_>>()?;
391
392 validate_members(&members, &storage_attr.repr)?;
393
394 let vis = &input.vis;
395 let name = &input.ident;
396 let repr = &storage_attr.repr;
397 let builder_name = Ident::new(&format!("{}Builder", name), name.span());
398
399 let newtype = generate_struct_newtype(vis, name, repr);
400 let builder_struct = generate_struct_builder_struct(vis, &builder_name, &members);
401 let newtype_impl = generate_struct_newtype_impl(name, &builder_name, repr, &members);
402 let builder_impl = generate_struct_builder_impl(name, &builder_name, repr, &members);
403
404 Ok(quote! {
405 #newtype
406 #builder_struct
407 #newtype_impl
408 #builder_impl
409 })
410}
411
412fn generate_struct_newtype(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
413 quote! {
414 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
415 #[repr(transparent)]
416 #vis struct #name(#vis #repr);
417 }
418}
419
420fn generate_struct_builder_struct(
421 vis: &syn::Visibility,
422 builder_name: &Ident,
423 members: &[MemberDef],
424) -> TokenStream2 {
425 let fields: Vec<TokenStream2> = members
426 .iter()
427 .map(|m| match m {
428 MemberDef::Field { name, ty, .. } => {
429 quote! { #name: Option<#ty>, }
430 }
431 MemberDef::Flag { name, .. } => {
432 quote! { #name: Option<bool>, }
433 }
434 })
435 .collect();
436
437 quote! {
438 #[derive(Debug, Clone, Copy, Default)]
439 #vis struct #builder_name {
440 #(#fields)*
441 }
442 }
443}
444
445fn generate_struct_newtype_impl(
446 name: &Ident,
447 builder_name: &Ident,
448 repr: &Ident,
449 members: &[MemberDef],
450) -> TokenStream2 {
451 let repr_bit_count = repr_bits(repr);
452
453 let accessors: Vec<TokenStream2> = members.iter().map(|m| {
454 match m {
455 MemberDef::Field { name: field_name, ty, range } => {
456 let start = range.start;
457 let len = range.len;
458 let mask = if len >= repr_bit_count {
459 quote! { #repr::MAX }
460 } else {
461 quote! { ((1 as #repr) << #len) - 1 }
462 };
463
464 if is_primitive(ty) {
465 quote! {
466 #[inline]
467 pub const fn #field_name(&self) -> #ty {
468 ((self.0 >> #start) & #mask) as #ty
469 }
470 }
471 } else {
472 quote! {
474 #[inline]
475 pub fn #field_name(&self) -> Result<#ty, nexus_bits::UnknownDiscriminant<#repr>> {
476 let field_repr = ((self.0 >> #start) & #mask);
477 <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
478 .ok_or(nexus_bits::UnknownDiscriminant {
479 field: stringify!(#field_name),
480 value: field_repr as #repr,
481 })
482 }
483 }
484 }
485 }
486 MemberDef::Flag { name: field_name, bit } => {
487 quote! {
488 #[inline]
489 pub const fn #field_name(&self) -> bool {
490 (self.0 >> #bit) & 1 != 0
491 }
492 }
493 }
494 }
495 }).collect();
496
497 quote! {
498 impl #name {
499 #[inline]
501 pub const fn from_raw(raw: #repr) -> Self {
502 Self(raw)
503 }
504
505 #[inline]
507 pub const fn raw(self) -> #repr {
508 self.0
509 }
510
511 #[inline]
513 pub fn builder() -> #builder_name {
514 #builder_name::default()
515 }
516
517 #(#accessors)*
518 }
519 }
520}
521
522fn generate_struct_builder_impl(
523 name: &Ident,
524 builder_name: &Ident,
525 repr: &Ident,
526 members: &[MemberDef],
527) -> TokenStream2 {
528 let repr_bit_count = repr_bits(repr);
529
530 let setters: Vec<TokenStream2> = members
532 .iter()
533 .map(|m| match m {
534 MemberDef::Field {
535 name: field_name,
536 ty,
537 ..
538 } => {
539 quote! {
540 #[inline]
541 pub fn #field_name(mut self, val: #ty) -> Self {
542 self.#field_name = Some(val);
543 self
544 }
545 }
546 }
547 MemberDef::Flag {
548 name: field_name, ..
549 } => {
550 quote! {
551 #[inline]
552 pub fn #field_name(mut self, val: bool) -> Self {
553 self.#field_name = Some(val);
554 self
555 }
556 }
557 }
558 })
559 .collect();
560
561 let validations: Vec<TokenStream2> = members
563 .iter()
564 .filter_map(|m| match m {
565 MemberDef::Field {
566 name: field_name,
567 ty,
568 range,
569 } => {
570 let field_str = field_name.to_string();
571 let len = range.len;
572
573 let max_val = if len >= repr_bit_count {
574 quote! { #repr::MAX }
575 } else {
576 quote! { ((1 as #repr) << #len) - 1 }
577 };
578
579 if is_primitive(ty) {
580 let type_bits: u32 = match ty {
581 Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
582 Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
583 Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
584 Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
585 Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
586 _ => 128,
587 };
588
589 if len >= type_bits {
591 return None;
592 }
593
594 let is_signed = matches!(ty,
596 Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
597 p.path.is_ident("i32") || p.path.is_ident("i64") ||
598 p.path.is_ident("i128")
599 );
600
601 if is_signed {
602 let min_shift = len - 1;
605 Some(quote! {
606 if let Some(v) = self.#field_name {
607 let min_val = -((1i128 << #min_shift) as i128);
608 let max_val = ((1i128 << #min_shift) - 1) as i128;
609 let v_i128 = v as i128;
610 if v_i128 < min_val || v_i128 > max_val {
611 return Err(nexus_bits::FieldOverflow {
612 field: #field_str,
613 overflow: nexus_bits::Overflow {
614 value: (v as #repr),
615 max: #max_val,
616 },
617 });
618 }
619 }
620 })
621 } else {
622 Some(quote! {
624 if let Some(v) = self.#field_name {
625 if (v as #repr) > #max_val {
626 return Err(nexus_bits::FieldOverflow {
627 field: #field_str,
628 overflow: nexus_bits::Overflow {
629 value: v as #repr,
630 max: #max_val,
631 },
632 });
633 }
634 }
635 })
636 }
637 } else {
638 Some(quote! {
640 const _: () = assert!(
641 core::mem::size_of::<<#ty as nexus_bits::IntEnum>::Repr>() <= core::mem::size_of::<#repr>(),
642 "IntEnum repr type is wider than storage repr — values may be truncated"
643 );
644 if let Some(v) = self.#field_name {
645 let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
646 if repr_val > #max_val {
647 return Err(nexus_bits::FieldOverflow {
648 field: #field_str,
649 overflow: nexus_bits::Overflow {
650 value: repr_val,
651 max: #max_val,
652 },
653 });
654 }
655 }
656 })
657 }
658 }
659 _ => None,
660 })
661 .collect();
662
663 let pack_statements: Vec<TokenStream2> = members
665 .iter()
666 .map(|m| {
667 match m {
668 MemberDef::Field {
669 name: field_name,
670 ty,
671 range,
672 } => {
673 let start = range.start;
674 let len = range.len;
675 let mask = if len >= repr_bit_count {
676 quote! { #repr::MAX }
677 } else {
678 quote! { ((1 as #repr) << #len) - 1 }
679 };
680
681 if is_primitive(ty) {
682 quote! {
683 if let Some(v) = self.#field_name {
684 val |= ((v as #repr) & #mask) << #start;
685 }
686 }
687 } else {
688 quote! {
690 if let Some(v) = self.#field_name {
691 val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
692 }
693 }
694 }
695 }
696 MemberDef::Flag {
697 name: field_name,
698 bit,
699 } => {
700 quote! {
701 if let Some(true) = self.#field_name {
702 val |= (1 as #repr) << #bit;
703 }
704 }
705 }
706 }
707 })
708 .collect();
709
710 quote! {
711 impl #builder_name {
712 #(#setters)*
713
714 #[inline]
716 pub fn build(self) -> Result<#name, nexus_bits::FieldOverflow<#repr>> {
717 #(#validations)*
719
720 let mut val: #repr = 0;
722 #(#pack_statements)*
723
724 Ok(#name(val))
725 }
726 }
727 }
728}
729
730struct ParsedVariant {
736 name: Ident,
737 discriminant: u64,
738 members: Vec<MemberDef>,
739}
740
741fn derive_storage_enum(
742 input: &DeriveInput,
743 data: &syn::DataEnum,
744 storage_attr: StorageAttr,
745) -> Result<TokenStream2> {
746 let discriminant = storage_attr.discriminant.ok_or_else(|| {
747 Error::new_spanned(
748 input,
749 "bit_storage enum requires discriminant: #[bit_storage(repr = T, discriminant(start = N, len = M))]",
750 )
751 })?;
752
753 let repr = &storage_attr.repr;
754 let bits = repr_bits(repr);
755
756 if discriminant.start + discriminant.len > bits {
758 return Err(Error::new_spanned(
759 input,
760 format!(
761 "discriminant exceeds {} bits (start {} + len {} = {})",
762 bits,
763 discriminant.start,
764 discriminant.len,
765 discriminant.start + discriminant.len
766 ),
767 ));
768 }
769
770 let max_discriminant = if discriminant.len >= 64 {
771 u64::MAX
772 } else {
773 (1u64 << discriminant.len) - 1
774 };
775
776 let mut variants = Vec::new();
778 for variant in &data.variants {
779 let disc = parse_variant_attr(&variant.attrs)?;
780
781 if disc > max_discriminant {
782 return Err(Error::new_spanned(
783 &variant.ident,
784 format!(
785 "variant discriminant {} exceeds max {} for {}-bit field",
786 disc, max_discriminant, discriminant.len
787 ),
788 ));
789 }
790
791 for existing in &variants {
793 let existing: &ParsedVariant = existing;
794 if existing.discriminant == disc {
795 return Err(Error::new_spanned(
796 &variant.ident,
797 format!(
798 "duplicate discriminant {}: already used by '{}'",
799 disc, existing.name
800 ),
801 ));
802 }
803 }
804
805 let members: Vec<MemberDef> = match &variant.fields {
806 Fields::Named(fields) => fields
807 .named
808 .iter()
809 .map(parse_member)
810 .collect::<Result<_>>()?,
811 Fields::Unit => Vec::new(),
812 Fields::Unnamed(_) => {
813 return Err(Error::new_spanned(
814 variant,
815 "tuple variants not supported, use named fields",
816 ));
817 }
818 };
819
820 let disc_range = MemberDef::Field {
822 name: Ident::new("__discriminant", proc_macro2::Span::call_site()),
823 ty: syn::parse_quote!(u64),
824 range: BitRange {
825 start: discriminant.start,
826 len: discriminant.len,
827 },
828 };
829
830 for member in &members {
831 if ranges_overlap(&disc_range, member) {
832 return Err(Error::new_spanned(
833 member.name(),
834 format!("field '{}' overlaps with discriminant", member.name()),
835 ));
836 }
837 }
838
839 validate_members(&members, repr)?;
841
842 variants.push(ParsedVariant {
843 name: variant.ident.clone(),
844 discriminant: disc,
845 members,
846 });
847 }
848
849 let vis = &input.vis;
850 let name = &input.ident;
851
852 let parent_type = generate_enum_parent_type(vis, name, repr);
853 let variant_types = generate_enum_variant_types(vis, name, repr, &variants);
854 let kind_enum = generate_enum_kind(vis, name, &variants);
855 let builder_structs = generate_enum_builder_structs(vis, name, &variants);
856 let parent_impl = generate_enum_parent_impl(name, repr, &discriminant, &variants);
857 let variant_impls = generate_enum_variant_impls(name, repr, &variants);
858 let builder_impls = generate_enum_builder_impls(name, repr, &discriminant, &variants);
859 let from_impls = generate_enum_from_impls(name, &variants);
860
861 Ok(quote! {
862 #parent_type
863 #variant_types
864 #kind_enum
865 #builder_structs
866 #parent_impl
867 #variant_impls
868 #builder_impls
869 #from_impls
870 })
871}
872
873fn variant_type_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
874 Ident::new(
875 &format!("{}{}", parent_name, variant_name),
876 variant_name.span(),
877 )
878}
879
880fn variant_builder_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
881 Ident::new(
882 &format!("{}{}Builder", parent_name, variant_name),
883 variant_name.span(),
884 )
885}
886
887fn kind_enum_name(parent_name: &Ident) -> Ident {
888 Ident::new(&format!("{}Kind", parent_name), parent_name.span())
889}
890
891fn generate_enum_parent_type(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
892 quote! {
893 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
894 #[repr(transparent)]
895 #vis struct #name(#vis #repr);
896 }
897}
898
899fn generate_enum_variant_types(
900 vis: &syn::Visibility,
901 parent_name: &Ident,
902 repr: &Ident,
903 variants: &[ParsedVariant],
904) -> TokenStream2 {
905 let types: Vec<TokenStream2> = variants
906 .iter()
907 .map(|v| {
908 let type_name = variant_type_name(parent_name, &v.name);
909 quote! {
910 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
911 #[repr(transparent)]
912 #vis struct #type_name(#repr);
913 }
914 })
915 .collect();
916
917 quote! { #(#types)* }
918}
919
920fn generate_enum_kind(
921 vis: &syn::Visibility,
922 parent_name: &Ident,
923 variants: &[ParsedVariant],
924) -> TokenStream2 {
925 let kind_name = kind_enum_name(parent_name);
926 let variant_names: Vec<&Ident> = variants.iter().map(|v| &v.name).collect();
927
928 quote! {
929 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
930 #vis enum #kind_name {
931 #(#variant_names),*
932 }
933 }
934}
935
936fn generate_enum_builder_structs(
937 vis: &syn::Visibility,
938 parent_name: &Ident,
939 variants: &[ParsedVariant],
940) -> TokenStream2 {
941 let builders: Vec<TokenStream2> = variants
942 .iter()
943 .map(|v| {
944 let builder_name = variant_builder_name(parent_name, &v.name);
945
946 let fields: Vec<TokenStream2> = v
947 .members
948 .iter()
949 .map(|m| match m {
950 MemberDef::Field { name, ty, .. } => {
951 quote! { #name: Option<#ty>, }
952 }
953 MemberDef::Flag { name, .. } => {
954 quote! { #name: Option<bool>, }
955 }
956 })
957 .collect();
958
959 quote! {
960 #[derive(Debug, Clone, Copy, Default)]
961 #vis struct #builder_name {
962 #(#fields)*
963 }
964 }
965 })
966 .collect();
967
968 quote! { #(#builders)* }
969}
970
971fn generate_enum_parent_impl(
972 name: &Ident,
973 repr: &Ident,
974 discriminant: &BitRange,
975 variants: &[ParsedVariant],
976) -> TokenStream2 {
977 let repr_bit_count = repr_bits(repr);
978 let kind_name = kind_enum_name(name);
979 let disc_start = discriminant.start;
980 let disc_len = discriminant.len;
981
982 assert!(
985 disc_len <= 64,
986 "discriminant length must be <= 64 bits (got {disc_len})"
987 );
988
989 let disc_mask = if disc_len >= repr_bit_count {
990 quote! { #repr::MAX }
991 } else {
992 quote! { ((1 as #repr) << #disc_len) - 1 }
993 };
994
995 let kind_arms: Vec<TokenStream2> = variants
997 .iter()
998 .map(|v| {
999 let variant_name = &v.name;
1000 let disc_val = v.discriminant;
1001 quote! {
1002 #disc_val => Ok(#kind_name::#variant_name),
1003 }
1004 })
1005 .collect();
1006
1007 let is_methods: Vec<TokenStream2> = variants
1009 .iter()
1010 .map(|v| {
1011 let variant_name = &v.name;
1012 let method_name = Ident::new(
1013 &format!("is_{}", to_snake_case(&variant_name.to_string())),
1014 variant_name.span(),
1015 );
1016 let disc_val = v.discriminant;
1017 quote! {
1018 #[inline]
1019 pub fn #method_name(&self) -> bool {
1020 let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1021 disc == #disc_val
1022 }
1023 }
1024 })
1025 .collect();
1026
1027 let as_methods: Vec<TokenStream2> = variants
1029 .iter()
1030 .map(|v| {
1031 let variant_name = &v.name;
1032 let variant_type = variant_type_name(name, variant_name);
1033 let method_name = Ident::new(
1034 &format!("as_{}", to_snake_case(&variant_name.to_string())),
1035 variant_name.span(),
1036 );
1037 let disc_val = v.discriminant;
1038
1039 let validations: Vec<TokenStream2> = v.members
1041 .iter()
1042 .filter_map(|m| {
1043 if let MemberDef::Field { name: field_name, ty, range } = m {
1044 if !is_primitive(ty) {
1045 let start = range.start;
1046 let len = range.len;
1047 let repr_bit_count = repr_bits(repr);
1048 let mask = if len >= repr_bit_count {
1049 quote! { #repr::MAX }
1050 } else {
1051 quote! { ((1 as #repr) << #len) - 1 }
1052 };
1053 return Some(quote! {
1054 let field_repr = ((self.0 >> #start) & #mask);
1055 if <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _).is_none() {
1056 return Err(nexus_bits::UnknownDiscriminant {
1057 field: stringify!(#field_name),
1058 value: field_repr as #repr,
1059 });
1060 }
1061 });
1062 }
1063 }
1064 None
1065 })
1066 .collect();
1067
1068 quote! {
1069 #[inline]
1070 pub fn #method_name(&self) -> Result<#variant_type, nexus_bits::UnknownDiscriminant<#repr>> {
1071 let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1072 if disc != #disc_val {
1073 return Err(nexus_bits::UnknownDiscriminant {
1074 field: "__discriminant",
1075 value: disc as #repr,
1076 });
1077 }
1078 #(#validations)*
1079 Ok(#variant_type(self.0))
1080 }
1081 }
1082 })
1083 .collect();
1084
1085 let builder_methods: Vec<TokenStream2> = variants
1087 .iter()
1088 .map(|v| {
1089 let variant_name = &v.name;
1090 let builder_name = variant_builder_name(name, variant_name);
1091 let method_name = Ident::new(
1092 &to_snake_case(&variant_name.to_string()),
1093 variant_name.span(),
1094 );
1095 quote! {
1096 #[inline]
1097 pub fn #method_name() -> #builder_name {
1098 #builder_name::default()
1099 }
1100 }
1101 })
1102 .collect();
1103
1104 quote! {
1105 impl #name {
1106 #[inline]
1108 pub const fn from_raw(raw: #repr) -> Self {
1109 Self(raw)
1110 }
1111
1112 #[inline]
1114 pub const fn raw(self) -> #repr {
1115 self.0
1116 }
1117
1118 #[inline]
1120 pub fn kind(&self) -> Result<#kind_name, nexus_bits::UnknownDiscriminant<#repr>> {
1121 let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1122 match disc {
1123 #(#kind_arms)*
1124 _ => Err(nexus_bits::UnknownDiscriminant {
1125 field: "__discriminant",
1126 value: disc as #repr,
1127 }),
1128 }
1129 }
1130
1131 #(#is_methods)*
1132
1133 #(#as_methods)*
1134
1135 #(#builder_methods)*
1136 }
1137 }
1138}
1139
1140fn generate_enum_variant_impls(
1141 parent_name: &Ident,
1142 repr: &Ident,
1143 variants: &[ParsedVariant],
1144) -> TokenStream2 {
1145 let repr_bit_count = repr_bits(repr);
1146
1147 let impls: Vec<TokenStream2> =
1148 variants
1149 .iter()
1150 .map(|v| {
1151 let variant_name = &v.name;
1152 let variant_type = variant_type_name(parent_name, variant_name);
1153 let builder_name = variant_builder_name(parent_name, variant_name);
1154
1155 let accessors: Vec<TokenStream2> = v.members
1157 .iter()
1158 .map(|m| {
1159 match m {
1160 MemberDef::Field { name: field_name, ty, range } => {
1161 let start = range.start;
1162 let len = range.len;
1163 let mask = if len >= repr_bit_count {
1164 quote! { #repr::MAX }
1165 } else {
1166 quote! { ((1 as #repr) << #len) - 1 }
1167 };
1168
1169 if is_primitive(ty) {
1170 quote! {
1171 #[inline]
1172 pub const fn #field_name(&self) -> #ty {
1173 ((self.0 >> #start) & #mask) as #ty
1174 }
1175 }
1176 } else {
1177 quote! {
1179 #[inline]
1180 pub fn #field_name(&self) -> #ty {
1181 let field_repr = ((self.0 >> #start) & #mask);
1182 <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
1184 .expect("variant type invariant violated")
1185 }
1186 }
1187 }
1188 }
1189 MemberDef::Flag { name: field_name, bit } => {
1190 quote! {
1191 #[inline]
1192 pub const fn #field_name(&self) -> bool {
1193 (self.0 >> #bit) & 1 != 0
1194 }
1195 }
1196 }
1197 }
1198 })
1199 .collect();
1200
1201 quote! {
1202 impl #variant_type {
1203 #[inline]
1205 pub fn builder() -> #builder_name {
1206 #builder_name::default()
1207 }
1208
1209 #[inline]
1211 pub const fn raw(self) -> #repr {
1212 self.0
1213 }
1214
1215 #[inline]
1217 pub const fn as_parent(self) -> #parent_name {
1218 #parent_name(self.0)
1219 }
1220
1221 #(#accessors)*
1222 }
1223 }
1224 })
1225 .collect();
1226
1227 quote! { #(#impls)* }
1228}
1229
1230fn generate_enum_builder_impls(
1231 parent_name: &Ident,
1232 repr: &Ident,
1233 discriminant: &BitRange,
1234 variants: &[ParsedVariant],
1235) -> TokenStream2 {
1236 let repr_bit_count = repr_bits(repr);
1237 let disc_start = discriminant.start;
1238
1239 let impls: Vec<TokenStream2> = variants
1240 .iter()
1241 .map(|v| {
1242 let variant_name = &v.name;
1243 let variant_type = variant_type_name(parent_name, variant_name);
1244 let builder_name = variant_builder_name(parent_name, variant_name);
1245 let disc_val = v.discriminant;
1246
1247 let setters: Vec<TokenStream2> = v.members
1249 .iter()
1250 .map(|m| match m {
1251 MemberDef::Field { name: field_name, ty, .. } => {
1252 quote! {
1253 #[inline]
1254 pub fn #field_name(mut self, val: #ty) -> Self {
1255 self.#field_name = Some(val);
1256 self
1257 }
1258 }
1259 }
1260 MemberDef::Flag { name: field_name, .. } => {
1261 quote! {
1262 #[inline]
1263 pub fn #field_name(mut self, val: bool) -> Self {
1264 self.#field_name = Some(val);
1265 self
1266 }
1267 }
1268 }
1269 })
1270 .collect();
1271
1272 let validations: Vec<TokenStream2> = v.members
1274 .iter()
1275 .filter_map(|m| match m {
1276 MemberDef::Field { name: field_name, ty, range } => {
1277 let field_str = field_name.to_string();
1278 let len = range.len;
1279
1280 let max_val = if len >= repr_bit_count {
1281 quote! { #repr::MAX }
1282 } else {
1283 quote! { ((1 as #repr) << #len) - 1 }
1284 };
1285
1286 if is_primitive(ty) {
1287 let type_bits: u32 = match ty {
1288 Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
1289 Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
1290 Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
1291 Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
1292 Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
1293 _ => 128,
1294 };
1295
1296 if len >= type_bits {
1297 return None;
1298 }
1299
1300 let is_signed = matches!(ty,
1301 Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
1302 p.path.is_ident("i32") || p.path.is_ident("i64") ||
1303 p.path.is_ident("i128")
1304 );
1305
1306 if is_signed {
1307 let min_shift = len - 1;
1308 Some(quote! {
1309 if let Some(v) = self.#field_name {
1310 let min_val = -((1i128 << #min_shift) as i128);
1311 let max_val = ((1i128 << #min_shift) - 1) as i128;
1312 let v_i128 = v as i128;
1313 if v_i128 < min_val || v_i128 > max_val {
1314 return Err(nexus_bits::FieldOverflow {
1315 field: #field_str,
1316 overflow: nexus_bits::Overflow {
1317 value: (v as #repr),
1318 max: #max_val,
1319 },
1320 });
1321 }
1322 }
1323 })
1324 } else {
1325 Some(quote! {
1326 if let Some(v) = self.#field_name {
1327 if (v as #repr) > #max_val {
1328 return Err(nexus_bits::FieldOverflow {
1329 field: #field_str,
1330 overflow: nexus_bits::Overflow {
1331 value: v as #repr,
1332 max: #max_val,
1333 },
1334 });
1335 }
1336 }
1337 })
1338 }
1339 } else {
1340 Some(quote! {
1342 if let Some(v) = self.#field_name {
1343 let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
1344 if repr_val > #max_val {
1345 return Err(nexus_bits::FieldOverflow {
1346 field: #field_str,
1347 overflow: nexus_bits::Overflow {
1348 value: repr_val,
1349 max: #max_val,
1350 },
1351 });
1352 }
1353 }
1354 })
1355 }
1356 }
1357 _ => None,
1358 })
1359 .collect();
1360
1361 let pack_statements: Vec<TokenStream2> = v.members
1363 .iter()
1364 .map(|m| {
1365 match m {
1366 MemberDef::Field { name: field_name, ty, range } => {
1367 let start = range.start;
1368 let len = range.len;
1369 let mask = if len >= repr_bit_count {
1370 quote! { #repr::MAX }
1371 } else {
1372 quote! { ((1 as #repr) << #len) - 1 }
1373 };
1374
1375 if is_primitive(ty) {
1376 quote! {
1377 if let Some(v) = self.#field_name {
1378 val |= ((v as #repr) & #mask) << #start;
1379 }
1380 }
1381 } else {
1382 quote! {
1383 if let Some(v) = self.#field_name {
1384 val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
1385 }
1386 }
1387 }
1388 }
1389 MemberDef::Flag { name: field_name, bit } => {
1390 quote! {
1391 if let Some(true) = self.#field_name {
1392 val |= (1 as #repr) << #bit;
1393 }
1394 }
1395 }
1396 }
1397 })
1398 .collect();
1399
1400 quote! {
1401 impl #builder_name {
1402 #(#setters)*
1403
1404 #[inline]
1406 pub fn build(self) -> Result<#variant_type, nexus_bits::FieldOverflow<#repr>> {
1407 #(#validations)*
1408
1409 let mut val: #repr = 0;
1410 val |= (#disc_val as #repr) << #disc_start;
1412 #(#pack_statements)*
1413
1414 Ok(#variant_type(val))
1415 }
1416
1417 #[inline]
1419 pub fn build_parent(self) -> Result<#parent_name, nexus_bits::FieldOverflow<#repr>> {
1420 self.build().map(|v| v.as_parent())
1421 }
1422 }
1423 }
1424 })
1425 .collect();
1426
1427 quote! { #(#impls)* }
1428}
1429
1430fn generate_enum_from_impls(parent_name: &Ident, variants: &[ParsedVariant]) -> TokenStream2 {
1431 let impls: Vec<TokenStream2> = variants
1432 .iter()
1433 .map(|v| {
1434 let variant_type = variant_type_name(parent_name, &v.name);
1435 quote! {
1436 impl From<#variant_type> for #parent_name {
1437 #[inline]
1438 fn from(v: #variant_type) -> Self {
1439 v.as_parent()
1440 }
1441 }
1442 }
1443 })
1444 .collect();
1445
1446 quote! { #(#impls)* }
1447}
1448
1449fn to_snake_case(s: &str) -> String {
1450 let mut result = String::new();
1451 for (i, c) in s.chars().enumerate() {
1452 if c.is_uppercase() {
1453 if i > 0 {
1454 result.push('_');
1455 }
1456 result.push(c.to_lowercase().next().unwrap());
1457 } else {
1458 result.push(c);
1459 }
1460 }
1461 result
1462}