1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{Data, DeriveInput, Error, Fields, Ident, Result, Type, parse_macro_input};
5
6#[proc_macro_derive(IntEnum)]
11pub fn derive_int_enum(input: TokenStream) -> TokenStream {
12 let input = parse_macro_input!(input as DeriveInput);
13
14 match derive_int_enum_impl(input) {
15 Ok(tokens) => tokens.into(),
16 Err(err) => err.to_compile_error().into(),
17 }
18}
19
20fn derive_int_enum_impl(input: DeriveInput) -> Result<TokenStream2> {
21 let variants = match &input.data {
22 Data::Enum(data) => &data.variants,
23 _ => {
24 return Err(Error::new_spanned(
25 &input,
26 "IntEnum can only be derived for enums",
27 ));
28 }
29 };
30
31 let repr = parse_repr(&input)?;
32
33 for variant in variants {
34 if !matches!(variant.fields, Fields::Unit) {
35 return Err(Error::new_spanned(
36 variant,
37 "IntEnum variants cannot have fields",
38 ));
39 }
40 }
41
42 let name = &input.ident;
43 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
44
45 let from_arms = variants.iter().map(|v| {
46 let variant_name = &v.ident;
47 quote! {
48 x if x == #name::#variant_name as #repr => Some(#name::#variant_name),
49 }
50 });
51
52 Ok(quote! {
53 impl #impl_generics nexus_bits::IntEnum for #name #ty_generics #where_clause {
54 type Repr = #repr;
55
56 #[inline]
57 fn into_repr(self) -> #repr {
58 self as #repr
59 }
60
61 #[inline]
62 fn try_from_repr(repr: #repr) -> Option<Self> {
63 match repr {
64 #(#from_arms)*
65 _ => None,
66 }
67 }
68 }
69 })
70}
71
72fn parse_repr(input: &DeriveInput) -> Result<Ident> {
73 for attr in &input.attrs {
74 if attr.path().is_ident("repr") {
75 let repr: Ident = attr.parse_args()?;
76 match repr.to_string().as_str() {
77 "u8" | "u16" | "u32" | "u64" | "i8" | "i16" | "i32" | "i64" => {
78 return Ok(repr);
79 }
80 _ => {
81 return Err(Error::new_spanned(
82 repr,
83 "IntEnum requires repr(u8), repr(u16), repr(u32), repr(u64), repr(i8), repr(i16), repr(i32), or repr(i64)",
84 ));
85 }
86 }
87 }
88 }
89
90 Err(Error::new_spanned(
91 input,
92 "IntEnum requires a #[repr(u8/u16/u32/u64/i8/i16/i32/i64)] attribute",
93 ))
94}
95
96#[proc_macro_derive(BitPacked, attributes(packed, field, flag, variant))]
101pub fn derive_bit_packed(input: TokenStream) -> TokenStream {
102 let input = parse_macro_input!(input as DeriveInput);
103
104 match derive_bit_packed_impl(input) {
105 Ok(tokens) => tokens.into(),
106 Err(err) => err.to_compile_error().into(),
107 }
108}
109
110fn derive_bit_packed_impl(input: DeriveInput) -> Result<TokenStream2> {
111 match &input.data {
112 Data::Struct(data) => derive_packed_struct(&input, data),
113 Data::Enum(data) => derive_packed_enum(&input, data),
114 Data::Union(_) => Err(Error::new_spanned(
115 &input,
116 "BitPacked cannot be derived for unions",
117 )),
118 }
119}
120
121struct PackedAttr {
127 repr: Ident,
128 discriminant: Option<BitRange>,
129}
130
131struct BitRange {
133 start: u32,
134 len: u32,
135}
136
137enum MemberDef {
139 Field {
140 name: Ident,
141 ty: Type,
142 range: BitRange,
143 },
144 Flag {
145 name: Ident,
146 bit: u32,
147 },
148}
149
150impl MemberDef {
151 fn name(&self) -> &Ident {
152 match self {
153 MemberDef::Field { name, .. } => name,
154 MemberDef::Flag { name, .. } => name,
155 }
156 }
157}
158
159fn parse_packed_attr(attrs: &[syn::Attribute]) -> Result<PackedAttr> {
164 for attr in attrs {
165 if attr.path().is_ident("packed") {
166 return parse_packed_attr_inner(attr);
167 }
168 }
169 Err(Error::new(
170 proc_macro2::Span::call_site(),
171 "BitPacked requires a #[packed(repr = ...)] attribute",
172 ))
173}
174
175fn parse_packed_attr_inner(attr: &syn::Attribute) -> Result<PackedAttr> {
176 let mut repr = None;
177 let mut discriminant = None;
178
179 attr.parse_nested_meta(|meta| {
180 if meta.path.is_ident("repr") {
181 meta.input.parse::<syn::Token![=]>()?;
182 repr = Some(meta.input.parse::<Ident>()?);
183 Ok(())
184 } else if meta.path.is_ident("discriminant") {
185 let content;
186 syn::parenthesized!(content in meta.input);
187 discriminant = Some(parse_bit_range(&content)?);
188 Ok(())
189 } else {
190 Err(meta.error("expected `repr` or `discriminant`"))
191 }
192 })?;
193
194 let repr =
195 repr.ok_or_else(|| Error::new_spanned(attr, "packed attribute requires `repr = ...`"))?;
196
197 match repr.to_string().as_str() {
199 "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" => {}
200 _ => return Err(Error::new_spanned(&repr, "repr must be an integer type")),
201 }
202
203 Ok(PackedAttr { repr, discriminant })
204}
205
206fn parse_bit_range(input: syn::parse::ParseStream) -> Result<BitRange> {
207 let mut start = None;
208 let mut len = None;
209
210 while !input.is_empty() {
211 let ident: Ident = input.parse()?;
212 input.parse::<syn::Token![=]>()?;
213 let lit: syn::LitInt = input.parse()?;
214 let value: u32 = lit.base10_parse()?;
215
216 match ident.to_string().as_str() {
217 "start" => start = Some(value),
218 "len" => len = Some(value),
219 _ => return Err(Error::new_spanned(ident, "expected `start` or `len`")),
220 }
221
222 if input.peek(syn::Token![,]) {
223 input.parse::<syn::Token![,]>()?;
224 }
225 }
226
227 let start = start.ok_or_else(|| Error::new(input.span(), "missing `start`"))?;
228 let len = len.ok_or_else(|| Error::new(input.span(), "missing `len`"))?;
229
230 if len == 0 {
231 return Err(Error::new(input.span(), "len must be > 0"));
232 }
233
234 Ok(BitRange { start, len })
235}
236
237fn parse_member(field: &syn::Field) -> Result<MemberDef> {
238 let name = field
239 .ident
240 .clone()
241 .ok_or_else(|| Error::new_spanned(field, "tuple structs not supported"))?;
242 let ty = field.ty.clone();
243
244 for attr in &field.attrs {
245 if attr.path().is_ident("field") {
246 let range = attr.parse_args_with(parse_bit_range)?;
247 return Ok(MemberDef::Field { name, ty, range });
248 } else if attr.path().is_ident("flag") {
249 let bit: syn::LitInt = attr.parse_args()?;
250 let bit: u32 = bit.base10_parse()?;
251 return Ok(MemberDef::Flag { name, bit });
252 }
253 }
254
255 Err(Error::new_spanned(
256 field,
257 "field requires #[field(start = N, len = M)] or #[flag(N)] attribute",
258 ))
259}
260
261fn is_primitive(ty: &Type) -> bool {
266 if let Type::Path(type_path) = ty {
267 if let Some(ident) = type_path.path.get_ident() {
268 return matches!(
269 ident.to_string().as_str(),
270 "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128"
271 );
272 }
273 }
274 false
275}
276
277fn repr_bits(repr: &Ident) -> u32 {
278 match repr.to_string().as_str() {
279 "u8" | "i8" => 8,
280 "u16" | "i16" => 16,
281 "u32" | "i32" => 32,
282 "u64" | "i64" => 64,
283 "u128" | "i128" => 128,
284 _ => 0,
285 }
286}
287
288fn validate_members(members: &[MemberDef], repr: &Ident) -> Result<()> {
293 let bits = repr_bits(repr);
294
295 for member in members {
297 match member {
298 MemberDef::Field { name, range, .. } => {
299 if range.start + range.len > bits {
300 return Err(Error::new_spanned(
301 name,
302 format!(
303 "field exceeds {} bits (start {} + len {} = {})",
304 bits,
305 range.start,
306 range.len,
307 range.start + range.len
308 ),
309 ));
310 }
311 }
312 MemberDef::Flag { name, bit, .. } => {
313 if *bit >= bits {
314 return Err(Error::new_spanned(
315 name,
316 format!("flag bit {} exceeds {} bits", bit, bits),
317 ));
318 }
319 }
320 }
321 }
322
323 for (i, a) in members.iter().enumerate() {
325 for b in members.iter().skip(i + 1) {
326 if ranges_overlap(a, b) {
327 return Err(Error::new_spanned(
328 b.name(),
329 format!("field '{}' overlaps with '{}'", b.name(), a.name()),
330 ));
331 }
332 }
333 }
334
335 Ok(())
336}
337
338fn ranges_overlap(a: &MemberDef, b: &MemberDef) -> bool {
339 let (a_start, a_end) = member_bit_range(a);
340 let (b_start, b_end) = member_bit_range(b);
341 a_start < b_end && b_start < a_end
342}
343
344fn member_bit_range(m: &MemberDef) -> (u32, u32) {
345 match m {
346 MemberDef::Field { range, .. } => (range.start, range.start + range.len),
347 MemberDef::Flag { bit, .. } => (*bit, bit + 1),
348 }
349}
350
351fn derive_packed_struct(input: &DeriveInput, data: &syn::DataStruct) -> Result<TokenStream2> {
356 let fields = match &data.fields {
357 Fields::Named(f) => &f.named,
358 _ => return Err(Error::new_spanned(input, "BitPacked requires named fields")),
359 };
360
361 let packed_attr = parse_packed_attr(&input.attrs)?;
362
363 if packed_attr.discriminant.is_some() {
364 return Err(Error::new_spanned(
365 input,
366 "discriminant is only valid for enums",
367 ));
368 }
369
370 let members: Vec<MemberDef> = fields.iter().map(parse_member).collect::<Result<_>>()?;
371
372 validate_members(&members, &packed_attr.repr)?;
373
374 let name = &input.ident;
375 let repr = &packed_attr.repr;
376 let has_enum_fields = members
377 .iter()
378 .any(|m| matches!(m, MemberDef::Field { ty, .. } if !is_primitive(ty)));
379
380 let pack_fn = generate_struct_pack(repr, &members);
381 let unpack_fn = generate_struct_unpack(repr, &members, has_enum_fields);
382
383 Ok(quote! {
384 impl #name {
385 #pack_fn
386 #unpack_fn
387 }
388 })
389}
390
391fn generate_struct_pack(repr: &Ident, members: &[MemberDef]) -> TokenStream2 {
392 let pack_statements: Vec<TokenStream2> = members.iter().map(|m| {
393 match m {
394 MemberDef::Field { name: field_name, ty, range } => {
395 let field_str = field_name.to_string();
396 let start = range.start;
397 let len = range.len;
398 let max_val = if len >= 64 {
399 quote! { #repr::MAX }
400 } else {
401 quote! { ((1 as #repr) << #len) - 1 }
402 };
403
404 if is_primitive(ty) {
405 quote! {
406 let field_val = self.#field_name as #repr;
407 if field_val > #max_val {
408 return Err(nexus_bits::FieldOverflow {
409 field: #field_str,
410 overflow: nexus_bits::Overflow { value: field_val, max: #max_val },
411 });
412 }
413 val |= field_val << #start;
414 }
415 } else {
416 quote! {
418 let field_val = nexus_bits::IntEnum::into_repr(self.#field_name) as #repr;
419 if field_val > #max_val {
420 return Err(nexus_bits::FieldOverflow {
421 field: #field_str,
422 overflow: nexus_bits::Overflow { value: field_val, max: #max_val },
423 });
424 }
425 val |= field_val << #start;
426 }
427 }
428 }
429 MemberDef::Flag { name: field_name, bit } => {
430 quote! {
431 if self.#field_name {
432 val |= (1 as #repr) << #bit;
433 }
434 }
435 }
436 }
437 }).collect();
438
439 quote! {
440 #[inline]
442 pub fn pack(&self) -> Result<#repr, nexus_bits::FieldOverflow<#repr>> {
443 let mut val: #repr = 0;
444 #(#pack_statements)*
445 Ok(val)
446 }
447 }
448}
449
450fn generate_struct_unpack(
451 repr: &Ident,
452 members: &[MemberDef],
453 has_enum_fields: bool,
454) -> TokenStream2 {
455 let unpack_statements: Vec<TokenStream2> = members.iter().map(|m| {
456 match m {
457 MemberDef::Field { name: field_name, ty, range } => {
458 let start = range.start;
459 let len = range.len;
460 let mask = if len >= 64 {
461 quote! { #repr::MAX }
462 } else {
463 quote! { ((1 as #repr) << #len) - 1 }
464 };
465
466 if is_primitive(ty) {
467 quote! {
468 let #field_name = ((raw >> #start) & #mask) as #ty;
469 }
470 } else {
471 let field_str = field_name.to_string();
473 quote! {
474 let field_repr = ((raw >> #start) & #mask);
475 let #field_name = <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
476 .ok_or(nexus_bits::UnknownDiscriminant {
477 field: #field_str,
478 value: raw,
479 })?;
480 }
481 }
482 }
483 MemberDef::Flag { name: field_name, bit } => {
484 quote! {
485 let #field_name = (raw >> #bit) & 1 != 0;
486 }
487 }
488 }
489 }).collect();
490
491 let field_names: Vec<&Ident> = members.iter().map(|m| m.name()).collect();
492
493 if has_enum_fields {
494 quote! {
495 #[inline]
497 pub fn unpack(raw: #repr) -> Result<Self, nexus_bits::UnknownDiscriminant<#repr>> {
498 #(#unpack_statements)*
499 Ok(Self { #(#field_names),* })
500 }
501 }
502 } else {
503 quote! {
504 #[inline]
506 pub fn unpack(raw: #repr) -> Self {
507 #(#unpack_statements)*
508 Self { #(#field_names),* }
509 }
510 }
511 }
512}
513
514struct ParsedVariant {
520 name: Ident,
521 discriminant: u64,
522 members: Vec<MemberDef>,
523}
524
525fn parse_variant_attr(attrs: &[syn::Attribute]) -> Result<u64> {
526 for attr in attrs {
527 if attr.path().is_ident("variant") {
528 let lit: syn::LitInt = attr.parse_args()?;
529 return lit.base10_parse();
530 }
531 }
532 Err(Error::new(
533 proc_macro2::Span::call_site(),
534 "enum variant requires #[variant(N)] attribute",
535 ))
536}
537
538fn derive_packed_enum(input: &DeriveInput, data: &syn::DataEnum) -> Result<TokenStream2> {
539 let packed_attr = parse_packed_attr(&input.attrs)?;
540
541 let discriminant = packed_attr.discriminant.ok_or_else(|| {
542 Error::new_spanned(
543 input,
544 "BitPacked enum requires discriminant: #[packed(repr = T, discriminant(start = N, len = M))]",
545 )
546 })?;
547
548 let repr = &packed_attr.repr;
549 let bits = repr_bits(repr);
550
551 if discriminant.start + discriminant.len > bits {
553 return Err(Error::new_spanned(
554 input,
555 format!(
556 "discriminant exceeds {} bits (start {} + len {} = {})",
557 bits, discriminant.start, discriminant.len, discriminant.start + discriminant.len
558 ),
559 ));
560 }
561
562 let max_discriminant = if discriminant.len >= 64 {
563 u64::MAX
564 } else {
565 (1u64 << discriminant.len) - 1
566 };
567
568 let mut variants = Vec::new();
570 for variant in &data.variants {
571 let disc = parse_variant_attr(&variant.attrs)?;
572
573 if disc > max_discriminant {
574 return Err(Error::new_spanned(
575 &variant.ident,
576 format!(
577 "variant discriminant {} exceeds max {} for {}-bit field",
578 disc, max_discriminant, discriminant.len
579 ),
580 ));
581 }
582
583 for existing in &variants {
585 let existing: &ParsedVariant = existing;
586 if existing.discriminant == disc {
587 return Err(Error::new_spanned(
588 &variant.ident,
589 format!(
590 "duplicate discriminant {}: already used by '{}'",
591 disc, existing.name
592 ),
593 ));
594 }
595 }
596
597 let members: Vec<MemberDef> = match &variant.fields {
598 Fields::Named(fields) => {
599 fields.named.iter().map(parse_member).collect::<Result<_>>()?
600 }
601 Fields::Unit => Vec::new(),
602 Fields::Unnamed(_) => {
603 return Err(Error::new_spanned(
604 variant,
605 "tuple variants not supported, use named fields",
606 ));
607 }
608 };
609
610 let disc_range = MemberDef::Field {
612 name: Ident::new("__discriminant", proc_macro2::Span::call_site()),
613 ty: syn::parse_quote!(u64),
614 range: BitRange {
615 start: discriminant.start,
616 len: discriminant.len,
617 },
618 };
619
620 for member in &members {
621 if ranges_overlap(&disc_range, member) {
622 return Err(Error::new_spanned(
623 member.name(),
624 format!("field '{}' overlaps with discriminant", member.name()),
625 ));
626 }
627 }
628
629 validate_members(&members, repr)?;
631
632 variants.push(ParsedVariant {
633 name: variant.ident.clone(),
634 discriminant: disc,
635 members,
636 });
637 }
638
639 let name = &input.ident;
640 let pack_fn = generate_enum_pack(repr, &discriminant, &variants);
641 let unpack_fn = generate_enum_unpack(repr, &discriminant, &variants);
642
643 Ok(quote! {
644 impl #name {
645 #pack_fn
646 #unpack_fn
647 }
648 })
649}
650
651fn generate_enum_pack(
652 repr: &Ident,
653 discriminant: &BitRange,
654 variants: &[ParsedVariant],
655) -> TokenStream2 {
656 let disc_start = discriminant.start;
657
658 let match_arms: Vec<TokenStream2> = variants.iter().map(|variant| {
659 let variant_name = &variant.name;
660 let disc_val = variant.discriminant;
661
662 let field_names: Vec<&Ident> = variant.members.iter().map(|m| m.name()).collect();
663
664 let pack_statements: Vec<TokenStream2> = variant.members.iter().map(|m| {
665 match m {
666 MemberDef::Field { name: field_name, ty, range } => {
667 let field_str = field_name.to_string();
668 let start = range.start;
669 let len = range.len;
670 let max_val = if len >= 64 {
671 quote! { #repr::MAX }
672 } else {
673 quote! { ((1 as #repr) << #len) - 1 }
674 };
675
676 if is_primitive(ty) {
677 quote! {
678 let field_val = *#field_name as #repr;
679 if field_val > #max_val {
680 return Err(nexus_bits::FieldOverflow {
681 field: #field_str,
682 overflow: nexus_bits::Overflow { value: field_val, max: #max_val },
683 });
684 }
685 val |= field_val << #start;
686 }
687 } else {
688 quote! {
689 let field_val = nexus_bits::IntEnum::into_repr(*#field_name) as #repr;
690 if field_val > #max_val {
691 return Err(nexus_bits::FieldOverflow {
692 field: #field_str,
693 overflow: nexus_bits::Overflow { value: field_val, max: #max_val },
694 });
695 }
696 val |= field_val << #start;
697 }
698 }
699 }
700 MemberDef::Flag { name: field_name, bit } => {
701 quote! {
702 if *#field_name {
703 val |= (1 as #repr) << #bit;
704 }
705 }
706 }
707 }
708 }).collect();
709
710 if field_names.is_empty() {
711 quote! {
712 Self::#variant_name => {
713 let mut val: #repr = 0;
714 val |= (#disc_val as #repr) << #disc_start;
715 Ok(val)
716 }
717 }
718 } else {
719 quote! {
720 Self::#variant_name { #(#field_names),* } => {
721 let mut val: #repr = 0;
722 val |= (#disc_val as #repr) << #disc_start;
723 #(#pack_statements)*
724 Ok(val)
725 }
726 }
727 }
728 }).collect();
729
730 quote! {
731 #[inline]
733 pub fn pack(&self) -> Result<#repr, nexus_bits::FieldOverflow<#repr>> {
734 match self {
735 #(#match_arms)*
736 }
737 }
738 }
739}
740
741fn generate_enum_unpack(
742 repr: &Ident,
743 discriminant: &BitRange,
744 variants: &[ParsedVariant],
745) -> TokenStream2 {
746 let disc_start = discriminant.start;
747 let disc_len = discriminant.len;
748 let disc_mask = if disc_len >= 64 {
749 quote! { #repr::MAX }
750 } else {
751 quote! { ((1 as #repr) << #disc_len) - 1 }
752 };
753
754 let match_arms: Vec<TokenStream2> = variants.iter().map(|variant| {
755 let variant_name = &variant.name;
756 let disc_val = variant.discriminant;
757
758 let unpack_statements: Vec<TokenStream2> = variant.members.iter().map(|m| {
759 match m {
760 MemberDef::Field { name: field_name, ty, range } => {
761 let start = range.start;
762 let len = range.len;
763 let mask = if len >= 64 {
764 quote! { #repr::MAX }
765 } else {
766 quote! { ((1 as #repr) << #len) - 1 }
767 };
768
769 if is_primitive(ty) {
770 quote! {
771 let #field_name = ((raw >> #start) & #mask) as #ty;
772 }
773 } else {
774 let field_str = field_name.to_string();
775 quote! {
776 let field_repr = ((raw >> #start) & #mask);
777 let #field_name = <#ty as nexus_bits::IntEnum>::try_from_repr(field_repr as _)
778 .ok_or(nexus_bits::UnknownDiscriminant {
779 field: #field_str,
780 value: raw,
781 })?;
782 }
783 }
784 }
785 MemberDef::Flag { name: field_name, bit } => {
786 quote! {
787 let #field_name = (raw >> #bit) & 1 != 0;
788 }
789 }
790 }
791 }).collect();
792
793 let field_names: Vec<&Ident> = variant.members.iter().map(|m| m.name()).collect();
794
795 if field_names.is_empty() {
796 quote! {
797 #disc_val => Ok(Self::#variant_name),
798 }
799 } else {
800 quote! {
801 #disc_val => {
802 #(#unpack_statements)*
803 Ok(Self::#variant_name { #(#field_names),* })
804 }
805 }
806 }
807 }).collect();
808
809 quote! {
810 #[inline]
812 pub fn unpack(raw: #repr) -> Result<Self, nexus_bits::UnknownDiscriminant<#repr>> {
813 let discriminant = ((raw >> #disc_start) & #disc_mask) as u64;
814 match discriminant {
815 #(#match_arms)*
816 _ => Err(nexus_bits::UnknownDiscriminant {
817 field: "",
818 value: raw,
819 }),
820 }
821 }
822 }
823}