1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{
5 Data, DeriveInput, Error, Fields, Ident, Result, Type, parse::Parser, parse_macro_input,
6};
7
8#[proc_macro_derive(IntEnum)]
13pub fn derive_int_enum(input: TokenStream) -> TokenStream {
14 let input = parse_macro_input!(input as DeriveInput);
15
16 match derive_int_enum_impl(&input) {
17 Ok(tokens) => tokens.into(),
18 Err(err) => err.to_compile_error().into(),
19 }
20}
21
22fn derive_int_enum_impl(input: &DeriveInput) -> Result<TokenStream2> {
23 let variants = match &input.data {
24 Data::Enum(data) => &data.variants,
25 _ => {
26 return Err(Error::new_spanned(
27 input,
28 "IntEnum can only be derived for enums",
29 ));
30 }
31 };
32
33 let repr = parse_repr(input)?;
34
35 for variant in variants {
36 if !matches!(variant.fields, Fields::Unit) {
37 return Err(Error::new_spanned(
38 variant,
39 "IntEnum variants cannot have fields",
40 ));
41 }
42 }
43
44 let name = &input.ident;
45 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
46
47 let from_arms = variants.iter().map(|v| {
48 let variant_name = &v.ident;
49 quote! {
50 x if x == #name::#variant_name as #repr => Some(#name::#variant_name),
51 }
52 });
53
54 Ok(quote! {
55 impl #impl_generics nexus_bits::IntEnum for #name #ty_generics #where_clause {
56 type Repr = #repr;
57
58 #[inline]
59 fn into_repr(self) -> #repr {
60 self as #repr
61 }
62
63 #[inline]
64 fn try_from_repr(repr: #repr) -> Option<Self> {
65 match repr {
66 #(#from_arms)*
67 _ => None,
68 }
69 }
70 }
71 })
72}
73
74fn parse_repr(input: &DeriveInput) -> Result<Ident> {
75 for attr in &input.attrs {
76 if attr.path().is_ident("repr") {
77 let repr: Ident = attr.parse_args()?;
78 match repr.to_string().as_str() {
79 "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" => {
80 return Ok(repr);
81 }
82 _ => {
83 return Err(Error::new_spanned(
84 repr,
85 "IntEnum requires a primitive integer repr (u8..u128, i8..i128)",
86 ));
87 }
88 }
89 }
90 }
91
92 Err(Error::new_spanned(
93 input,
94 "IntEnum requires a #[repr(u8/u16/u32/u64/i8/i16/i32/i64)] attribute",
95 ))
96}
97
98#[proc_macro_attribute]
103pub fn bit_storage(attr: TokenStream, item: TokenStream) -> TokenStream {
104 let attr = proc_macro2::TokenStream::from(attr);
105 let item = parse_macro_input!(item as DeriveInput);
106
107 match bit_storage_impl(attr, &item) {
108 Ok(tokens) => tokens.into(),
109 Err(err) => err.to_compile_error().into(),
110 }
111}
112
113fn bit_storage_impl(attr: TokenStream2, input: &DeriveInput) -> Result<TokenStream2> {
114 let storage_attr = parse_storage_attr_tokens(attr)?;
115
116 match &input.data {
117 Data::Struct(data) => derive_storage_struct(input, data, &storage_attr),
118 Data::Enum(data) => derive_storage_enum(input, data, &storage_attr),
119 Data::Union(_) => Err(Error::new_spanned(
120 input,
121 "bit_storage cannot be applied to unions",
122 )),
123 }
124}
125
126struct StorageAttr {
132 repr: Ident,
133 discriminant: Option<BitRange>,
134}
135
136#[derive(Clone, Copy)]
138struct BitRange {
139 start: u32,
140 len: u32,
141}
142
143#[allow(clippy::large_enum_variant)]
147enum MemberDef {
148 Field {
149 name: Ident,
150 ty: Type,
151 range: BitRange,
152 },
153 Flag {
154 name: Ident,
155 bit: u32,
156 },
157}
158
159impl MemberDef {
160 fn name(&self) -> &Ident {
161 match self {
162 MemberDef::Field { name, .. } | MemberDef::Flag { name, .. } => name,
163 }
164 }
165}
166
167fn parse_storage_attr_tokens(attr: TokenStream2) -> Result<StorageAttr> {
172 let mut repr = None;
173 let mut discriminant = None;
174
175 let parser = syn::meta::parser(|meta| {
176 if meta.path.is_ident("repr") {
177 meta.input.parse::<syn::Token![=]>()?;
178 repr = Some(meta.input.parse::<Ident>()?);
179 Ok(())
180 } else if meta.path.is_ident("discriminant") {
181 let content;
182 syn::parenthesized!(content in meta.input);
183 discriminant = Some(parse_bit_range(&content)?);
184 Ok(())
185 } else {
186 Err(meta.error("expected `repr` or `discriminant`"))
187 }
188 });
189
190 parser.parse2(attr)?;
191
192 let repr = repr.ok_or_else(|| {
193 Error::new(
194 proc_macro2::Span::call_site(),
195 "bit_storage requires `repr = ...`",
196 )
197 })?;
198
199 match repr.to_string().as_str() {
201 "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" => {}
202 _ => return Err(Error::new_spanned(&repr, "repr must be an integer type")),
203 }
204
205 Ok(StorageAttr { repr, discriminant })
206}
207
208fn parse_bit_range(input: syn::parse::ParseStream) -> Result<BitRange> {
209 let mut start = None;
210 let mut len = None;
211
212 while !input.is_empty() {
213 let ident: Ident = input.parse()?;
214 input.parse::<syn::Token![=]>()?;
215 let lit: syn::LitInt = input.parse()?;
216 let value: u32 = lit.base10_parse()?;
217
218 match ident.to_string().as_str() {
219 "start" => start = Some(value),
220 "len" => len = Some(value),
221 _ => return Err(Error::new_spanned(ident, "expected `start` or `len`")),
222 }
223
224 if input.peek(syn::Token![,]) {
225 input.parse::<syn::Token![,]>()?;
226 }
227 }
228
229 let start = start.ok_or_else(|| Error::new(input.span(), "missing `start`"))?;
230 let len = len.ok_or_else(|| Error::new(input.span(), "missing `len`"))?;
231
232 if len == 0 {
233 return Err(Error::new(input.span(), "len must be > 0"));
234 }
235
236 Ok(BitRange { start, len })
237}
238
239fn parse_member(field: &syn::Field) -> Result<MemberDef> {
240 let name = field
241 .ident
242 .clone()
243 .ok_or_else(|| Error::new_spanned(field, "tuple structs not supported"))?;
244 let ty = field.ty.clone();
245
246 for attr in &field.attrs {
247 if attr.path().is_ident("field") {
248 let range = attr.parse_args_with(parse_bit_range)?;
249 return Ok(MemberDef::Field { name, ty, range });
250 } else if attr.path().is_ident("flag") {
251 let bit: syn::LitInt = attr.parse_args()?;
252 let bit: u32 = bit.base10_parse()?;
253 return Ok(MemberDef::Flag { name, bit });
254 }
255 }
256
257 Err(Error::new_spanned(
258 field,
259 "field requires #[field(start = N, len = M)] or #[flag(N)] attribute",
260 ))
261}
262
263fn parse_variant_attr(attrs: &[syn::Attribute]) -> Result<u64> {
264 for attr in attrs {
265 if attr.path().is_ident("variant") {
266 let lit: syn::LitInt = attr.parse_args()?;
267 return lit.base10_parse();
268 }
269 }
270 Err(Error::new(
271 proc_macro2::Span::call_site(),
272 "enum variant requires #[variant(N)] attribute",
273 ))
274}
275
276fn is_primitive(ty: &Type) -> bool {
281 if let Type::Path(type_path) = ty {
282 if let Some(ident) = type_path.path.get_ident() {
283 return matches!(
284 ident.to_string().as_str(),
285 "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128"
286 );
287 }
288 }
289 false
290}
291
292fn repr_bits(repr: &Ident) -> u32 {
293 match repr.to_string().as_str() {
294 "u8" | "i8" => 8,
295 "u16" | "i16" => 16,
296 "u32" | "i32" => 32,
297 "u64" | "i64" => 64,
298 "u128" | "i128" => 128,
299 _ => 0,
300 }
301}
302
303fn validate_members(members: &[MemberDef], repr: &Ident) -> Result<()> {
308 let bits = repr_bits(repr);
309
310 for member in members {
312 match member {
313 MemberDef::Field { name, range, .. } => {
314 if range.start + range.len > bits {
315 return Err(Error::new_spanned(
316 name,
317 format!(
318 "field exceeds {} bits (start {} + len {} = {})",
319 bits,
320 range.start,
321 range.len,
322 range.start + range.len
323 ),
324 ));
325 }
326 }
327 MemberDef::Flag { name, bit, .. } => {
328 if *bit >= bits {
329 return Err(Error::new_spanned(
330 name,
331 format!("flag bit {} exceeds {} bits", bit, bits),
332 ));
333 }
334 }
335 }
336 }
337
338 for (i, a) in members.iter().enumerate() {
340 for b in members.iter().skip(i + 1) {
341 if ranges_overlap(a, b) {
342 return Err(Error::new_spanned(
343 b.name(),
344 format!("field '{}' overlaps with '{}'", b.name(), a.name()),
345 ));
346 }
347 }
348 }
349
350 Ok(())
351}
352
353fn ranges_overlap(a: &MemberDef, b: &MemberDef) -> bool {
354 let (a_start, a_end) = member_bit_range(a);
355 let (b_start, b_end) = member_bit_range(b);
356 a_start < b_end && b_start < a_end
357}
358
359fn member_bit_range(m: &MemberDef) -> (u32, u32) {
360 match m {
361 MemberDef::Field { range, .. } => (range.start, range.start + range.len),
362 MemberDef::Flag { bit, .. } => (*bit, bit + 1),
363 }
364}
365
366fn derive_storage_struct(
371 input: &DeriveInput,
372 data: &syn::DataStruct,
373 storage_attr: &StorageAttr,
374) -> Result<TokenStream2> {
375 let fields = match &data.fields {
376 Fields::Named(f) => &f.named,
377 _ => {
378 return Err(Error::new_spanned(
379 input,
380 "bit_storage requires named fields",
381 ));
382 }
383 };
384
385 if storage_attr.discriminant.is_some() {
386 return Err(Error::new_spanned(
387 input,
388 "discriminant is only valid for enums",
389 ));
390 }
391
392 let members: Vec<MemberDef> = fields.iter().map(parse_member).collect::<Result<_>>()?;
393
394 validate_members(&members, &storage_attr.repr)?;
395
396 let vis = &input.vis;
397 let name = &input.ident;
398 let repr = &storage_attr.repr;
399 let builder_name = Ident::new(&format!("{}Builder", name), name.span());
400
401 let newtype = generate_struct_newtype(vis, name, repr);
402 let builder_struct = generate_struct_builder_struct(vis, &builder_name, &members);
403 let newtype_impl = generate_struct_newtype_impl(name, &builder_name, repr, &members);
404 let builder_impl = generate_struct_builder_impl(name, &builder_name, repr, &members);
405
406 Ok(quote! {
407 #newtype
408 #builder_struct
409 #newtype_impl
410 #builder_impl
411 })
412}
413
414fn generate_struct_newtype(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
415 quote! {
416 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
417 #[repr(transparent)]
418 #vis struct #name(#vis #repr);
419 }
420}
421
422fn generate_struct_builder_struct(
423 vis: &syn::Visibility,
424 builder_name: &Ident,
425 members: &[MemberDef],
426) -> TokenStream2 {
427 let fields: Vec<TokenStream2> = members
428 .iter()
429 .map(|m| match m {
430 MemberDef::Field { name, ty, .. } => {
431 quote! { #name: Option<#ty>, }
432 }
433 MemberDef::Flag { name, .. } => {
434 quote! { #name: Option<bool>, }
435 }
436 })
437 .collect();
438
439 quote! {
440 #[derive(Debug, Clone, Copy, Default)]
441 #vis struct #builder_name {
442 #(#fields)*
443 }
444 }
445}
446
447fn generate_struct_newtype_impl(
448 name: &Ident,
449 builder_name: &Ident,
450 repr: &Ident,
451 members: &[MemberDef],
452) -> TokenStream2 {
453 let repr_bit_count = repr_bits(repr);
454
455 let accessors: Vec<TokenStream2> = members.iter().map(|m| {
456 match m {
457 MemberDef::Field { name: field_name, ty, range } => {
458 let start = range.start;
459 let len = range.len;
460 let mask = if len >= repr_bit_count {
461 quote! { #repr::MAX }
462 } else {
463 quote! { ((1 as #repr) << #len) - 1 }
464 };
465
466 if is_primitive(ty) {
467 quote! {
468 #[inline]
469 pub const fn #field_name(&self) -> #ty {
470 ((self.0 >> #start) & #mask) as #ty
471 }
472 }
473 } else {
474 quote! {
476 #[inline]
477 pub fn #field_name(&self) -> Result<#ty, nexus_bits::UnknownDiscriminant<#repr>> {
478 let field_repr = ((self.0 >> #start) & #mask);
479 <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
480 .ok_or(nexus_bits::UnknownDiscriminant {
481 field: stringify!(#field_name),
482 value: field_repr as #repr,
483 })
484 }
485 }
486 }
487 }
488 MemberDef::Flag { name: field_name, bit } => {
489 quote! {
490 #[inline]
491 pub const fn #field_name(&self) -> bool {
492 (self.0 >> #bit) & 1 != 0
493 }
494 }
495 }
496 }
497 }).collect();
498
499 quote! {
500 impl #name {
501 #[inline]
503 pub const fn from_raw(raw: #repr) -> Self {
504 Self(raw)
505 }
506
507 #[inline]
509 pub const fn raw(self) -> #repr {
510 self.0
511 }
512
513 #[inline]
515 pub fn builder() -> #builder_name {
516 #builder_name::default()
517 }
518
519 #(#accessors)*
520 }
521 }
522}
523
524fn generate_struct_builder_impl(
525 name: &Ident,
526 builder_name: &Ident,
527 repr: &Ident,
528 members: &[MemberDef],
529) -> TokenStream2 {
530 let repr_bit_count = repr_bits(repr);
531
532 let setters: Vec<TokenStream2> = members
534 .iter()
535 .map(|m| match m {
536 MemberDef::Field {
537 name: field_name,
538 ty,
539 ..
540 } => {
541 quote! {
542 #[inline]
543 pub fn #field_name(mut self, val: #ty) -> Self {
544 self.#field_name = Some(val);
545 self
546 }
547 }
548 }
549 MemberDef::Flag {
550 name: field_name, ..
551 } => {
552 quote! {
553 #[inline]
554 pub fn #field_name(mut self, val: bool) -> Self {
555 self.#field_name = Some(val);
556 self
557 }
558 }
559 }
560 })
561 .collect();
562
563 let validations: Vec<TokenStream2> = members
565 .iter()
566 .filter_map(|m| match m {
567 MemberDef::Field {
568 name: field_name,
569 ty,
570 range,
571 } => {
572 let field_str = field_name.to_string();
573 let len = range.len;
574
575 let max_val = if len >= repr_bit_count {
576 quote! { #repr::MAX }
577 } else {
578 quote! { ((1 as #repr) << #len) - 1 }
579 };
580
581 if is_primitive(ty) {
582 let type_bits: u32 = match ty {
583 Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
584 Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
585 Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
586 Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
587 Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
588 _ => 128,
589 };
590
591 if len >= type_bits {
593 return None;
594 }
595
596 let is_signed = matches!(ty,
598 Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
599 p.path.is_ident("i32") || p.path.is_ident("i64") ||
600 p.path.is_ident("i128")
601 );
602
603 if is_signed {
604 let min_shift = len - 1;
609 Some(quote! {
610 if let Some(v) = self.#field_name {
611 let min_val = -((1i128 << #min_shift) as i128);
612 let max_val = ((1i128 << #min_shift) - 1) as i128;
613 let v_i128 = v as i128;
614 if v_i128 < min_val || v_i128 > max_val {
615 return Err(nexus_bits::FieldOverflow {
616 field: #field_str,
617 overflow: nexus_bits::Overflow {
618 value: (v as #repr),
619 max: #max_val,
620 },
621 });
622 }
623 }
624 })
625 } else {
626 Some(quote! {
628 if let Some(v) = self.#field_name {
629 if (v as #repr) > #max_val {
630 return Err(nexus_bits::FieldOverflow {
631 field: #field_str,
632 overflow: nexus_bits::Overflow {
633 value: v as #repr,
634 max: #max_val,
635 },
636 });
637 }
638 }
639 })
640 }
641 } else {
642 Some(quote! {
644 const _: () = assert!(
645 core::mem::size_of::<<#ty as nexus_bits::IntEnum>::Repr>() <= core::mem::size_of::<#repr>(),
646 "IntEnum repr type is wider than storage repr — values may be truncated"
647 );
648 if let Some(v) = self.#field_name {
649 let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
650 if repr_val > #max_val {
651 return Err(nexus_bits::FieldOverflow {
652 field: #field_str,
653 overflow: nexus_bits::Overflow {
654 value: repr_val,
655 max: #max_val,
656 },
657 });
658 }
659 }
660 })
661 }
662 }
663 MemberDef::Flag { .. } => None,
664 })
665 .collect();
666
667 let pack_statements: Vec<TokenStream2> = members
669 .iter()
670 .map(|m| {
671 match m {
672 MemberDef::Field {
673 name: field_name,
674 ty,
675 range,
676 } => {
677 let start = range.start;
678 let len = range.len;
679 let mask = if len >= repr_bit_count {
680 quote! { #repr::MAX }
681 } else {
682 quote! { ((1 as #repr) << #len) - 1 }
683 };
684
685 if is_primitive(ty) {
686 quote! {
687 if let Some(v) = self.#field_name {
688 val |= ((v as #repr) & #mask) << #start;
689 }
690 }
691 } else {
692 quote! {
694 if let Some(v) = self.#field_name {
695 val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
696 }
697 }
698 }
699 }
700 MemberDef::Flag {
701 name: field_name,
702 bit,
703 } => {
704 quote! {
705 if let Some(true) = self.#field_name {
706 val |= (1 as #repr) << #bit;
707 }
708 }
709 }
710 }
711 })
712 .collect();
713
714 quote! {
715 impl #builder_name {
716 #(#setters)*
717
718 #[inline]
720 pub fn build(self) -> Result<#name, nexus_bits::FieldOverflow<#repr>> {
721 #(#validations)*
723
724 let mut val: #repr = 0;
726 #(#pack_statements)*
727
728 Ok(#name(val))
729 }
730 }
731 }
732}
733
734struct ParsedVariant {
740 name: Ident,
741 discriminant: u64,
742 members: Vec<MemberDef>,
743}
744
745fn derive_storage_enum(
746 input: &DeriveInput,
747 data: &syn::DataEnum,
748 storage_attr: &StorageAttr,
749) -> Result<TokenStream2> {
750 let discriminant = storage_attr.discriminant.ok_or_else(|| {
751 Error::new_spanned(
752 input,
753 "bit_storage enum requires discriminant: #[bit_storage(repr = T, discriminant(start = N, len = M))]",
754 )
755 })?;
756
757 let repr = &storage_attr.repr;
758 let bits = repr_bits(repr);
759
760 if discriminant.start + discriminant.len > bits {
762 return Err(Error::new_spanned(
763 input,
764 format!(
765 "discriminant exceeds {} bits (start {} + len {} = {})",
766 bits,
767 discriminant.start,
768 discriminant.len,
769 discriminant.start + discriminant.len
770 ),
771 ));
772 }
773
774 let max_discriminant = if discriminant.len >= 64 {
775 u64::MAX
776 } else {
777 (1u64 << discriminant.len) - 1
778 };
779
780 let mut variants = Vec::new();
782 for variant in &data.variants {
783 let disc = parse_variant_attr(&variant.attrs)?;
784
785 if disc > max_discriminant {
786 return Err(Error::new_spanned(
787 &variant.ident,
788 format!(
789 "variant discriminant {} exceeds max {} for {}-bit field",
790 disc, max_discriminant, discriminant.len
791 ),
792 ));
793 }
794
795 for existing in &variants {
797 let existing: &ParsedVariant = existing;
798 if existing.discriminant == disc {
799 return Err(Error::new_spanned(
800 &variant.ident,
801 format!(
802 "duplicate discriminant {}: already used by '{}'",
803 disc, existing.name
804 ),
805 ));
806 }
807 }
808
809 let members: Vec<MemberDef> = match &variant.fields {
810 Fields::Named(fields) => fields
811 .named
812 .iter()
813 .map(parse_member)
814 .collect::<Result<_>>()?,
815 Fields::Unit => Vec::new(),
816 Fields::Unnamed(_) => {
817 return Err(Error::new_spanned(
818 variant,
819 "tuple variants not supported, use named fields",
820 ));
821 }
822 };
823
824 let disc_range = MemberDef::Field {
826 name: Ident::new("__discriminant", proc_macro2::Span::call_site()),
827 ty: syn::parse_quote!(u64),
828 range: discriminant,
829 };
830
831 for member in &members {
832 if ranges_overlap(&disc_range, member) {
833 return Err(Error::new_spanned(
834 member.name(),
835 format!("field '{}' overlaps with discriminant", member.name()),
836 ));
837 }
838 }
839
840 validate_members(&members, repr)?;
842
843 variants.push(ParsedVariant {
844 name: variant.ident.clone(),
845 discriminant: disc,
846 members,
847 });
848 }
849
850 let vis = &input.vis;
851 let name = &input.ident;
852
853 let parent_type = generate_enum_parent_type(vis, name, repr);
854 let variant_types = generate_enum_variant_types(vis, name, repr, &variants);
855 let kind_enum = generate_enum_kind(vis, name, &variants);
856 let builder_structs = generate_enum_builder_structs(vis, name, &variants);
857 let parent_impl = generate_enum_parent_impl(name, repr, discriminant, &variants);
858 let variant_impls = generate_enum_variant_impls(name, repr, &variants);
859 let builder_impls = generate_enum_builder_impls(name, repr, discriminant, &variants);
860 let from_impls = generate_enum_from_impls(name, &variants);
861
862 Ok(quote! {
863 #parent_type
864 #variant_types
865 #kind_enum
866 #builder_structs
867 #parent_impl
868 #variant_impls
869 #builder_impls
870 #from_impls
871 })
872}
873
874fn variant_type_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
875 Ident::new(
876 &format!("{}{}", parent_name, variant_name),
877 variant_name.span(),
878 )
879}
880
881fn variant_builder_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
882 Ident::new(
883 &format!("{}{}Builder", parent_name, variant_name),
884 variant_name.span(),
885 )
886}
887
888fn kind_enum_name(parent_name: &Ident) -> Ident {
889 Ident::new(&format!("{}Kind", parent_name), parent_name.span())
890}
891
892fn generate_enum_parent_type(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
893 quote! {
894 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
895 #[repr(transparent)]
896 #vis struct #name(#vis #repr);
897 }
898}
899
900fn generate_enum_variant_types(
901 vis: &syn::Visibility,
902 parent_name: &Ident,
903 repr: &Ident,
904 variants: &[ParsedVariant],
905) -> TokenStream2 {
906 let types: Vec<TokenStream2> = variants
907 .iter()
908 .map(|v| {
909 let type_name = variant_type_name(parent_name, &v.name);
910 quote! {
911 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
912 #[repr(transparent)]
913 #vis struct #type_name(#repr);
914 }
915 })
916 .collect();
917
918 quote! { #(#types)* }
919}
920
921fn generate_enum_kind(
922 vis: &syn::Visibility,
923 parent_name: &Ident,
924 variants: &[ParsedVariant],
925) -> TokenStream2 {
926 let kind_name = kind_enum_name(parent_name);
927 let variant_names: Vec<&Ident> = variants.iter().map(|v| &v.name).collect();
928
929 quote! {
930 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
931 #vis enum #kind_name {
932 #(#variant_names),*
933 }
934 }
935}
936
937fn generate_enum_builder_structs(
938 vis: &syn::Visibility,
939 parent_name: &Ident,
940 variants: &[ParsedVariant],
941) -> TokenStream2 {
942 let builders: Vec<TokenStream2> = variants
943 .iter()
944 .map(|v| {
945 let builder_name = variant_builder_name(parent_name, &v.name);
946
947 let fields: Vec<TokenStream2> = v
948 .members
949 .iter()
950 .map(|m| match m {
951 MemberDef::Field { name, ty, .. } => {
952 quote! { #name: Option<#ty>, }
953 }
954 MemberDef::Flag { name, .. } => {
955 quote! { #name: Option<bool>, }
956 }
957 })
958 .collect();
959
960 quote! {
961 #[derive(Debug, Clone, Copy, Default)]
962 #vis struct #builder_name {
963 #(#fields)*
964 }
965 }
966 })
967 .collect();
968
969 quote! { #(#builders)* }
970}
971
972fn generate_enum_parent_impl(
973 name: &Ident,
974 repr: &Ident,
975 discriminant: BitRange,
976 variants: &[ParsedVariant],
977) -> TokenStream2 {
978 let repr_bit_count = repr_bits(repr);
979 let kind_name = kind_enum_name(name);
980 let disc_start = discriminant.start;
981 let disc_len = discriminant.len;
982
983 assert!(
986 disc_len <= 64,
987 "discriminant length must be <= 64 bits (got {disc_len})"
988 );
989
990 let disc_mask = if disc_len >= repr_bit_count {
991 quote! { #repr::MAX }
992 } else {
993 quote! { ((1 as #repr) << #disc_len) - 1 }
994 };
995
996 let kind_arms: Vec<TokenStream2> = variants
998 .iter()
999 .map(|v| {
1000 let variant_name = &v.name;
1001 let disc_val = v.discriminant;
1002 quote! {
1003 #disc_val => Ok(#kind_name::#variant_name),
1004 }
1005 })
1006 .collect();
1007
1008 let is_methods: Vec<TokenStream2> = variants
1010 .iter()
1011 .map(|v| {
1012 let variant_name = &v.name;
1013 let method_name = Ident::new(
1014 &format!("is_{}", to_snake_case(&variant_name.to_string())),
1015 variant_name.span(),
1016 );
1017 let disc_val = v.discriminant;
1018 quote! {
1019 #[inline]
1020 pub fn #method_name(&self) -> bool {
1021 let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1022 disc == #disc_val
1023 }
1024 }
1025 })
1026 .collect();
1027
1028 let as_methods: Vec<TokenStream2> = variants
1030 .iter()
1031 .map(|v| {
1032 let variant_name = &v.name;
1033 let variant_type = variant_type_name(name, variant_name);
1034 let method_name = Ident::new(
1035 &format!("as_{}", to_snake_case(&variant_name.to_string())),
1036 variant_name.span(),
1037 );
1038 let disc_val = v.discriminant;
1039
1040 let validations: Vec<TokenStream2> = v.members
1042 .iter()
1043 .filter_map(|m| {
1044 if let MemberDef::Field { name: field_name, ty, range } = m {
1045 if !is_primitive(ty) {
1046 let start = range.start;
1047 let len = range.len;
1048 let repr_bit_count = repr_bits(repr);
1049 let mask = if len >= repr_bit_count {
1050 quote! { #repr::MAX }
1051 } else {
1052 quote! { ((1 as #repr) << #len) - 1 }
1053 };
1054 return Some(quote! {
1055 let field_repr = ((self.0 >> #start) & #mask);
1056 if <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _).is_none() {
1057 return Err(nexus_bits::UnknownDiscriminant {
1058 field: stringify!(#field_name),
1059 value: field_repr as #repr,
1060 });
1061 }
1062 });
1063 }
1064 }
1065 None
1066 })
1067 .collect();
1068
1069 quote! {
1070 #[inline]
1071 pub fn #method_name(&self) -> Result<#variant_type, nexus_bits::UnknownDiscriminant<#repr>> {
1072 let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1073 if disc != #disc_val {
1074 return Err(nexus_bits::UnknownDiscriminant {
1075 field: "__discriminant",
1076 value: disc as #repr,
1077 });
1078 }
1079 #(#validations)*
1080 Ok(#variant_type(self.0))
1081 }
1082 }
1083 })
1084 .collect();
1085
1086 let builder_methods: Vec<TokenStream2> = variants
1088 .iter()
1089 .map(|v| {
1090 let variant_name = &v.name;
1091 let builder_name = variant_builder_name(name, variant_name);
1092 let method_name = Ident::new(
1093 &to_snake_case(&variant_name.to_string()),
1094 variant_name.span(),
1095 );
1096 quote! {
1097 #[inline]
1098 pub fn #method_name() -> #builder_name {
1099 #builder_name::default()
1100 }
1101 }
1102 })
1103 .collect();
1104
1105 quote! {
1106 impl #name {
1107 #[inline]
1109 pub const fn from_raw(raw: #repr) -> Self {
1110 Self(raw)
1111 }
1112
1113 #[inline]
1115 pub const fn raw(self) -> #repr {
1116 self.0
1117 }
1118
1119 #[inline]
1121 pub fn kind(&self) -> Result<#kind_name, nexus_bits::UnknownDiscriminant<#repr>> {
1122 let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1123 match disc {
1124 #(#kind_arms)*
1125 _ => Err(nexus_bits::UnknownDiscriminant {
1126 field: "__discriminant",
1127 value: disc as #repr,
1128 }),
1129 }
1130 }
1131
1132 #(#is_methods)*
1133
1134 #(#as_methods)*
1135
1136 #(#builder_methods)*
1137 }
1138 }
1139}
1140
1141fn generate_enum_variant_impls(
1142 parent_name: &Ident,
1143 repr: &Ident,
1144 variants: &[ParsedVariant],
1145) -> TokenStream2 {
1146 let repr_bit_count = repr_bits(repr);
1147
1148 let impls: Vec<TokenStream2> =
1149 variants
1150 .iter()
1151 .map(|v| {
1152 let variant_name = &v.name;
1153 let variant_type = variant_type_name(parent_name, variant_name);
1154 let builder_name = variant_builder_name(parent_name, variant_name);
1155
1156 let accessors: Vec<TokenStream2> = v.members
1158 .iter()
1159 .map(|m| {
1160 match m {
1161 MemberDef::Field { name: field_name, ty, range } => {
1162 let start = range.start;
1163 let len = range.len;
1164 let mask = if len >= repr_bit_count {
1165 quote! { #repr::MAX }
1166 } else {
1167 quote! { ((1 as #repr) << #len) - 1 }
1168 };
1169
1170 if is_primitive(ty) {
1171 quote! {
1172 #[inline]
1173 pub const fn #field_name(&self) -> #ty {
1174 ((self.0 >> #start) & #mask) as #ty
1175 }
1176 }
1177 } else {
1178 quote! {
1180 #[inline]
1181 pub fn #field_name(&self) -> #ty {
1182 let field_repr = ((self.0 >> #start) & #mask);
1183 <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
1185 .expect("variant type invariant violated")
1186 }
1187 }
1188 }
1189 }
1190 MemberDef::Flag { name: field_name, bit } => {
1191 quote! {
1192 #[inline]
1193 pub const fn #field_name(&self) -> bool {
1194 (self.0 >> #bit) & 1 != 0
1195 }
1196 }
1197 }
1198 }
1199 })
1200 .collect();
1201
1202 quote! {
1203 impl #variant_type {
1204 #[inline]
1206 pub fn builder() -> #builder_name {
1207 #builder_name::default()
1208 }
1209
1210 #[inline]
1212 pub const fn raw(self) -> #repr {
1213 self.0
1214 }
1215
1216 #[inline]
1218 pub const fn as_parent(self) -> #parent_name {
1219 #parent_name(self.0)
1220 }
1221
1222 #(#accessors)*
1223 }
1224 }
1225 })
1226 .collect();
1227
1228 quote! { #(#impls)* }
1229}
1230
1231fn generate_enum_builder_impls(
1232 parent_name: &Ident,
1233 repr: &Ident,
1234 discriminant: BitRange,
1235 variants: &[ParsedVariant],
1236) -> TokenStream2 {
1237 let repr_bit_count = repr_bits(repr);
1238 let disc_start = discriminant.start;
1239
1240 let impls: Vec<TokenStream2> = variants
1241 .iter()
1242 .map(|v| {
1243 let variant_name = &v.name;
1244 let variant_type = variant_type_name(parent_name, variant_name);
1245 let builder_name = variant_builder_name(parent_name, variant_name);
1246 let disc_val = v.discriminant;
1247
1248 let setters: Vec<TokenStream2> = v.members
1250 .iter()
1251 .map(|m| match m {
1252 MemberDef::Field { name: field_name, ty, .. } => {
1253 quote! {
1254 #[inline]
1255 pub fn #field_name(mut self, val: #ty) -> Self {
1256 self.#field_name = Some(val);
1257 self
1258 }
1259 }
1260 }
1261 MemberDef::Flag { name: field_name, .. } => {
1262 quote! {
1263 #[inline]
1264 pub fn #field_name(mut self, val: bool) -> Self {
1265 self.#field_name = Some(val);
1266 self
1267 }
1268 }
1269 }
1270 })
1271 .collect();
1272
1273 let validations: Vec<TokenStream2> = v.members
1275 .iter()
1276 .filter_map(|m| match m {
1277 MemberDef::Field { name: field_name, ty, range } => {
1278 let field_str = field_name.to_string();
1279 let len = range.len;
1280
1281 let max_val = if len >= repr_bit_count {
1282 quote! { #repr::MAX }
1283 } else {
1284 quote! { ((1 as #repr) << #len) - 1 }
1285 };
1286
1287 if is_primitive(ty) {
1288 let type_bits: u32 = match ty {
1289 Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
1290 Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
1291 Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
1292 Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
1293 Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
1294 _ => 128,
1295 };
1296
1297 if len >= type_bits {
1298 return None;
1299 }
1300
1301 let is_signed = matches!(ty,
1302 Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
1303 p.path.is_ident("i32") || p.path.is_ident("i64") ||
1304 p.path.is_ident("i128")
1305 );
1306
1307 if is_signed {
1308 let min_shift = len - 1;
1309 Some(quote! {
1310 if let Some(v) = self.#field_name {
1311 let min_val = -((1i128 << #min_shift) as i128);
1312 let max_val = ((1i128 << #min_shift) - 1) as i128;
1313 let v_i128 = v as i128;
1314 if v_i128 < min_val || v_i128 > max_val {
1315 return Err(nexus_bits::FieldOverflow {
1316 field: #field_str,
1317 overflow: nexus_bits::Overflow {
1318 value: (v as #repr),
1319 max: #max_val,
1320 },
1321 });
1322 }
1323 }
1324 })
1325 } else {
1326 Some(quote! {
1327 if let Some(v) = self.#field_name {
1328 if (v as #repr) > #max_val {
1329 return Err(nexus_bits::FieldOverflow {
1330 field: #field_str,
1331 overflow: nexus_bits::Overflow {
1332 value: v as #repr,
1333 max: #max_val,
1334 },
1335 });
1336 }
1337 }
1338 })
1339 }
1340 } else {
1341 Some(quote! {
1343 if let Some(v) = self.#field_name {
1344 let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
1345 if repr_val > #max_val {
1346 return Err(nexus_bits::FieldOverflow {
1347 field: #field_str,
1348 overflow: nexus_bits::Overflow {
1349 value: repr_val,
1350 max: #max_val,
1351 },
1352 });
1353 }
1354 }
1355 })
1356 }
1357 }
1358 MemberDef::Flag { .. } => None,
1359 })
1360 .collect();
1361
1362 let pack_statements: Vec<TokenStream2> = v.members
1364 .iter()
1365 .map(|m| {
1366 match m {
1367 MemberDef::Field { name: field_name, ty, range } => {
1368 let start = range.start;
1369 let len = range.len;
1370 let mask = if len >= repr_bit_count {
1371 quote! { #repr::MAX }
1372 } else {
1373 quote! { ((1 as #repr) << #len) - 1 }
1374 };
1375
1376 if is_primitive(ty) {
1377 quote! {
1378 if let Some(v) = self.#field_name {
1379 val |= ((v as #repr) & #mask) << #start;
1380 }
1381 }
1382 } else {
1383 quote! {
1384 if let Some(v) = self.#field_name {
1385 val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
1386 }
1387 }
1388 }
1389 }
1390 MemberDef::Flag { name: field_name, bit } => {
1391 quote! {
1392 if let Some(true) = self.#field_name {
1393 val |= (1 as #repr) << #bit;
1394 }
1395 }
1396 }
1397 }
1398 })
1399 .collect();
1400
1401 quote! {
1402 impl #builder_name {
1403 #(#setters)*
1404
1405 #[inline]
1407 pub fn build(self) -> Result<#variant_type, nexus_bits::FieldOverflow<#repr>> {
1408 #(#validations)*
1409
1410 let mut val: #repr = 0;
1411 val |= (#disc_val as #repr) << #disc_start;
1413 #(#pack_statements)*
1414
1415 Ok(#variant_type(val))
1416 }
1417
1418 #[inline]
1420 pub fn build_parent(self) -> Result<#parent_name, nexus_bits::FieldOverflow<#repr>> {
1421 self.build().map(|v| v.as_parent())
1422 }
1423 }
1424 }
1425 })
1426 .collect();
1427
1428 quote! { #(#impls)* }
1429}
1430
1431fn generate_enum_from_impls(parent_name: &Ident, variants: &[ParsedVariant]) -> TokenStream2 {
1432 let impls: Vec<TokenStream2> = variants
1433 .iter()
1434 .map(|v| {
1435 let variant_type = variant_type_name(parent_name, &v.name);
1436 quote! {
1437 impl From<#variant_type> for #parent_name {
1438 #[inline]
1439 fn from(v: #variant_type) -> Self {
1440 v.as_parent()
1441 }
1442 }
1443 }
1444 })
1445 .collect();
1446
1447 quote! { #(#impls)* }
1448}
1449
1450fn to_snake_case(s: &str) -> String {
1451 let mut result = String::new();
1452 for (i, c) in s.chars().enumerate() {
1453 if c.is_uppercase() {
1454 if i > 0 {
1455 result.push('_');
1456 }
1457 result.push(c.to_lowercase().next().unwrap());
1458 } else {
1459 result.push(c);
1460 }
1461 }
1462 result
1463}