1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{
5 Data, DeriveInput, Error, Fields, Ident, Result, Type, parse::Parser, parse_macro_input,
6};
7
8#[proc_macro_derive(IntEnum)]
13pub fn derive_int_enum(input: TokenStream) -> TokenStream {
14 let input = parse_macro_input!(input as DeriveInput);
15
16 match derive_int_enum_impl(&input) {
17 Ok(tokens) => tokens.into(),
18 Err(err) => err.to_compile_error().into(),
19 }
20}
21
22fn derive_int_enum_impl(input: &DeriveInput) -> Result<TokenStream2> {
23 let variants = match &input.data {
24 Data::Enum(data) => &data.variants,
25 _ => {
26 return Err(Error::new_spanned(
27 input,
28 "IntEnum can only be derived for enums",
29 ));
30 }
31 };
32
33 let repr = parse_repr(input)?;
34
35 for variant in variants {
36 if !matches!(variant.fields, Fields::Unit) {
37 return Err(Error::new_spanned(
38 variant,
39 "IntEnum variants cannot have fields",
40 ));
41 }
42 }
43
44 let name = &input.ident;
45 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
46
47 let from_arms = variants.iter().map(|v| {
48 let variant_name = &v.ident;
49 quote! {
50 x if x == #name::#variant_name as #repr => Some(#name::#variant_name),
51 }
52 });
53
54 Ok(quote! {
55 impl #impl_generics nexus_bits::IntEnum for #name #ty_generics #where_clause {
56 type Repr = #repr;
57
58 #[inline]
59 fn into_repr(self) -> #repr {
60 self as #repr
61 }
62
63 #[inline]
64 fn try_from_repr(repr: #repr) -> Option<Self> {
65 match repr {
66 #(#from_arms)*
67 _ => None,
68 }
69 }
70 }
71 })
72}
73
74fn parse_repr(input: &DeriveInput) -> Result<Ident> {
75 for attr in &input.attrs {
76 if attr.path().is_ident("repr") {
77 let repr: Ident = attr.parse_args()?;
78 match repr.to_string().as_str() {
79 "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" => {
80 return Ok(repr);
81 }
82 _ => {
83 return Err(Error::new_spanned(
84 repr,
85 "IntEnum requires a primitive integer repr (u8..u128, i8..i128)",
86 ));
87 }
88 }
89 }
90 }
91
92 Err(Error::new_spanned(
93 input,
94 "IntEnum requires a #[repr(u8/u16/u32/u64/i8/i16/i32/i64)] attribute",
95 ))
96}
97
98#[proc_macro_attribute]
103pub fn bit_storage(attr: TokenStream, item: TokenStream) -> TokenStream {
104 let attr = proc_macro2::TokenStream::from(attr);
105 let item = parse_macro_input!(item as DeriveInput);
106
107 match bit_storage_impl(attr, &item) {
108 Ok(tokens) => tokens.into(),
109 Err(err) => err.to_compile_error().into(),
110 }
111}
112
113fn bit_storage_impl(attr: TokenStream2, input: &DeriveInput) -> Result<TokenStream2> {
114 let storage_attr = parse_storage_attr_tokens(attr)?;
115
116 match &input.data {
117 Data::Struct(data) => derive_storage_struct(input, data, &storage_attr),
118 Data::Enum(data) => derive_storage_enum(input, data, &storage_attr),
119 Data::Union(_) => Err(Error::new_spanned(
120 input,
121 "bit_storage cannot be applied to unions",
122 )),
123 }
124}
125
126struct StorageAttr {
132 repr: Ident,
133 discriminant: Option<BitRange>,
134}
135
136#[derive(Clone, Copy)]
138struct BitRange {
139 start: u32,
140 len: u32,
141}
142
143#[allow(clippy::large_enum_variant)]
147enum MemberDef {
148 Field {
149 name: Ident,
150 ty: Type,
151 range: BitRange,
152 },
153 Flag {
154 name: Ident,
155 bit: u32,
156 },
157}
158
159impl MemberDef {
160 fn name(&self) -> &Ident {
161 match self {
162 MemberDef::Field { name, .. } | MemberDef::Flag { name, .. } => name,
163 }
164 }
165}
166
167fn parse_storage_attr_tokens(attr: TokenStream2) -> Result<StorageAttr> {
172 let mut repr = None;
173 let mut discriminant = None;
174
175 let parser = syn::meta::parser(|meta| {
176 if meta.path.is_ident("repr") {
177 meta.input.parse::<syn::Token![=]>()?;
178 repr = Some(meta.input.parse::<Ident>()?);
179 Ok(())
180 } else if meta.path.is_ident("discriminant") {
181 let content;
182 syn::parenthesized!(content in meta.input);
183 discriminant = Some(parse_bit_range(&content)?);
184 Ok(())
185 } else {
186 Err(meta.error("expected `repr` or `discriminant`"))
187 }
188 });
189
190 parser.parse2(attr)?;
191
192 let repr = repr.ok_or_else(|| {
193 Error::new(
194 proc_macro2::Span::call_site(),
195 "bit_storage requires `repr = ...`",
196 )
197 })?;
198
199 match repr.to_string().as_str() {
201 "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" => {}
202 _ => return Err(Error::new_spanned(&repr, "repr must be an integer type")),
203 }
204
205 Ok(StorageAttr { repr, discriminant })
206}
207
208fn parse_bit_range(input: syn::parse::ParseStream) -> Result<BitRange> {
209 let mut start = None;
210 let mut len = None;
211
212 while !input.is_empty() {
213 let ident: Ident = input.parse()?;
214 input.parse::<syn::Token![=]>()?;
215 let lit: syn::LitInt = input.parse()?;
216 let value: u32 = lit.base10_parse()?;
217
218 match ident.to_string().as_str() {
219 "start" => start = Some(value),
220 "len" => len = Some(value),
221 _ => return Err(Error::new_spanned(ident, "expected `start` or `len`")),
222 }
223
224 if input.peek(syn::Token![,]) {
225 input.parse::<syn::Token![,]>()?;
226 }
227 }
228
229 let start = start.ok_or_else(|| Error::new(input.span(), "missing `start`"))?;
230 let len = len.ok_or_else(|| Error::new(input.span(), "missing `len`"))?;
231
232 if len == 0 {
233 return Err(Error::new(input.span(), "len must be > 0"));
234 }
235
236 Ok(BitRange { start, len })
237}
238
239fn parse_member(field: &syn::Field) -> Result<MemberDef> {
240 let name = field
241 .ident
242 .clone()
243 .ok_or_else(|| Error::new_spanned(field, "tuple structs not supported"))?;
244 let ty = field.ty.clone();
245
246 for attr in &field.attrs {
247 if attr.path().is_ident("field") {
248 let range = attr.parse_args_with(parse_bit_range)?;
249 return Ok(MemberDef::Field { name, ty, range });
250 } else if attr.path().is_ident("flag") {
251 let bit: syn::LitInt = attr.parse_args()?;
252 let bit: u32 = bit.base10_parse()?;
253 return Ok(MemberDef::Flag { name, bit });
254 }
255 }
256
257 Err(Error::new_spanned(
258 field,
259 "field requires #[field(start = N, len = M)] or #[flag(N)] attribute",
260 ))
261}
262
263fn parse_variant_attr(attrs: &[syn::Attribute]) -> Result<u64> {
264 for attr in attrs {
265 if attr.path().is_ident("variant") {
266 let lit: syn::LitInt = attr.parse_args()?;
267 return lit.base10_parse();
268 }
269 }
270 Err(Error::new(
271 proc_macro2::Span::call_site(),
272 "enum variant requires #[variant(N)] attribute",
273 ))
274}
275
276fn is_primitive(ty: &Type) -> bool {
281 if let Type::Path(type_path) = ty {
282 if let Some(ident) = type_path.path.get_ident() {
283 return matches!(
284 ident.to_string().as_str(),
285 "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128"
286 );
287 }
288 }
289 false
290}
291
292fn is_signed_primitive(ty: &Type) -> bool {
293 if let Type::Path(type_path) = ty {
294 if let Some(ident) = type_path.path.get_ident() {
295 return matches!(
296 ident.to_string().as_str(),
297 "i8" | "i16" | "i32" | "i64" | "i128"
298 );
299 }
300 }
301 false
302}
303
304fn primitive_bits(ty: &Type) -> u32 {
305 if let Type::Path(type_path) = ty {
306 if let Some(ident) = type_path.path.get_ident() {
307 return match ident.to_string().as_str() {
308 "u8" | "i8" => 8,
309 "u16" | "i16" => 16,
310 "u32" | "i32" => 32,
311 "u64" | "i64" => 64,
312 "u128" | "i128" => 128,
313 _ => 0,
314 };
315 }
316 }
317 0
318}
319
320fn repr_bits(repr: &Ident) -> u32 {
321 match repr.to_string().as_str() {
322 "u8" | "i8" => 8,
323 "u16" | "i16" => 16,
324 "u32" | "i32" => 32,
325 "u64" | "i64" => 64,
326 "u128" | "i128" => 128,
327 _ => 0,
328 }
329}
330
331fn field_mask(repr: &Ident, len: u32, repr_bit_count: u32) -> TokenStream2 {
337 if len >= repr_bit_count {
338 quote! { (!0 as #repr) }
339 } else {
340 quote! { (((1u128 << #len) - 1) as #repr) }
341 }
342}
343
344fn validate_members(members: &[MemberDef], repr: &Ident) -> Result<()> {
349 let bits = repr_bits(repr);
350
351 for member in members {
353 match member {
354 MemberDef::Field { name, range, .. } => {
355 if range.start + range.len > bits {
356 return Err(Error::new_spanned(
357 name,
358 format!(
359 "field exceeds {} bits (start {} + len {} = {})",
360 bits,
361 range.start,
362 range.len,
363 range.start + range.len
364 ),
365 ));
366 }
367 }
368 MemberDef::Flag { name, bit, .. } => {
369 if *bit >= bits {
370 return Err(Error::new_spanned(
371 name,
372 format!("flag bit {} exceeds {} bits", bit, bits),
373 ));
374 }
375 }
376 }
377 }
378
379 for (i, a) in members.iter().enumerate() {
381 for b in members.iter().skip(i + 1) {
382 if ranges_overlap(a, b) {
383 return Err(Error::new_spanned(
384 b.name(),
385 format!("field '{}' overlaps with '{}'", b.name(), a.name()),
386 ));
387 }
388 }
389 }
390
391 Ok(())
392}
393
394fn ranges_overlap(a: &MemberDef, b: &MemberDef) -> bool {
395 let (a_start, a_end) = member_bit_range(a);
396 let (b_start, b_end) = member_bit_range(b);
397 a_start < b_end && b_start < a_end
398}
399
400fn member_bit_range(m: &MemberDef) -> (u32, u32) {
401 match m {
402 MemberDef::Field { range, .. } => (range.start, range.start + range.len),
403 MemberDef::Flag { bit, .. } => (*bit, bit + 1),
404 }
405}
406
407fn derive_storage_struct(
412 input: &DeriveInput,
413 data: &syn::DataStruct,
414 storage_attr: &StorageAttr,
415) -> Result<TokenStream2> {
416 let fields = match &data.fields {
417 Fields::Named(f) => &f.named,
418 _ => {
419 return Err(Error::new_spanned(
420 input,
421 "bit_storage requires named fields",
422 ));
423 }
424 };
425
426 if storage_attr.discriminant.is_some() {
427 return Err(Error::new_spanned(
428 input,
429 "discriminant is only valid for enums",
430 ));
431 }
432
433 let members: Vec<MemberDef> = fields.iter().map(parse_member).collect::<Result<_>>()?;
434
435 validate_members(&members, &storage_attr.repr)?;
436
437 let vis = &input.vis;
438 let name = &input.ident;
439 let repr = &storage_attr.repr;
440 let builder_name = Ident::new(&format!("{}Builder", name), name.span());
441
442 let newtype = generate_struct_newtype(vis, name, repr);
443 let builder_struct = generate_struct_builder_struct(vis, &builder_name, &members);
444 let newtype_impl = generate_struct_newtype_impl(name, &builder_name, repr, &members);
445 let builder_impl = generate_struct_builder_impl(name, &builder_name, repr, &members);
446
447 Ok(quote! {
448 #newtype
449 #builder_struct
450 #newtype_impl
451 #builder_impl
452 })
453}
454
455fn generate_struct_newtype(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
456 quote! {
457 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
458 #[repr(transparent)]
459 #vis struct #name(#vis #repr);
460 }
461}
462
463fn generate_struct_builder_struct(
464 vis: &syn::Visibility,
465 builder_name: &Ident,
466 members: &[MemberDef],
467) -> TokenStream2 {
468 let fields: Vec<TokenStream2> = members
469 .iter()
470 .map(|m| match m {
471 MemberDef::Field { name, ty, .. } => {
472 quote! { #name: Option<#ty>, }
473 }
474 MemberDef::Flag { name, .. } => {
475 quote! { #name: Option<bool>, }
476 }
477 })
478 .collect();
479
480 quote! {
481 #[derive(Debug, Clone, Copy, Default)]
482 #vis struct #builder_name {
483 #(#fields)*
484 }
485 }
486}
487
488fn generate_struct_newtype_impl(
489 name: &Ident,
490 builder_name: &Ident,
491 repr: &Ident,
492 members: &[MemberDef],
493) -> TokenStream2 {
494 let repr_bit_count = repr_bits(repr);
495
496 let accessors: Vec<TokenStream2> = members.iter().map(|m| {
497 match m {
498 MemberDef::Field { name: field_name, ty, range } => {
499 let start = range.start;
500 let len = range.len;
501 let mask = field_mask(repr, len, repr_bit_count);
502
503 if is_primitive(ty) {
504 let type_bits = primitive_bits(ty);
505 if is_signed_primitive(ty) && len < type_bits {
506 let shift = type_bits - len;
508 quote! {
509 #[inline]
510 pub const fn #field_name(&self) -> #ty {
511 let raw = ((self.0 >> #start) & #mask) as #ty;
512 (raw << #shift) >> #shift
513 }
514 }
515 } else {
516 quote! {
517 #[inline]
518 pub const fn #field_name(&self) -> #ty {
519 ((self.0 >> #start) & #mask) as #ty
520 }
521 }
522 }
523 } else {
524 quote! {
526 #[inline]
527 pub fn #field_name(&self) -> Result<#ty, nexus_bits::UnknownDiscriminant<#repr>> {
528 let field_repr = ((self.0 >> #start) & #mask);
529 <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
530 .ok_or(nexus_bits::UnknownDiscriminant {
531 field: stringify!(#field_name),
532 value: field_repr as #repr,
533 })
534 }
535 }
536 }
537 }
538 MemberDef::Flag { name: field_name, bit } => {
539 quote! {
540 #[inline]
541 pub const fn #field_name(&self) -> bool {
542 (self.0 >> #bit) & 1 != 0
543 }
544 }
545 }
546 }
547 }).collect();
548
549 quote! {
550 impl #name {
551 #[inline]
553 pub const fn from_raw(raw: #repr) -> Self {
554 Self(raw)
555 }
556
557 #[inline]
559 pub const fn raw(self) -> #repr {
560 self.0
561 }
562
563 #[inline]
565 pub fn builder() -> #builder_name {
566 #builder_name::default()
567 }
568
569 #(#accessors)*
570 }
571 }
572}
573
574fn generate_struct_builder_impl(
575 name: &Ident,
576 builder_name: &Ident,
577 repr: &Ident,
578 members: &[MemberDef],
579) -> TokenStream2 {
580 let repr_bit_count = repr_bits(repr);
581
582 let setters: Vec<TokenStream2> = members
584 .iter()
585 .map(|m| match m {
586 MemberDef::Field {
587 name: field_name,
588 ty,
589 ..
590 } => {
591 quote! {
592 #[inline]
593 pub fn #field_name(mut self, val: #ty) -> Self {
594 self.#field_name = Some(val);
595 self
596 }
597 }
598 }
599 MemberDef::Flag {
600 name: field_name, ..
601 } => {
602 quote! {
603 #[inline]
604 pub fn #field_name(mut self, val: bool) -> Self {
605 self.#field_name = Some(val);
606 self
607 }
608 }
609 }
610 })
611 .collect();
612
613 let validations: Vec<TokenStream2> = members
615 .iter()
616 .filter_map(|m| match m {
617 MemberDef::Field {
618 name: field_name,
619 ty,
620 range,
621 } => {
622 let field_str = field_name.to_string();
623 let len = range.len;
624
625 let max_val = field_mask(repr, len, repr_bit_count);
626
627 if is_primitive(ty) {
628 let type_bits: u32 = match ty {
629 Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
630 Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
631 Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
632 Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
633 Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
634 _ => 128,
635 };
636
637 if len >= type_bits {
639 return None;
640 }
641
642 let is_signed = matches!(ty,
644 Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
645 p.path.is_ident("i32") || p.path.is_ident("i64") ||
646 p.path.is_ident("i128")
647 );
648
649 if is_signed {
650 let min_shift = len - 1;
655 Some(quote! {
656 if let Some(v) = self.#field_name {
657 let min_val = -((1i128 << #min_shift) as i128);
658 let max_val = ((1i128 << #min_shift) - 1) as i128;
659 let v_i128 = v as i128;
660 if v_i128 < min_val || v_i128 > max_val {
661 return Err(nexus_bits::FieldOverflow {
662 field: #field_str,
663 overflow: nexus_bits::Overflow {
664 value: (v as #repr),
665 max: #max_val,
666 },
667 });
668 }
669 }
670 })
671 } else {
672 Some(quote! {
674 if let Some(v) = self.#field_name {
675 if (v as #repr) > #max_val {
676 return Err(nexus_bits::FieldOverflow {
677 field: #field_str,
678 overflow: nexus_bits::Overflow {
679 value: v as #repr,
680 max: #max_val,
681 },
682 });
683 }
684 }
685 })
686 }
687 } else {
688 Some(quote! {
690 const _: () = assert!(
691 core::mem::size_of::<<#ty as nexus_bits::IntEnum>::Repr>() <= core::mem::size_of::<#repr>(),
692 "IntEnum repr type is wider than storage repr — values may be truncated"
693 );
694 if let Some(v) = self.#field_name {
695 let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
696 if repr_val > #max_val {
697 return Err(nexus_bits::FieldOverflow {
698 field: #field_str,
699 overflow: nexus_bits::Overflow {
700 value: repr_val,
701 max: #max_val,
702 },
703 });
704 }
705 }
706 })
707 }
708 }
709 MemberDef::Flag { .. } => None,
710 })
711 .collect();
712
713 let pack_statements: Vec<TokenStream2> = members
715 .iter()
716 .map(|m| {
717 match m {
718 MemberDef::Field {
719 name: field_name,
720 ty,
721 range,
722 } => {
723 let start = range.start;
724 let len = range.len;
725 let mask = field_mask(repr, len, repr_bit_count);
726
727 if is_primitive(ty) {
728 quote! {
729 if let Some(v) = self.#field_name {
730 val |= ((v as #repr) & #mask) << #start;
731 }
732 }
733 } else {
734 quote! {
736 if let Some(v) = self.#field_name {
737 val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
738 }
739 }
740 }
741 }
742 MemberDef::Flag {
743 name: field_name,
744 bit,
745 } => {
746 quote! {
747 if let Some(true) = self.#field_name {
748 val |= (1 as #repr) << #bit;
749 }
750 }
751 }
752 }
753 })
754 .collect();
755
756 quote! {
757 impl #builder_name {
758 #(#setters)*
759
760 #[inline]
762 pub fn build(self) -> Result<#name, nexus_bits::FieldOverflow<#repr>> {
763 #(#validations)*
765
766 let mut val: #repr = 0;
768 #(#pack_statements)*
769
770 Ok(#name(val))
771 }
772 }
773 }
774}
775
776struct ParsedVariant {
782 name: Ident,
783 discriminant: u64,
784 members: Vec<MemberDef>,
785}
786
787fn derive_storage_enum(
788 input: &DeriveInput,
789 data: &syn::DataEnum,
790 storage_attr: &StorageAttr,
791) -> Result<TokenStream2> {
792 let discriminant = storage_attr.discriminant.ok_or_else(|| {
793 Error::new_spanned(
794 input,
795 "bit_storage enum requires discriminant: #[bit_storage(repr = T, discriminant(start = N, len = M))]",
796 )
797 })?;
798
799 let repr = &storage_attr.repr;
800 let bits = repr_bits(repr);
801
802 if discriminant.start + discriminant.len > bits {
804 return Err(Error::new_spanned(
805 input,
806 format!(
807 "discriminant exceeds {} bits (start {} + len {} = {})",
808 bits,
809 discriminant.start,
810 discriminant.len,
811 discriminant.start + discriminant.len
812 ),
813 ));
814 }
815
816 let max_discriminant = if discriminant.len >= 64 {
817 u64::MAX
818 } else {
819 (1u64 << discriminant.len) - 1
820 };
821
822 let mut variants = Vec::new();
824 for variant in &data.variants {
825 let disc = parse_variant_attr(&variant.attrs)?;
826
827 if disc > max_discriminant {
828 return Err(Error::new_spanned(
829 &variant.ident,
830 format!(
831 "variant discriminant {} exceeds max {} for {}-bit field",
832 disc, max_discriminant, discriminant.len
833 ),
834 ));
835 }
836
837 for existing in &variants {
839 let existing: &ParsedVariant = existing;
840 if existing.discriminant == disc {
841 return Err(Error::new_spanned(
842 &variant.ident,
843 format!(
844 "duplicate discriminant {}: already used by '{}'",
845 disc, existing.name
846 ),
847 ));
848 }
849 }
850
851 let members: Vec<MemberDef> = match &variant.fields {
852 Fields::Named(fields) => fields
853 .named
854 .iter()
855 .map(parse_member)
856 .collect::<Result<_>>()?,
857 Fields::Unit => Vec::new(),
858 Fields::Unnamed(_) => {
859 return Err(Error::new_spanned(
860 variant,
861 "tuple variants not supported, use named fields",
862 ));
863 }
864 };
865
866 let disc_range = MemberDef::Field {
868 name: Ident::new("__discriminant", proc_macro2::Span::call_site()),
869 ty: syn::parse_quote!(u64),
870 range: discriminant,
871 };
872
873 for member in &members {
874 if ranges_overlap(&disc_range, member) {
875 return Err(Error::new_spanned(
876 member.name(),
877 format!("field '{}' overlaps with discriminant", member.name()),
878 ));
879 }
880 }
881
882 validate_members(&members, repr)?;
884
885 variants.push(ParsedVariant {
886 name: variant.ident.clone(),
887 discriminant: disc,
888 members,
889 });
890 }
891
892 let vis = &input.vis;
893 let name = &input.ident;
894
895 let parent_type = generate_enum_parent_type(vis, name, repr);
896 let variant_types = generate_enum_variant_types(vis, name, repr, &variants);
897 let kind_enum = generate_enum_kind(vis, name, &variants);
898 let builder_structs = generate_enum_builder_structs(vis, name, &variants);
899 let parent_impl = generate_enum_parent_impl(name, repr, discriminant, &variants);
900 let variant_impls = generate_enum_variant_impls(name, repr, &variants);
901 let builder_impls = generate_enum_builder_impls(name, repr, discriminant, &variants);
902 let from_impls = generate_enum_from_impls(name, &variants);
903
904 Ok(quote! {
905 #parent_type
906 #variant_types
907 #kind_enum
908 #builder_structs
909 #parent_impl
910 #variant_impls
911 #builder_impls
912 #from_impls
913 })
914}
915
916fn variant_type_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
917 Ident::new(
918 &format!("{}{}", parent_name, variant_name),
919 variant_name.span(),
920 )
921}
922
923fn variant_builder_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
924 Ident::new(
925 &format!("{}{}Builder", parent_name, variant_name),
926 variant_name.span(),
927 )
928}
929
930fn kind_enum_name(parent_name: &Ident) -> Ident {
931 Ident::new(&format!("{}Kind", parent_name), parent_name.span())
932}
933
934fn generate_enum_parent_type(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
935 quote! {
936 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
937 #[repr(transparent)]
938 #vis struct #name(#vis #repr);
939 }
940}
941
942fn generate_enum_variant_types(
943 vis: &syn::Visibility,
944 parent_name: &Ident,
945 repr: &Ident,
946 variants: &[ParsedVariant],
947) -> TokenStream2 {
948 let types: Vec<TokenStream2> = variants
949 .iter()
950 .map(|v| {
951 let type_name = variant_type_name(parent_name, &v.name);
952 quote! {
953 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
954 #[repr(transparent)]
955 #vis struct #type_name(#repr);
956 }
957 })
958 .collect();
959
960 quote! { #(#types)* }
961}
962
963fn generate_enum_kind(
964 vis: &syn::Visibility,
965 parent_name: &Ident,
966 variants: &[ParsedVariant],
967) -> TokenStream2 {
968 let kind_name = kind_enum_name(parent_name);
969 let variant_names: Vec<&Ident> = variants.iter().map(|v| &v.name).collect();
970
971 quote! {
972 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
973 #vis enum #kind_name {
974 #(#variant_names),*
975 }
976 }
977}
978
979fn generate_enum_builder_structs(
980 vis: &syn::Visibility,
981 parent_name: &Ident,
982 variants: &[ParsedVariant],
983) -> TokenStream2 {
984 let builders: Vec<TokenStream2> = variants
985 .iter()
986 .map(|v| {
987 let builder_name = variant_builder_name(parent_name, &v.name);
988
989 let fields: Vec<TokenStream2> = v
990 .members
991 .iter()
992 .map(|m| match m {
993 MemberDef::Field { name, ty, .. } => {
994 quote! { #name: Option<#ty>, }
995 }
996 MemberDef::Flag { name, .. } => {
997 quote! { #name: Option<bool>, }
998 }
999 })
1000 .collect();
1001
1002 quote! {
1003 #[derive(Debug, Clone, Copy, Default)]
1004 #vis struct #builder_name {
1005 #(#fields)*
1006 }
1007 }
1008 })
1009 .collect();
1010
1011 quote! { #(#builders)* }
1012}
1013
1014fn generate_enum_parent_impl(
1015 name: &Ident,
1016 repr: &Ident,
1017 discriminant: BitRange,
1018 variants: &[ParsedVariant],
1019) -> TokenStream2 {
1020 let repr_bit_count = repr_bits(repr);
1021 let kind_name = kind_enum_name(name);
1022 let disc_start = discriminant.start;
1023 let disc_len = discriminant.len;
1024
1025 assert!(
1028 disc_len <= 64,
1029 "discriminant length must be <= 64 bits (got {disc_len})"
1030 );
1031
1032 let disc_mask = field_mask(repr, disc_len, repr_bit_count);
1033
1034 let kind_arms: Vec<TokenStream2> = variants
1036 .iter()
1037 .map(|v| {
1038 let variant_name = &v.name;
1039 let disc_val = v.discriminant;
1040 quote! {
1041 #disc_val => Ok(#kind_name::#variant_name),
1042 }
1043 })
1044 .collect();
1045
1046 let is_methods: Vec<TokenStream2> = variants
1048 .iter()
1049 .map(|v| {
1050 let variant_name = &v.name;
1051 let method_name = Ident::new(
1052 &format!("is_{}", to_snake_case(&variant_name.to_string())),
1053 variant_name.span(),
1054 );
1055 let disc_val = v.discriminant;
1056 quote! {
1057 #[inline]
1058 pub fn #method_name(&self) -> bool {
1059 let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1060 disc == #disc_val
1061 }
1062 }
1063 })
1064 .collect();
1065
1066 let as_methods: Vec<TokenStream2> = variants
1068 .iter()
1069 .map(|v| {
1070 let variant_name = &v.name;
1071 let variant_type = variant_type_name(name, variant_name);
1072 let method_name = Ident::new(
1073 &format!("as_{}", to_snake_case(&variant_name.to_string())),
1074 variant_name.span(),
1075 );
1076 let disc_val = v.discriminant;
1077
1078 let validations: Vec<TokenStream2> = v.members
1080 .iter()
1081 .filter_map(|m| {
1082 if let MemberDef::Field { name: field_name, ty, range } = m {
1083 if !is_primitive(ty) {
1084 let start = range.start;
1085 let len = range.len;
1086 let repr_bit_count = repr_bits(repr);
1087 let mask = field_mask(repr, len, repr_bit_count);
1088 return Some(quote! {
1089 let field_repr = ((self.0 >> #start) & #mask);
1090 if <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _).is_none() {
1091 return Err(nexus_bits::UnknownDiscriminant {
1092 field: stringify!(#field_name),
1093 value: field_repr as #repr,
1094 });
1095 }
1096 });
1097 }
1098 }
1099 None
1100 })
1101 .collect();
1102
1103 quote! {
1104 #[inline]
1105 pub fn #method_name(&self) -> Result<#variant_type, nexus_bits::UnknownDiscriminant<#repr>> {
1106 let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1107 if disc != #disc_val {
1108 return Err(nexus_bits::UnknownDiscriminant {
1109 field: "__discriminant",
1110 value: disc as #repr,
1111 });
1112 }
1113 #(#validations)*
1114 Ok(#variant_type(self.0))
1115 }
1116 }
1117 })
1118 .collect();
1119
1120 let builder_methods: Vec<TokenStream2> = variants
1122 .iter()
1123 .map(|v| {
1124 let variant_name = &v.name;
1125 let builder_name = variant_builder_name(name, variant_name);
1126 let method_name = Ident::new(
1127 &to_snake_case(&variant_name.to_string()),
1128 variant_name.span(),
1129 );
1130 quote! {
1131 #[inline]
1132 pub fn #method_name() -> #builder_name {
1133 #builder_name::default()
1134 }
1135 }
1136 })
1137 .collect();
1138
1139 quote! {
1140 impl #name {
1141 #[inline]
1143 pub const fn from_raw(raw: #repr) -> Self {
1144 Self(raw)
1145 }
1146
1147 #[inline]
1149 pub const fn raw(self) -> #repr {
1150 self.0
1151 }
1152
1153 #[inline]
1155 pub fn kind(&self) -> Result<#kind_name, nexus_bits::UnknownDiscriminant<#repr>> {
1156 let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1157 match disc {
1158 #(#kind_arms)*
1159 _ => Err(nexus_bits::UnknownDiscriminant {
1160 field: "__discriminant",
1161 value: disc as #repr,
1162 }),
1163 }
1164 }
1165
1166 #(#is_methods)*
1167
1168 #(#as_methods)*
1169
1170 #(#builder_methods)*
1171 }
1172 }
1173}
1174
1175fn generate_enum_variant_impls(
1176 parent_name: &Ident,
1177 repr: &Ident,
1178 variants: &[ParsedVariant],
1179) -> TokenStream2 {
1180 let repr_bit_count = repr_bits(repr);
1181
1182 let impls: Vec<TokenStream2> =
1183 variants
1184 .iter()
1185 .map(|v| {
1186 let variant_name = &v.name;
1187 let variant_type = variant_type_name(parent_name, variant_name);
1188 let builder_name = variant_builder_name(parent_name, variant_name);
1189
1190 let accessors: Vec<TokenStream2> = v.members
1192 .iter()
1193 .map(|m| {
1194 match m {
1195 MemberDef::Field { name: field_name, ty, range } => {
1196 let start = range.start;
1197 let len = range.len;
1198 let mask = field_mask(repr, len, repr_bit_count);
1199
1200 if is_primitive(ty) {
1201 let type_bits = primitive_bits(ty);
1202 if is_signed_primitive(ty) && len < type_bits {
1203 let shift = type_bits - len;
1204 quote! {
1205 #[inline]
1206 pub const fn #field_name(&self) -> #ty {
1207 let raw = ((self.0 >> #start) & #mask) as #ty;
1208 (raw << #shift) >> #shift
1209 }
1210 }
1211 } else {
1212 quote! {
1213 #[inline]
1214 pub const fn #field_name(&self) -> #ty {
1215 ((self.0 >> #start) & #mask) as #ty
1216 }
1217 }
1218 }
1219 } else {
1220 quote! {
1222 #[inline]
1223 pub fn #field_name(&self) -> #ty {
1224 let field_repr = ((self.0 >> #start) & #mask);
1225 <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
1227 .expect("variant type invariant violated")
1228 }
1229 }
1230 }
1231 }
1232 MemberDef::Flag { name: field_name, bit } => {
1233 quote! {
1234 #[inline]
1235 pub const fn #field_name(&self) -> bool {
1236 (self.0 >> #bit) & 1 != 0
1237 }
1238 }
1239 }
1240 }
1241 })
1242 .collect();
1243
1244 quote! {
1245 impl #variant_type {
1246 #[inline]
1248 pub fn builder() -> #builder_name {
1249 #builder_name::default()
1250 }
1251
1252 #[inline]
1254 pub const fn raw(self) -> #repr {
1255 self.0
1256 }
1257
1258 #[inline]
1260 pub const fn as_parent(self) -> #parent_name {
1261 #parent_name(self.0)
1262 }
1263
1264 #(#accessors)*
1265 }
1266 }
1267 })
1268 .collect();
1269
1270 quote! { #(#impls)* }
1271}
1272
1273fn generate_enum_builder_impls(
1274 parent_name: &Ident,
1275 repr: &Ident,
1276 discriminant: BitRange,
1277 variants: &[ParsedVariant],
1278) -> TokenStream2 {
1279 let repr_bit_count = repr_bits(repr);
1280 let disc_start = discriminant.start;
1281
1282 let impls: Vec<TokenStream2> = variants
1283 .iter()
1284 .map(|v| {
1285 let variant_name = &v.name;
1286 let variant_type = variant_type_name(parent_name, variant_name);
1287 let builder_name = variant_builder_name(parent_name, variant_name);
1288 let disc_val = v.discriminant;
1289
1290 let setters: Vec<TokenStream2> = v.members
1292 .iter()
1293 .map(|m| match m {
1294 MemberDef::Field { name: field_name, ty, .. } => {
1295 quote! {
1296 #[inline]
1297 pub fn #field_name(mut self, val: #ty) -> Self {
1298 self.#field_name = Some(val);
1299 self
1300 }
1301 }
1302 }
1303 MemberDef::Flag { name: field_name, .. } => {
1304 quote! {
1305 #[inline]
1306 pub fn #field_name(mut self, val: bool) -> Self {
1307 self.#field_name = Some(val);
1308 self
1309 }
1310 }
1311 }
1312 })
1313 .collect();
1314
1315 let validations: Vec<TokenStream2> = v.members
1317 .iter()
1318 .filter_map(|m| match m {
1319 MemberDef::Field { name: field_name, ty, range } => {
1320 let field_str = field_name.to_string();
1321 let len = range.len;
1322
1323 let max_val = field_mask(repr, len, repr_bit_count);
1324
1325 if is_primitive(ty) {
1326 let type_bits: u32 = match ty {
1327 Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
1328 Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
1329 Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
1330 Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
1331 Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
1332 _ => 128,
1333 };
1334
1335 if len >= type_bits {
1336 return None;
1337 }
1338
1339 let is_signed = matches!(ty,
1340 Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
1341 p.path.is_ident("i32") || p.path.is_ident("i64") ||
1342 p.path.is_ident("i128")
1343 );
1344
1345 if is_signed {
1346 let min_shift = len - 1;
1347 Some(quote! {
1348 if let Some(v) = self.#field_name {
1349 let min_val = -((1i128 << #min_shift) as i128);
1350 let max_val = ((1i128 << #min_shift) - 1) as i128;
1351 let v_i128 = v as i128;
1352 if v_i128 < min_val || v_i128 > max_val {
1353 return Err(nexus_bits::FieldOverflow {
1354 field: #field_str,
1355 overflow: nexus_bits::Overflow {
1356 value: (v as #repr),
1357 max: #max_val,
1358 },
1359 });
1360 }
1361 }
1362 })
1363 } else {
1364 Some(quote! {
1365 if let Some(v) = self.#field_name {
1366 if (v as #repr) > #max_val {
1367 return Err(nexus_bits::FieldOverflow {
1368 field: #field_str,
1369 overflow: nexus_bits::Overflow {
1370 value: v as #repr,
1371 max: #max_val,
1372 },
1373 });
1374 }
1375 }
1376 })
1377 }
1378 } else {
1379 Some(quote! {
1381 if let Some(v) = self.#field_name {
1382 let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
1383 if repr_val > #max_val {
1384 return Err(nexus_bits::FieldOverflow {
1385 field: #field_str,
1386 overflow: nexus_bits::Overflow {
1387 value: repr_val,
1388 max: #max_val,
1389 },
1390 });
1391 }
1392 }
1393 })
1394 }
1395 }
1396 MemberDef::Flag { .. } => None,
1397 })
1398 .collect();
1399
1400 let pack_statements: Vec<TokenStream2> = v.members
1402 .iter()
1403 .map(|m| {
1404 match m {
1405 MemberDef::Field { name: field_name, ty, range } => {
1406 let start = range.start;
1407 let len = range.len;
1408 let mask = field_mask(repr, len, repr_bit_count);
1409
1410 if is_primitive(ty) {
1411 quote! {
1412 if let Some(v) = self.#field_name {
1413 val |= ((v as #repr) & #mask) << #start;
1414 }
1415 }
1416 } else {
1417 quote! {
1418 if let Some(v) = self.#field_name {
1419 val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
1420 }
1421 }
1422 }
1423 }
1424 MemberDef::Flag { name: field_name, bit } => {
1425 quote! {
1426 if let Some(true) = self.#field_name {
1427 val |= (1 as #repr) << #bit;
1428 }
1429 }
1430 }
1431 }
1432 })
1433 .collect();
1434
1435 quote! {
1436 impl #builder_name {
1437 #(#setters)*
1438
1439 #[inline]
1441 pub fn build(self) -> Result<#variant_type, nexus_bits::FieldOverflow<#repr>> {
1442 #(#validations)*
1443
1444 let mut val: #repr = 0;
1445 val |= (#disc_val as #repr) << #disc_start;
1447 #(#pack_statements)*
1448
1449 Ok(#variant_type(val))
1450 }
1451
1452 #[inline]
1454 pub fn build_parent(self) -> Result<#parent_name, nexus_bits::FieldOverflow<#repr>> {
1455 self.build().map(|v| v.as_parent())
1456 }
1457 }
1458 }
1459 })
1460 .collect();
1461
1462 quote! { #(#impls)* }
1463}
1464
1465fn generate_enum_from_impls(parent_name: &Ident, variants: &[ParsedVariant]) -> TokenStream2 {
1466 let impls: Vec<TokenStream2> = variants
1467 .iter()
1468 .map(|v| {
1469 let variant_type = variant_type_name(parent_name, &v.name);
1470 quote! {
1471 impl From<#variant_type> for #parent_name {
1472 #[inline]
1473 fn from(v: #variant_type) -> Self {
1474 v.as_parent()
1475 }
1476 }
1477 }
1478 })
1479 .collect();
1480
1481 quote! { #(#impls)* }
1482}
1483
1484fn to_snake_case(s: &str) -> String {
1485 let mut result = String::new();
1486 for (i, c) in s.chars().enumerate() {
1487 if c.is_uppercase() {
1488 if i > 0 {
1489 result.push('_');
1490 }
1491 result.push(c.to_lowercase().next().unwrap());
1492 } else {
1493 result.push(c);
1494 }
1495 }
1496 result
1497}