1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{
5 Data, DeriveInput, Error, Fields, Ident, Result, Type, parse::Parser, parse_macro_input,
6};
7
8#[proc_macro_derive(IntEnum)]
13pub fn derive_int_enum(input: TokenStream) -> TokenStream {
14 let input = parse_macro_input!(input as DeriveInput);
15
16 match derive_int_enum_impl(input) {
17 Ok(tokens) => tokens.into(),
18 Err(err) => err.to_compile_error().into(),
19 }
20}
21
22fn derive_int_enum_impl(input: DeriveInput) -> Result<TokenStream2> {
23 let variants = match &input.data {
24 Data::Enum(data) => &data.variants,
25 _ => {
26 return Err(Error::new_spanned(
27 &input,
28 "IntEnum can only be derived for enums",
29 ));
30 }
31 };
32
33 let repr = parse_repr(&input)?;
34
35 for variant in variants {
36 if !matches!(variant.fields, Fields::Unit) {
37 return Err(Error::new_spanned(
38 variant,
39 "IntEnum variants cannot have fields",
40 ));
41 }
42 }
43
44 let name = &input.ident;
45 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
46
47 let from_arms = variants.iter().map(|v| {
48 let variant_name = &v.ident;
49 quote! {
50 x if x == #name::#variant_name as #repr => Some(#name::#variant_name),
51 }
52 });
53
54 Ok(quote! {
55 impl #impl_generics nexus_bits::IntEnum for #name #ty_generics #where_clause {
56 type Repr = #repr;
57
58 #[inline]
59 fn into_repr(self) -> #repr {
60 self as #repr
61 }
62
63 #[inline]
64 fn try_from_repr(repr: #repr) -> Option<Self> {
65 match repr {
66 #(#from_arms)*
67 _ => None,
68 }
69 }
70 }
71 })
72}
73
74fn parse_repr(input: &DeriveInput) -> Result<Ident> {
75 for attr in &input.attrs {
76 if attr.path().is_ident("repr") {
77 let repr: Ident = attr.parse_args()?;
78 match repr.to_string().as_str() {
79 "u8" | "u16" | "u32" | "u64" | "i8" | "i16" | "i32" | "i64" => {
80 return Ok(repr);
81 }
82 _ => {
83 return Err(Error::new_spanned(
84 repr,
85 "IntEnum requires repr(u8), repr(u16), repr(u32), repr(u64), repr(i8), repr(i16), repr(i32), or repr(i64)",
86 ));
87 }
88 }
89 }
90 }
91
92 Err(Error::new_spanned(
93 input,
94 "IntEnum requires a #[repr(u8/u16/u32/u64/i8/i16/i32/i64)] attribute",
95 ))
96}
97
98#[proc_macro_attribute]
103pub fn bit_storage(attr: TokenStream, item: TokenStream) -> TokenStream {
104 let attr = proc_macro2::TokenStream::from(attr);
105 let item = parse_macro_input!(item as DeriveInput);
106
107 match bit_storage_impl(attr, item) {
108 Ok(tokens) => tokens.into(),
109 Err(err) => err.to_compile_error().into(),
110 }
111}
112
113fn bit_storage_impl(attr: TokenStream2, input: DeriveInput) -> Result<TokenStream2> {
114 let storage_attr = parse_storage_attr_tokens(attr)?;
115
116 match &input.data {
117 Data::Struct(data) => derive_storage_struct(&input, data, storage_attr),
118 Data::Enum(data) => derive_storage_enum(&input, data, storage_attr),
119 Data::Union(_) => Err(Error::new_spanned(
120 &input,
121 "bit_storage cannot be applied to unions",
122 )),
123 }
124}
125
126struct StorageAttr {
132 repr: Ident,
133 discriminant: Option<BitRange>,
134}
135
136struct BitRange {
138 start: u32,
139 len: u32,
140}
141
142enum MemberDef {
144 Field {
145 name: Ident,
146 ty: Type,
147 range: BitRange,
148 },
149 Flag {
150 name: Ident,
151 bit: u32,
152 },
153}
154
155impl MemberDef {
156 fn name(&self) -> &Ident {
157 match self {
158 MemberDef::Field { name, .. } => name,
159 MemberDef::Flag { name, .. } => name,
160 }
161 }
162}
163
164fn parse_storage_attr_tokens(attr: TokenStream2) -> Result<StorageAttr> {
169 let mut repr = None;
170 let mut discriminant = None;
171
172 let parser = syn::meta::parser(|meta| {
173 if meta.path.is_ident("repr") {
174 meta.input.parse::<syn::Token![=]>()?;
175 repr = Some(meta.input.parse::<Ident>()?);
176 Ok(())
177 } else if meta.path.is_ident("discriminant") {
178 let content;
179 syn::parenthesized!(content in meta.input);
180 discriminant = Some(parse_bit_range(&content)?);
181 Ok(())
182 } else {
183 Err(meta.error("expected `repr` or `discriminant`"))
184 }
185 });
186
187 parser.parse2(attr)?;
188
189 let repr = repr.ok_or_else(|| {
190 Error::new(
191 proc_macro2::Span::call_site(),
192 "bit_storage requires `repr = ...`",
193 )
194 })?;
195
196 match repr.to_string().as_str() {
198 "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" => {}
199 _ => return Err(Error::new_spanned(&repr, "repr must be an integer type")),
200 }
201
202 Ok(StorageAttr { repr, discriminant })
203}
204
205fn parse_bit_range(input: syn::parse::ParseStream) -> Result<BitRange> {
206 let mut start = None;
207 let mut len = None;
208
209 while !input.is_empty() {
210 let ident: Ident = input.parse()?;
211 input.parse::<syn::Token![=]>()?;
212 let lit: syn::LitInt = input.parse()?;
213 let value: u32 = lit.base10_parse()?;
214
215 match ident.to_string().as_str() {
216 "start" => start = Some(value),
217 "len" => len = Some(value),
218 _ => return Err(Error::new_spanned(ident, "expected `start` or `len`")),
219 }
220
221 if input.peek(syn::Token![,]) {
222 input.parse::<syn::Token![,]>()?;
223 }
224 }
225
226 let start = start.ok_or_else(|| Error::new(input.span(), "missing `start`"))?;
227 let len = len.ok_or_else(|| Error::new(input.span(), "missing `len`"))?;
228
229 if len == 0 {
230 return Err(Error::new(input.span(), "len must be > 0"));
231 }
232
233 Ok(BitRange { start, len })
234}
235
236fn parse_member(field: &syn::Field) -> Result<MemberDef> {
237 let name = field
238 .ident
239 .clone()
240 .ok_or_else(|| Error::new_spanned(field, "tuple structs not supported"))?;
241 let ty = field.ty.clone();
242
243 for attr in &field.attrs {
244 if attr.path().is_ident("field") {
245 let range = attr.parse_args_with(parse_bit_range)?;
246 return Ok(MemberDef::Field { name, ty, range });
247 } else if attr.path().is_ident("flag") {
248 let bit: syn::LitInt = attr.parse_args()?;
249 let bit: u32 = bit.base10_parse()?;
250 return Ok(MemberDef::Flag { name, bit });
251 }
252 }
253
254 Err(Error::new_spanned(
255 field,
256 "field requires #[field(start = N, len = M)] or #[flag(N)] attribute",
257 ))
258}
259
260fn parse_variant_attr(attrs: &[syn::Attribute]) -> Result<u64> {
261 for attr in attrs {
262 if attr.path().is_ident("variant") {
263 let lit: syn::LitInt = attr.parse_args()?;
264 return lit.base10_parse();
265 }
266 }
267 Err(Error::new(
268 proc_macro2::Span::call_site(),
269 "enum variant requires #[variant(N)] attribute",
270 ))
271}
272
273fn is_primitive(ty: &Type) -> bool {
278 if let Type::Path(type_path) = ty {
279 if let Some(ident) = type_path.path.get_ident() {
280 return matches!(
281 ident.to_string().as_str(),
282 "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128"
283 );
284 }
285 }
286 false
287}
288
289fn repr_bits(repr: &Ident) -> u32 {
290 match repr.to_string().as_str() {
291 "u8" | "i8" => 8,
292 "u16" | "i16" => 16,
293 "u32" | "i32" => 32,
294 "u64" | "i64" => 64,
295 "u128" | "i128" => 128,
296 _ => 0,
297 }
298}
299
300fn validate_members(members: &[MemberDef], repr: &Ident) -> Result<()> {
305 let bits = repr_bits(repr);
306
307 for member in members {
309 match member {
310 MemberDef::Field { name, range, .. } => {
311 if range.start + range.len > bits {
312 return Err(Error::new_spanned(
313 name,
314 format!(
315 "field exceeds {} bits (start {} + len {} = {})",
316 bits,
317 range.start,
318 range.len,
319 range.start + range.len
320 ),
321 ));
322 }
323 }
324 MemberDef::Flag { name, bit, .. } => {
325 if *bit >= bits {
326 return Err(Error::new_spanned(
327 name,
328 format!("flag bit {} exceeds {} bits", bit, bits),
329 ));
330 }
331 }
332 }
333 }
334
335 for (i, a) in members.iter().enumerate() {
337 for b in members.iter().skip(i + 1) {
338 if ranges_overlap(a, b) {
339 return Err(Error::new_spanned(
340 b.name(),
341 format!("field '{}' overlaps with '{}'", b.name(), a.name()),
342 ));
343 }
344 }
345 }
346
347 Ok(())
348}
349
350fn ranges_overlap(a: &MemberDef, b: &MemberDef) -> bool {
351 let (a_start, a_end) = member_bit_range(a);
352 let (b_start, b_end) = member_bit_range(b);
353 a_start < b_end && b_start < a_end
354}
355
356fn member_bit_range(m: &MemberDef) -> (u32, u32) {
357 match m {
358 MemberDef::Field { range, .. } => (range.start, range.start + range.len),
359 MemberDef::Flag { bit, .. } => (*bit, bit + 1),
360 }
361}
362
363fn derive_storage_struct(
368 input: &DeriveInput,
369 data: &syn::DataStruct,
370 storage_attr: StorageAttr,
371) -> Result<TokenStream2> {
372 let fields = match &data.fields {
373 Fields::Named(f) => &f.named,
374 _ => {
375 return Err(Error::new_spanned(
376 input,
377 "bit_storage requires named fields",
378 ));
379 }
380 };
381
382 if storage_attr.discriminant.is_some() {
383 return Err(Error::new_spanned(
384 input,
385 "discriminant is only valid for enums",
386 ));
387 }
388
389 let members: Vec<MemberDef> = fields.iter().map(parse_member).collect::<Result<_>>()?;
390
391 validate_members(&members, &storage_attr.repr)?;
392
393 let vis = &input.vis;
394 let name = &input.ident;
395 let repr = &storage_attr.repr;
396 let builder_name = Ident::new(&format!("{}Builder", name), name.span());
397
398 let newtype = generate_struct_newtype(vis, name, repr);
399 let builder_struct = generate_struct_builder_struct(vis, &builder_name, &members);
400 let newtype_impl = generate_struct_newtype_impl(name, &builder_name, repr, &members);
401 let builder_impl = generate_struct_builder_impl(name, &builder_name, repr, &members);
402
403 Ok(quote! {
404 #newtype
405 #builder_struct
406 #newtype_impl
407 #builder_impl
408 })
409}
410
411fn generate_struct_newtype(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
412 quote! {
413 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
414 #[repr(transparent)]
415 #vis struct #name(#vis #repr);
416 }
417}
418
419fn generate_struct_builder_struct(
420 vis: &syn::Visibility,
421 builder_name: &Ident,
422 members: &[MemberDef],
423) -> TokenStream2 {
424 let fields: Vec<TokenStream2> = members
425 .iter()
426 .map(|m| match m {
427 MemberDef::Field { name, ty, .. } => {
428 quote! { #name: Option<#ty>, }
429 }
430 MemberDef::Flag { name, .. } => {
431 quote! { #name: Option<bool>, }
432 }
433 })
434 .collect();
435
436 quote! {
437 #[derive(Debug, Clone, Copy, Default)]
438 #vis struct #builder_name {
439 #(#fields)*
440 }
441 }
442}
443
444fn generate_struct_newtype_impl(
445 name: &Ident,
446 builder_name: &Ident,
447 repr: &Ident,
448 members: &[MemberDef],
449) -> TokenStream2 {
450 let repr_bit_count = repr_bits(repr);
451
452 let accessors: Vec<TokenStream2> = members.iter().map(|m| {
453 match m {
454 MemberDef::Field { name: field_name, ty, range } => {
455 let start = range.start;
456 let len = range.len;
457 let mask = if len >= repr_bit_count {
458 quote! { #repr::MAX }
459 } else {
460 quote! { ((1 as #repr) << #len) - 1 }
461 };
462
463 if is_primitive(ty) {
464 quote! {
465 #[inline]
466 pub const fn #field_name(&self) -> #ty {
467 ((self.0 >> #start) & #mask) as #ty
468 }
469 }
470 } else {
471 quote! {
473 #[inline]
474 pub fn #field_name(&self) -> Result<#ty, nexus_bits::UnknownDiscriminant<#repr>> {
475 let field_repr = ((self.0 >> #start) & #mask);
476 <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
477 .ok_or(nexus_bits::UnknownDiscriminant {
478 field: stringify!(#field_name),
479 value: self.0,
480 })
481 }
482 }
483 }
484 }
485 MemberDef::Flag { name: field_name, bit } => {
486 quote! {
487 #[inline]
488 pub const fn #field_name(&self) -> bool {
489 (self.0 >> #bit) & 1 != 0
490 }
491 }
492 }
493 }
494 }).collect();
495
496 quote! {
497 impl #name {
498 #[inline]
500 pub const fn from_raw(raw: #repr) -> Self {
501 Self(raw)
502 }
503
504 #[inline]
506 pub const fn raw(self) -> #repr {
507 self.0
508 }
509
510 #[inline]
512 pub fn builder() -> #builder_name {
513 #builder_name::default()
514 }
515
516 #(#accessors)*
517 }
518 }
519}
520
521fn generate_struct_builder_impl(
522 name: &Ident,
523 builder_name: &Ident,
524 repr: &Ident,
525 members: &[MemberDef],
526) -> TokenStream2 {
527 let repr_bit_count = repr_bits(repr);
528
529 let setters: Vec<TokenStream2> = members
531 .iter()
532 .map(|m| match m {
533 MemberDef::Field {
534 name: field_name,
535 ty,
536 ..
537 } => {
538 quote! {
539 #[inline]
540 pub fn #field_name(mut self, val: #ty) -> Self {
541 self.#field_name = Some(val);
542 self
543 }
544 }
545 }
546 MemberDef::Flag {
547 name: field_name, ..
548 } => {
549 quote! {
550 #[inline]
551 pub fn #field_name(mut self, val: bool) -> Self {
552 self.#field_name = Some(val);
553 self
554 }
555 }
556 }
557 })
558 .collect();
559
560 let validations: Vec<TokenStream2> = members
562 .iter()
563 .filter_map(|m| match m {
564 MemberDef::Field {
565 name: field_name,
566 ty,
567 range,
568 } => {
569 let field_str = field_name.to_string();
570 let len = range.len;
571
572 let max_val = if len >= repr_bit_count {
573 quote! { #repr::MAX }
574 } else {
575 quote! { ((1 as #repr) << #len) - 1 }
576 };
577
578 if is_primitive(ty) {
579 let type_bits: u32 = match ty {
580 Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
581 Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
582 Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
583 Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
584 Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
585 _ => 128,
586 };
587
588 if len >= type_bits {
590 return None;
591 }
592
593 let is_signed = matches!(ty,
595 Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
596 p.path.is_ident("i32") || p.path.is_ident("i64") ||
597 p.path.is_ident("i128")
598 );
599
600 if is_signed {
601 let min_shift = len - 1;
604 Some(quote! {
605 if let Some(v) = self.#field_name {
606 let min_val = -((1i128 << #min_shift) as i128);
607 let max_val = ((1i128 << #min_shift) - 1) as i128;
608 let v_i128 = v as i128;
609 if v_i128 < min_val || v_i128 > max_val {
610 return Err(nexus_bits::FieldOverflow {
611 field: #field_str,
612 overflow: nexus_bits::Overflow {
613 value: (v as #repr),
614 max: #max_val,
615 },
616 });
617 }
618 }
619 })
620 } else {
621 Some(quote! {
623 if let Some(v) = self.#field_name {
624 if (v as #repr) > #max_val {
625 return Err(nexus_bits::FieldOverflow {
626 field: #field_str,
627 overflow: nexus_bits::Overflow {
628 value: v as #repr,
629 max: #max_val,
630 },
631 });
632 }
633 }
634 })
635 }
636 } else {
637 Some(quote! {
639 if let Some(v) = self.#field_name {
640 let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
641 if repr_val > #max_val {
642 return Err(nexus_bits::FieldOverflow {
643 field: #field_str,
644 overflow: nexus_bits::Overflow {
645 value: repr_val,
646 max: #max_val,
647 },
648 });
649 }
650 }
651 })
652 }
653 }
654 _ => None,
655 })
656 .collect();
657
658 let pack_statements: Vec<TokenStream2> = members
660 .iter()
661 .map(|m| {
662 match m {
663 MemberDef::Field {
664 name: field_name,
665 ty,
666 range,
667 } => {
668 let start = range.start;
669 let len = range.len;
670 let mask = if len >= repr_bit_count {
671 quote! { #repr::MAX }
672 } else {
673 quote! { ((1 as #repr) << #len) - 1 }
674 };
675
676 if is_primitive(ty) {
677 quote! {
678 if let Some(v) = self.#field_name {
679 val |= ((v as #repr) & #mask) << #start;
680 }
681 }
682 } else {
683 quote! {
685 if let Some(v) = self.#field_name {
686 val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
687 }
688 }
689 }
690 }
691 MemberDef::Flag {
692 name: field_name,
693 bit,
694 } => {
695 quote! {
696 if let Some(true) = self.#field_name {
697 val |= (1 as #repr) << #bit;
698 }
699 }
700 }
701 }
702 })
703 .collect();
704
705 quote! {
706 impl #builder_name {
707 #(#setters)*
708
709 #[inline]
711 pub fn build(self) -> Result<#name, nexus_bits::FieldOverflow<#repr>> {
712 #(#validations)*
714
715 let mut val: #repr = 0;
717 #(#pack_statements)*
718
719 Ok(#name(val))
720 }
721 }
722 }
723}
724
725struct ParsedVariant {
731 name: Ident,
732 discriminant: u64,
733 members: Vec<MemberDef>,
734}
735
736fn derive_storage_enum(
737 input: &DeriveInput,
738 data: &syn::DataEnum,
739 storage_attr: StorageAttr,
740) -> Result<TokenStream2> {
741 let discriminant = storage_attr.discriminant.ok_or_else(|| {
742 Error::new_spanned(
743 input,
744 "bit_storage enum requires discriminant: #[bit_storage(repr = T, discriminant(start = N, len = M))]",
745 )
746 })?;
747
748 let repr = &storage_attr.repr;
749 let bits = repr_bits(repr);
750
751 if discriminant.start + discriminant.len > bits {
753 return Err(Error::new_spanned(
754 input,
755 format!(
756 "discriminant exceeds {} bits (start {} + len {} = {})",
757 bits,
758 discriminant.start,
759 discriminant.len,
760 discriminant.start + discriminant.len
761 ),
762 ));
763 }
764
765 let max_discriminant = if discriminant.len >= 64 {
766 u64::MAX
767 } else {
768 (1u64 << discriminant.len) - 1
769 };
770
771 let mut variants = Vec::new();
773 for variant in &data.variants {
774 let disc = parse_variant_attr(&variant.attrs)?;
775
776 if disc > max_discriminant {
777 return Err(Error::new_spanned(
778 &variant.ident,
779 format!(
780 "variant discriminant {} exceeds max {} for {}-bit field",
781 disc, max_discriminant, discriminant.len
782 ),
783 ));
784 }
785
786 for existing in &variants {
788 let existing: &ParsedVariant = existing;
789 if existing.discriminant == disc {
790 return Err(Error::new_spanned(
791 &variant.ident,
792 format!(
793 "duplicate discriminant {}: already used by '{}'",
794 disc, existing.name
795 ),
796 ));
797 }
798 }
799
800 let members: Vec<MemberDef> = match &variant.fields {
801 Fields::Named(fields) => fields
802 .named
803 .iter()
804 .map(parse_member)
805 .collect::<Result<_>>()?,
806 Fields::Unit => Vec::new(),
807 Fields::Unnamed(_) => {
808 return Err(Error::new_spanned(
809 variant,
810 "tuple variants not supported, use named fields",
811 ));
812 }
813 };
814
815 let disc_range = MemberDef::Field {
817 name: Ident::new("__discriminant", proc_macro2::Span::call_site()),
818 ty: syn::parse_quote!(u64),
819 range: BitRange {
820 start: discriminant.start,
821 len: discriminant.len,
822 },
823 };
824
825 for member in &members {
826 if ranges_overlap(&disc_range, member) {
827 return Err(Error::new_spanned(
828 member.name(),
829 format!("field '{}' overlaps with discriminant", member.name()),
830 ));
831 }
832 }
833
834 validate_members(&members, repr)?;
836
837 variants.push(ParsedVariant {
838 name: variant.ident.clone(),
839 discriminant: disc,
840 members,
841 });
842 }
843
844 let vis = &input.vis;
845 let name = &input.ident;
846
847 let parent_type = generate_enum_parent_type(vis, name, repr);
848 let variant_types = generate_enum_variant_types(vis, name, repr, &variants);
849 let kind_enum = generate_enum_kind(vis, name, &variants);
850 let builder_structs = generate_enum_builder_structs(vis, name, &variants);
851 let parent_impl = generate_enum_parent_impl(name, repr, &discriminant, &variants);
852 let variant_impls = generate_enum_variant_impls(name, repr, &variants);
853 let builder_impls = generate_enum_builder_impls(name, repr, &discriminant, &variants);
854 let from_impls = generate_enum_from_impls(name, &variants);
855
856 Ok(quote! {
857 #parent_type
858 #variant_types
859 #kind_enum
860 #builder_structs
861 #parent_impl
862 #variant_impls
863 #builder_impls
864 #from_impls
865 })
866}
867
868fn variant_type_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
869 Ident::new(
870 &format!("{}{}", parent_name, variant_name),
871 variant_name.span(),
872 )
873}
874
875fn variant_builder_name(parent_name: &Ident, variant_name: &Ident) -> Ident {
876 Ident::new(
877 &format!("{}{}Builder", parent_name, variant_name),
878 variant_name.span(),
879 )
880}
881
882fn kind_enum_name(parent_name: &Ident) -> Ident {
883 Ident::new(&format!("{}Kind", parent_name), parent_name.span())
884}
885
886fn generate_enum_parent_type(vis: &syn::Visibility, name: &Ident, repr: &Ident) -> TokenStream2 {
887 quote! {
888 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
889 #[repr(transparent)]
890 #vis struct #name(#vis #repr);
891 }
892}
893
894fn generate_enum_variant_types(
895 vis: &syn::Visibility,
896 parent_name: &Ident,
897 repr: &Ident,
898 variants: &[ParsedVariant],
899) -> TokenStream2 {
900 let types: Vec<TokenStream2> = variants
901 .iter()
902 .map(|v| {
903 let type_name = variant_type_name(parent_name, &v.name);
904 quote! {
905 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
906 #[repr(transparent)]
907 #vis struct #type_name(#repr);
908 }
909 })
910 .collect();
911
912 quote! { #(#types)* }
913}
914
915fn generate_enum_kind(
916 vis: &syn::Visibility,
917 parent_name: &Ident,
918 variants: &[ParsedVariant],
919) -> TokenStream2 {
920 let kind_name = kind_enum_name(parent_name);
921 let variant_names: Vec<&Ident> = variants.iter().map(|v| &v.name).collect();
922
923 quote! {
924 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
925 #vis enum #kind_name {
926 #(#variant_names),*
927 }
928 }
929}
930
931fn generate_enum_builder_structs(
932 vis: &syn::Visibility,
933 parent_name: &Ident,
934 variants: &[ParsedVariant],
935) -> TokenStream2 {
936 let builders: Vec<TokenStream2> = variants
937 .iter()
938 .map(|v| {
939 let builder_name = variant_builder_name(parent_name, &v.name);
940
941 let fields: Vec<TokenStream2> = v
942 .members
943 .iter()
944 .map(|m| match m {
945 MemberDef::Field { name, ty, .. } => {
946 quote! { #name: Option<#ty>, }
947 }
948 MemberDef::Flag { name, .. } => {
949 quote! { #name: Option<bool>, }
950 }
951 })
952 .collect();
953
954 quote! {
955 #[derive(Debug, Clone, Copy, Default)]
956 #vis struct #builder_name {
957 #(#fields)*
958 }
959 }
960 })
961 .collect();
962
963 quote! { #(#builders)* }
964}
965
966fn generate_enum_parent_impl(
967 name: &Ident,
968 repr: &Ident,
969 discriminant: &BitRange,
970 variants: &[ParsedVariant],
971) -> TokenStream2 {
972 let repr_bit_count = repr_bits(repr);
973 let kind_name = kind_enum_name(name);
974 let disc_start = discriminant.start;
975 let disc_len = discriminant.len;
976
977 let disc_mask = if disc_len >= repr_bit_count {
978 quote! { #repr::MAX }
979 } else {
980 quote! { ((1 as #repr) << #disc_len) - 1 }
981 };
982
983 let kind_arms: Vec<TokenStream2> = variants
985 .iter()
986 .map(|v| {
987 let variant_name = &v.name;
988 let disc_val = v.discriminant;
989 quote! {
990 #disc_val => Ok(#kind_name::#variant_name),
991 }
992 })
993 .collect();
994
995 let is_methods: Vec<TokenStream2> = variants
997 .iter()
998 .map(|v| {
999 let variant_name = &v.name;
1000 let method_name = Ident::new(
1001 &format!("is_{}", to_snake_case(&variant_name.to_string())),
1002 variant_name.span(),
1003 );
1004 let disc_val = v.discriminant;
1005 quote! {
1006 #[inline]
1007 pub fn #method_name(&self) -> bool {
1008 let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1009 disc == #disc_val
1010 }
1011 }
1012 })
1013 .collect();
1014
1015 let as_methods: Vec<TokenStream2> = variants
1017 .iter()
1018 .map(|v| {
1019 let variant_name = &v.name;
1020 let variant_type = variant_type_name(name, variant_name);
1021 let method_name = Ident::new(
1022 &format!("as_{}", to_snake_case(&variant_name.to_string())),
1023 variant_name.span(),
1024 );
1025 let disc_val = v.discriminant;
1026
1027 let validations: Vec<TokenStream2> = v.members
1029 .iter()
1030 .filter_map(|m| {
1031 if let MemberDef::Field { name: field_name, ty, range } = m {
1032 if !is_primitive(ty) {
1033 let start = range.start;
1034 let len = range.len;
1035 let repr_bit_count = repr_bits(repr);
1036 let mask = if len >= repr_bit_count {
1037 quote! { #repr::MAX }
1038 } else {
1039 quote! { ((1 as #repr) << #len) - 1 }
1040 };
1041 return Some(quote! {
1042 let field_repr = ((self.0 >> #start) & #mask);
1043 if <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _).is_none() {
1044 return Err(nexus_bits::UnknownDiscriminant {
1045 field: stringify!(#field_name),
1046 value: self.0,
1047 });
1048 }
1049 });
1050 }
1051 }
1052 None
1053 })
1054 .collect();
1055
1056 quote! {
1057 #[inline]
1058 pub fn #method_name(&self) -> Result<#variant_type, nexus_bits::UnknownDiscriminant<#repr>> {
1059 let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1060 if disc != #disc_val {
1061 return Err(nexus_bits::UnknownDiscriminant {
1062 field: "__discriminant",
1063 value: self.0,
1064 });
1065 }
1066 #(#validations)*
1067 Ok(#variant_type(self.0))
1068 }
1069 }
1070 })
1071 .collect();
1072
1073 let builder_methods: Vec<TokenStream2> = variants
1075 .iter()
1076 .map(|v| {
1077 let variant_name = &v.name;
1078 let builder_name = variant_builder_name(name, variant_name);
1079 let method_name = Ident::new(
1080 &to_snake_case(&variant_name.to_string()),
1081 variant_name.span(),
1082 );
1083 quote! {
1084 #[inline]
1085 pub fn #method_name() -> #builder_name {
1086 #builder_name::default()
1087 }
1088 }
1089 })
1090 .collect();
1091
1092 quote! {
1093 impl #name {
1094 #[inline]
1096 pub const fn from_raw(raw: #repr) -> Self {
1097 Self(raw)
1098 }
1099
1100 #[inline]
1102 pub const fn raw(self) -> #repr {
1103 self.0
1104 }
1105
1106 #[inline]
1108 pub fn kind(&self) -> Result<#kind_name, nexus_bits::UnknownDiscriminant<#repr>> {
1109 let disc = ((self.0 >> #disc_start) & #disc_mask) as u64;
1110 match disc {
1111 #(#kind_arms)*
1112 _ => Err(nexus_bits::UnknownDiscriminant {
1113 field: "__discriminant",
1114 value: self.0,
1115 }),
1116 }
1117 }
1118
1119 #(#is_methods)*
1120
1121 #(#as_methods)*
1122
1123 #(#builder_methods)*
1124 }
1125 }
1126}
1127
1128fn generate_enum_variant_impls(
1129 parent_name: &Ident,
1130 repr: &Ident,
1131 variants: &[ParsedVariant],
1132) -> TokenStream2 {
1133 let repr_bit_count = repr_bits(repr);
1134
1135 let impls: Vec<TokenStream2> =
1136 variants
1137 .iter()
1138 .map(|v| {
1139 let variant_name = &v.name;
1140 let variant_type = variant_type_name(parent_name, variant_name);
1141 let builder_name = variant_builder_name(parent_name, variant_name);
1142
1143 let accessors: Vec<TokenStream2> = v.members
1145 .iter()
1146 .map(|m| {
1147 match m {
1148 MemberDef::Field { name: field_name, ty, range } => {
1149 let start = range.start;
1150 let len = range.len;
1151 let mask = if len >= repr_bit_count {
1152 quote! { #repr::MAX }
1153 } else {
1154 quote! { ((1 as #repr) << #len) - 1 }
1155 };
1156
1157 if is_primitive(ty) {
1158 quote! {
1159 #[inline]
1160 pub const fn #field_name(&self) -> #ty {
1161 ((self.0 >> #start) & #mask) as #ty
1162 }
1163 }
1164 } else {
1165 quote! {
1167 #[inline]
1168 pub fn #field_name(&self) -> #ty {
1169 let field_repr = ((self.0 >> #start) & #mask);
1170 <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
1172 .expect("variant type invariant violated")
1173 }
1174 }
1175 }
1176 }
1177 MemberDef::Flag { name: field_name, bit } => {
1178 quote! {
1179 #[inline]
1180 pub const fn #field_name(&self) -> bool {
1181 (self.0 >> #bit) & 1 != 0
1182 }
1183 }
1184 }
1185 }
1186 })
1187 .collect();
1188
1189 quote! {
1190 impl #variant_type {
1191 #[inline]
1193 pub fn builder() -> #builder_name {
1194 #builder_name::default()
1195 }
1196
1197 #[inline]
1199 pub const fn raw(self) -> #repr {
1200 self.0
1201 }
1202
1203 #[inline]
1205 pub const fn as_parent(self) -> #parent_name {
1206 #parent_name(self.0)
1207 }
1208
1209 #(#accessors)*
1210 }
1211 }
1212 })
1213 .collect();
1214
1215 quote! { #(#impls)* }
1216}
1217
1218fn generate_enum_builder_impls(
1219 parent_name: &Ident,
1220 repr: &Ident,
1221 discriminant: &BitRange,
1222 variants: &[ParsedVariant],
1223) -> TokenStream2 {
1224 let repr_bit_count = repr_bits(repr);
1225 let disc_start = discriminant.start;
1226
1227 let impls: Vec<TokenStream2> = variants
1228 .iter()
1229 .map(|v| {
1230 let variant_name = &v.name;
1231 let variant_type = variant_type_name(parent_name, variant_name);
1232 let builder_name = variant_builder_name(parent_name, variant_name);
1233 let disc_val = v.discriminant;
1234
1235 let setters: Vec<TokenStream2> = v.members
1237 .iter()
1238 .map(|m| match m {
1239 MemberDef::Field { name: field_name, ty, .. } => {
1240 quote! {
1241 #[inline]
1242 pub fn #field_name(mut self, val: #ty) -> Self {
1243 self.#field_name = Some(val);
1244 self
1245 }
1246 }
1247 }
1248 MemberDef::Flag { name: field_name, .. } => {
1249 quote! {
1250 #[inline]
1251 pub fn #field_name(mut self, val: bool) -> Self {
1252 self.#field_name = Some(val);
1253 self
1254 }
1255 }
1256 }
1257 })
1258 .collect();
1259
1260 let validations: Vec<TokenStream2> = v.members
1262 .iter()
1263 .filter_map(|m| match m {
1264 MemberDef::Field { name: field_name, ty, range } => {
1265 let field_str = field_name.to_string();
1266 let len = range.len;
1267
1268 let max_val = if len >= repr_bit_count {
1269 quote! { #repr::MAX }
1270 } else {
1271 quote! { ((1 as #repr) << #len) - 1 }
1272 };
1273
1274 if is_primitive(ty) {
1275 let type_bits: u32 = match ty {
1276 Type::Path(p) if p.path.is_ident("u8") || p.path.is_ident("i8") => 8,
1277 Type::Path(p) if p.path.is_ident("u16") || p.path.is_ident("i16") => 16,
1278 Type::Path(p) if p.path.is_ident("u32") || p.path.is_ident("i32") => 32,
1279 Type::Path(p) if p.path.is_ident("u64") || p.path.is_ident("i64") => 64,
1280 Type::Path(p) if p.path.is_ident("u128") || p.path.is_ident("i128") => 128,
1281 _ => 128,
1282 };
1283
1284 if len >= type_bits {
1285 return None;
1286 }
1287
1288 let is_signed = matches!(ty,
1289 Type::Path(p) if p.path.is_ident("i8") || p.path.is_ident("i16") ||
1290 p.path.is_ident("i32") || p.path.is_ident("i64") ||
1291 p.path.is_ident("i128")
1292 );
1293
1294 if is_signed {
1295 let min_shift = len - 1;
1296 Some(quote! {
1297 if let Some(v) = self.#field_name {
1298 let min_val = -((1i128 << #min_shift) as i128);
1299 let max_val = ((1i128 << #min_shift) - 1) as i128;
1300 let v_i128 = v as i128;
1301 if v_i128 < min_val || v_i128 > max_val {
1302 return Err(nexus_bits::FieldOverflow {
1303 field: #field_str,
1304 overflow: nexus_bits::Overflow {
1305 value: (v as #repr),
1306 max: #max_val,
1307 },
1308 });
1309 }
1310 }
1311 })
1312 } else {
1313 Some(quote! {
1314 if let Some(v) = self.#field_name {
1315 if (v as #repr) > #max_val {
1316 return Err(nexus_bits::FieldOverflow {
1317 field: #field_str,
1318 overflow: nexus_bits::Overflow {
1319 value: v as #repr,
1320 max: #max_val,
1321 },
1322 });
1323 }
1324 }
1325 })
1326 }
1327 } else {
1328 Some(quote! {
1330 if let Some(v) = self.#field_name {
1331 let repr_val = nexus_bits::IntEnum::into_repr(v) as #repr;
1332 if repr_val > #max_val {
1333 return Err(nexus_bits::FieldOverflow {
1334 field: #field_str,
1335 overflow: nexus_bits::Overflow {
1336 value: repr_val,
1337 max: #max_val,
1338 },
1339 });
1340 }
1341 }
1342 })
1343 }
1344 }
1345 _ => None,
1346 })
1347 .collect();
1348
1349 let pack_statements: Vec<TokenStream2> = v.members
1351 .iter()
1352 .map(|m| {
1353 match m {
1354 MemberDef::Field { name: field_name, ty, range } => {
1355 let start = range.start;
1356 let len = range.len;
1357 let mask = if len >= repr_bit_count {
1358 quote! { #repr::MAX }
1359 } else {
1360 quote! { ((1 as #repr) << #len) - 1 }
1361 };
1362
1363 if is_primitive(ty) {
1364 quote! {
1365 if let Some(v) = self.#field_name {
1366 val |= ((v as #repr) & #mask) << #start;
1367 }
1368 }
1369 } else {
1370 quote! {
1371 if let Some(v) = self.#field_name {
1372 val |= ((nexus_bits::IntEnum::into_repr(v) as #repr) & #mask) << #start;
1373 }
1374 }
1375 }
1376 }
1377 MemberDef::Flag { name: field_name, bit } => {
1378 quote! {
1379 if let Some(true) = self.#field_name {
1380 val |= (1 as #repr) << #bit;
1381 }
1382 }
1383 }
1384 }
1385 })
1386 .collect();
1387
1388 quote! {
1389 impl #builder_name {
1390 #(#setters)*
1391
1392 #[inline]
1394 pub fn build(self) -> Result<#variant_type, nexus_bits::FieldOverflow<#repr>> {
1395 #(#validations)*
1396
1397 let mut val: #repr = 0;
1398 val |= (#disc_val as #repr) << #disc_start;
1400 #(#pack_statements)*
1401
1402 Ok(#variant_type(val))
1403 }
1404
1405 #[inline]
1407 pub fn build_parent(self) -> Result<#parent_name, nexus_bits::FieldOverflow<#repr>> {
1408 self.build().map(|v| v.as_parent())
1409 }
1410 }
1411 }
1412 })
1413 .collect();
1414
1415 quote! { #(#impls)* }
1416}
1417
1418fn generate_enum_from_impls(parent_name: &Ident, variants: &[ParsedVariant]) -> TokenStream2 {
1419 let impls: Vec<TokenStream2> = variants
1420 .iter()
1421 .map(|v| {
1422 let variant_type = variant_type_name(parent_name, &v.name);
1423 quote! {
1424 impl From<#variant_type> for #parent_name {
1425 #[inline]
1426 fn from(v: #variant_type) -> Self {
1427 v.as_parent()
1428 }
1429 }
1430 }
1431 })
1432 .collect();
1433
1434 quote! { #(#impls)* }
1435}
1436
1437fn to_snake_case(s: &str) -> String {
1438 let mut result = String::new();
1439 for (i, c) in s.chars().enumerate() {
1440 if c.is_uppercase() {
1441 if i > 0 {
1442 result.push('_');
1443 }
1444 result.push(c.to_lowercase().next().unwrap());
1445 } else {
1446 result.push(c);
1447 }
1448 }
1449 result
1450}