enum_ordinalize_derive/lib.rs
1/*!
2# Enum Ordinalize Derive
3
4This library enables enums to not only obtain the ordinal values of their variants but also allows for the construction of enums from an ordinal value. See the [`enum-ordinalize`](https://crates.io/crates/enum-ordinalize) crate.
5*/
6
7#![no_std]
8
9#[macro_use]
10extern crate alloc;
11
12mod int128;
13mod int_wrapper;
14mod panic;
15mod variant_type;
16
17use alloc::{string::ToString, vec::Vec};
18
19use proc_macro::TokenStream;
20use quote::quote;
21use syn::{
22 Data, DeriveInput, Expr, Fields, Ident, Lit, Meta, Token, UnOp, Visibility,
23 parse::{Parse, ParseStream},
24 parse_macro_input,
25 punctuated::Punctuated,
26};
27use variant_type::VariantType;
28
29use crate::{int_wrapper::IntWrapper, int128::Int128};
30
31#[proc_macro_derive(Ordinalize, attributes(ordinalize))]
32pub fn ordinalize_derive(input: TokenStream) -> TokenStream {
33 struct ConstMember {
34 vis: Option<Visibility>,
35 ident: Ident,
36 meta: Vec<Meta>,
37 function: bool,
38 }
39
40 impl Parse for ConstMember {
41 #[inline]
42 fn parse(input: ParseStream) -> syn::Result<Self> {
43 let vis = input.parse::<Visibility>().ok();
44
45 let _ = input.parse::<Token![const]>();
46
47 let function = input.parse::<Token![fn]>().is_ok();
48
49 let ident = input.parse::<Ident>()?;
50
51 let mut meta = Vec::new();
52
53 if !input.is_empty() {
54 input.parse::<Token![,]>()?;
55
56 if !input.is_empty() {
57 let result = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
58
59 let mut has_inline = false;
60
61 for m in result {
62 if m.path().is_ident("inline") {
63 has_inline = true;
64 }
65
66 meta.push(m);
67 }
68
69 if !has_inline {
70 meta.push(syn::parse_str("inline")?);
71 }
72 }
73 }
74
75 Ok(Self {
76 vis,
77 ident,
78 meta,
79 function,
80 })
81 }
82 }
83
84 struct ConstFunctionMember {
85 vis: Option<Visibility>,
86 ident: Ident,
87 meta: Vec<Meta>,
88 }
89
90 impl Parse for ConstFunctionMember {
91 #[inline]
92 fn parse(input: ParseStream) -> syn::Result<Self> {
93 let vis = input.parse::<Visibility>().ok();
94
95 let _ = input.parse::<Token![const]>();
96
97 input.parse::<Token![fn]>()?;
98
99 let ident = input.parse::<Ident>()?;
100
101 let mut meta = Vec::new();
102
103 if !input.is_empty() {
104 input.parse::<Token![,]>()?;
105
106 if !input.is_empty() {
107 let result = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
108
109 let mut has_inline = false;
110
111 for m in result {
112 if m.path().is_ident("inline") {
113 has_inline = true;
114 }
115
116 meta.push(m);
117 }
118
119 if !has_inline {
120 meta.push(syn::parse_str("inline")?);
121 }
122 }
123 }
124
125 Ok(Self {
126 vis,
127 ident,
128 meta,
129 })
130 }
131 }
132
133 struct MyDeriveInput {
134 ast: DeriveInput,
135 variant_type: VariantType,
136 values: Vec<IntWrapper>,
137 variant_idents: Vec<Ident>,
138 use_constant_counter: bool,
139 enable_trait: bool,
140 enable_variant_count: Option<ConstMember>,
141 enable_variants: Option<ConstMember>,
142 enable_values: Option<ConstMember>,
143 enable_from_ordinal_unsafe: Option<ConstFunctionMember>,
144 enable_from_ordinal: Option<ConstFunctionMember>,
145 enable_ordinal: Option<ConstFunctionMember>,
146 }
147
148 impl Parse for MyDeriveInput {
149 fn parse(input: ParseStream) -> syn::Result<Self> {
150 let ast = input.parse::<DeriveInput>()?;
151
152 let mut variant_type = VariantType::default();
153 let mut enable_trait = cfg!(feature = "traits");
154 let mut enable_variant_count = None;
155 let mut enable_variants = None;
156 let mut enable_values = None;
157 let mut enable_from_ordinal_unsafe = None;
158 let mut enable_from_ordinal = None;
159 let mut enable_ordinal = None;
160
161 for attr in ast.attrs.iter() {
162 let path = attr.path();
163
164 if let Some(ident) = path.get_ident() {
165 match ident.to_string().as_str() {
166 "repr" => {
167 if let Meta::List(list) = &attr.meta {
168 let result = list.parse_args_with(
169 Punctuated::<Meta, Token![,]>::parse_terminated,
170 )?;
171
172 for meta in result {
173 if let Some(ident) = meta.path().get_ident() {
174 let repr_type = VariantType::from_str(ident.to_string());
175
176 if !matches!(repr_type, VariantType::NonDetermined) {
177 variant_type = repr_type;
178 break;
179 }
180 }
181 }
182 }
183 },
184 "ordinalize" => {
185 if let Meta::List(list) = &attr.meta {
186 let result = list.parse_args_with(
187 Punctuated::<Meta, Token![,]>::parse_terminated,
188 )?;
189
190 for meta in result {
191 let path = meta.path();
192
193 if let Some(ident) = path.get_ident() {
194 match ident.to_string().as_str() {
195 "impl_trait" => {
196 if let Meta::NameValue(name_value) = &meta {
197 if let Expr::Lit(lit) = &name_value.value {
198 if let Lit::Bool(value) = &lit.lit {
199 if cfg!(feature = "traits") {
200 enable_trait = value.value;
201 }
202 } else {
203 return Err(
204 panic::bool_attribute_usage(
205 ident, lit,
206 ),
207 );
208 }
209 } else {
210 return Err(panic::bool_attribute_usage(
211 ident,
212 &name_value.value,
213 ));
214 }
215 } else {
216 return Err(panic::bool_attribute_usage(
217 ident, &meta,
218 ));
219 }
220 },
221 "variant_count" => {
222 if let Meta::List(list) = &meta {
223 enable_variant_count = Some(list.parse_args()?);
224 } else {
225 return Err(panic::list_attribute_usage(
226 ident, &meta,
227 ));
228 }
229 },
230 "variants" => {
231 if let Meta::List(list) = &meta {
232 enable_variants = Some(list.parse_args()?);
233 } else {
234 return Err(panic::list_attribute_usage(
235 ident, &meta,
236 ));
237 }
238 },
239 "values" => {
240 if let Meta::List(list) = &meta {
241 enable_values = Some(list.parse_args()?);
242 } else {
243 return Err(panic::list_attribute_usage(
244 ident, &meta,
245 ));
246 }
247 },
248 "from_ordinal_unsafe" => {
249 if let Meta::List(list) = &meta {
250 enable_from_ordinal_unsafe =
251 Some(list.parse_args()?);
252 } else {
253 return Err(panic::list_attribute_usage(
254 ident, &meta,
255 ));
256 }
257 },
258 "from_ordinal" => {
259 if let Meta::List(list) = &meta {
260 enable_from_ordinal = Some(list.parse_args()?);
261 } else {
262 return Err(panic::list_attribute_usage(
263 ident, &meta,
264 ));
265 }
266 },
267 "ordinal" => {
268 if let Meta::List(list) = &meta {
269 enable_ordinal = Some(list.parse_args()?);
270 } else {
271 return Err(panic::list_attribute_usage(
272 ident, &meta,
273 ));
274 }
275 },
276 _ => {
277 return Err(panic::sub_attributes_for_ordinalize(
278 &meta,
279 ));
280 },
281 }
282 } else {
283 return Err(panic::sub_attributes_for_ordinalize(&meta));
284 }
285 }
286 } else {
287 return Err(panic::list_attribute_usage(ident, attr));
288 }
289 },
290 _ => (),
291 }
292 }
293 }
294
295 let name = &ast.ident;
296
297 if let Data::Enum(data) = &ast.data {
298 let variant_count = data.variants.len();
299
300 if variant_count == 0 {
301 return Err(panic::no_variant(name));
302 }
303
304 let mut values: Vec<IntWrapper> = Vec::with_capacity(variant_count);
305 let mut variant_idents: Vec<Ident> = Vec::with_capacity(variant_count);
306
307 let mut use_constant_counter = false;
308
309 if let VariantType::NonDetermined = variant_type {
310 let mut min = i128::MAX;
311 let mut max = i128::MIN;
312 let mut counter = 0;
313
314 for variant in data.variants.iter() {
315 if let Fields::Unit = variant.fields {
316 if let Some((_, exp)) = variant.discriminant.as_ref() {
317 match exp {
318 Expr::Lit(lit) => {
319 if let Lit::Int(lit) = &lit.lit {
320 counter = lit.base10_parse().map_err(|error| {
321 syn::Error::new_spanned(lit, error)
322 })?;
323 } else {
324 return Err(panic::unsupported_discriminant(lit));
325 }
326 },
327 Expr::Unary(unary) => {
328 if let UnOp::Neg(_) = unary.op {
329 match unary.expr.as_ref() {
330 Expr::Lit(lit) => {
331 if let Lit::Int(lit) = &lit.lit {
332 match lit.base10_parse::<i128>() {
333 Ok(i) => {
334 counter = -i;
335 },
336 Err(error) => {
337 // overflow
338 if lit.base10_digits() == "170141183460469231731687303715884105728" {
339 counter = i128::MIN;
340 } else {
341 return Err(syn::Error::new_spanned(lit, error));
342 }
343 },
344 }
345 } else {
346 return Err(panic::unsupported_discriminant(lit));
347 }
348 },
349 Expr::Path(_)
350 | Expr::Cast(_)
351 | Expr::Binary(_)
352 | Expr::Call(_) => {
353 return Err(panic::constant_variable_on_non_determined_size_enum(unary))
354 },
355 _ => return Err(panic::unsupported_discriminant(unary)),
356 }
357 } else {
358 return Err(panic::unsupported_discriminant(unary));
359 }
360 },
361 Expr::Path(_)
362 | Expr::Cast(_)
363 | Expr::Binary(_)
364 | Expr::Call(_) => {
365 return Err(
366 panic::constant_variable_on_non_determined_size_enum(
367 exp,
368 ),
369 );
370 },
371 _ => return Err(panic::unsupported_discriminant(exp)),
372 }
373 };
374
375 if min > counter {
376 min = counter;
377 }
378
379 if max < counter {
380 max = counter;
381 }
382
383 variant_idents.push(variant.ident.clone());
384
385 values.push(IntWrapper::from(counter));
386
387 counter = counter.saturating_add(1);
388 } else {
389 return Err(panic::not_unit_variant(variant));
390 }
391 }
392
393 if min >= i8::MIN as i128 && max <= i8::MAX as i128 {
394 variant_type = VariantType::I8;
395 } else if min >= i16::MIN as i128 && max <= i16::MAX as i128 {
396 variant_type = VariantType::I16;
397 } else if min >= i32::MIN as i128 && max <= i32::MAX as i128 {
398 variant_type = VariantType::I32;
399 } else if min >= i64::MIN as i128 && max <= i64::MAX as i128 {
400 variant_type = VariantType::I64;
401 } else {
402 variant_type = VariantType::I128;
403 }
404 } else {
405 let mut counter = Int128::ZERO;
406 let mut constant_counter = 0;
407 let mut last_exp: Option<&Expr> = None;
408
409 for variant in data.variants.iter() {
410 if let Fields::Unit = variant.fields {
411 if let Some((_, exp)) = variant.discriminant.as_ref() {
412 match exp {
413 Expr::Lit(lit) => {
414 if let Lit::Int(lit) = &lit.lit {
415 counter = lit.base10_parse().map_err(|error| {
416 syn::Error::new_spanned(lit, error)
417 })?;
418
419 values.push(IntWrapper::from(counter));
420
421 counter.inc();
422
423 last_exp = None;
424 } else {
425 return Err(panic::unsupported_discriminant(lit));
426 }
427 },
428 Expr::Unary(unary) => {
429 if let UnOp::Neg(_) = unary.op {
430 match unary.expr.as_ref() {
431 Expr::Lit(lit) => {
432 if let Lit::Int(lit) = &lit.lit {
433 counter = -lit.base10_parse().map_err(
434 |error| {
435 syn::Error::new_spanned(lit, error)
436 },
437 )?;
438
439 values.push(IntWrapper::from(counter));
440
441 counter.inc();
442
443 last_exp = None;
444 } else {
445 return Err(
446 panic::unsupported_discriminant(lit),
447 );
448 }
449 },
450 Expr::Path(_) => {
451 values.push(IntWrapper::from((exp, 0)));
452
453 last_exp = Some(exp);
454 constant_counter = 1;
455 },
456 Expr::Cast(_) | Expr::Binary(_) | Expr::Call(_) => {
457 values.push(IntWrapper::from((exp, 0)));
458
459 last_exp = Some(exp);
460 constant_counter = 1;
461
462 use_constant_counter = true;
463 },
464 _ => {
465 return Err(panic::unsupported_discriminant(
466 exp,
467 ));
468 },
469 }
470 } else {
471 return Err(panic::unsupported_discriminant(unary));
472 }
473 },
474 Expr::Path(_) => {
475 values.push(IntWrapper::from((exp, 0)));
476
477 last_exp = Some(exp);
478 constant_counter = 1;
479 },
480 Expr::Cast(_) | Expr::Binary(_) | Expr::Call(_) => {
481 values.push(IntWrapper::from((exp, 0)));
482
483 last_exp = Some(exp);
484 constant_counter = 1;
485
486 use_constant_counter = true;
487 },
488 _ => return Err(panic::unsupported_discriminant(exp)),
489 }
490 } else if let Some(exp) = last_exp {
491 values.push(IntWrapper::from((exp, constant_counter)));
492
493 constant_counter += 1;
494
495 use_constant_counter = true;
496 } else {
497 values.push(IntWrapper::from(counter));
498
499 counter.inc();
500 }
501
502 variant_idents.push(variant.ident.clone());
503 } else {
504 return Err(panic::not_unit_variant(variant));
505 }
506 }
507 }
508
509 Ok(MyDeriveInput {
510 ast,
511 variant_type,
512 values,
513 variant_idents,
514 use_constant_counter,
515 enable_trait,
516 enable_variant_count,
517 enable_variants,
518 enable_values,
519 enable_from_ordinal_unsafe,
520 enable_from_ordinal,
521 enable_ordinal,
522 })
523 } else {
524 Err(panic::not_enum(&ast.ident))
525 }
526 }
527 }
528
529 // Parse the token stream
530 let derive_input = parse_macro_input!(input as MyDeriveInput);
531
532 let MyDeriveInput {
533 ast,
534 variant_type,
535 values,
536 variant_idents,
537 use_constant_counter,
538 enable_trait,
539 enable_variant_count,
540 enable_variants,
541 enable_values,
542 enable_ordinal,
543 enable_from_ordinal_unsafe,
544 enable_from_ordinal,
545 } = derive_input;
546
547 // Get the identifier of the type.
548 let name = &ast.ident;
549
550 let variant_count = values.len();
551
552 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
553
554 // Build the code
555 let mut expanded = proc_macro2::TokenStream::new();
556
557 if enable_trait {
558 #[cfg(feature = "traits")]
559 {
560 let from_ordinal_unsafe = if variant_count == 1 {
561 let variant_ident = &variant_idents[0];
562
563 quote! {
564 #[inline]
565 unsafe fn from_ordinal_unsafe(_number: #variant_type) -> Self {
566 Self::#variant_ident
567 }
568 }
569 } else {
570 quote! {
571 #[inline]
572 unsafe fn from_ordinal_unsafe(number: #variant_type) -> Self {
573 unsafe { ::core::mem::transmute(number) }
574 }
575 }
576 };
577
578 let from_ordinal = if use_constant_counter {
579 quote! {
580 #[inline]
581 fn from_ordinal(number: #variant_type) -> Option<Self> {
582 if false {
583 unreachable!()
584 } #( else if number == #values {
585 Some(Self::#variant_idents)
586 } )* else {
587 None
588 }
589 }
590 }
591 } else {
592 quote! {
593 #[inline]
594 fn from_ordinal(number: #variant_type) -> Option<Self> {
595 match number{
596 #(
597 #values => Some(Self::#variant_idents),
598 )*
599 _ => None
600 }
601 }
602 }
603 };
604
605 expanded.extend(quote! {
606 impl #impl_generics ::enum_ordinalize::Ordinalize for #name #ty_generics #where_clause {
607 type VariantType = #variant_type;
608
609 const VARIANT_COUNT: usize = #variant_count;
610
611 const VARIANTS: &'static [Self] = &[#( Self::#variant_idents, )*];
612
613 const VALUES: &'static [#variant_type] = &[#( #values, )*];
614
615 #[inline]
616 fn ordinal(&self) -> #variant_type {
617 match self {
618 #(
619 Self::#variant_idents => #values,
620 )*
621 }
622 }
623
624 #from_ordinal_unsafe
625
626 #from_ordinal
627 }
628 });
629 }
630 }
631
632 let mut expanded_2 = proc_macro2::TokenStream::new();
633
634 if let Some(ConstMember {
635 vis,
636 ident,
637 meta,
638 function,
639 }) = enable_variant_count
640 {
641 expanded_2.extend(if function {
642 quote! {
643 #(#[#meta])*
644 #vis const fn #ident () -> usize {
645 #variant_count
646 }
647 }
648 } else {
649 quote! {
650 #(#[#meta])*
651 #vis const #ident: usize = #variant_count;
652 }
653 });
654 }
655
656 if let Some(ConstMember {
657 vis,
658 ident,
659 meta,
660 function,
661 }) = enable_variants
662 {
663 expanded_2.extend(if function {
664 quote! {
665 #(#[#meta])*
666 #vis const fn #ident () -> [Self; #variant_count] {
667 [#( Self::#variant_idents, )*]
668 }
669 }
670 } else {
671 quote! {
672 #(#[#meta])*
673 #vis const #ident: [Self; #variant_count] = [#( Self::#variant_idents, )*];
674 }
675 });
676 }
677
678 if let Some(ConstMember {
679 vis,
680 ident,
681 meta,
682 function,
683 }) = enable_values
684 {
685 expanded_2.extend(if function {
686 quote! {
687 #(#[#meta])*
688 #vis const fn #ident () -> [#variant_type; #variant_count] {
689 [#( #values, )*]
690 }
691 }
692 } else {
693 quote! {
694 #(#[#meta])*
695 #vis const #ident: [#variant_type; #variant_count] = [#( #values, )*];
696 }
697 });
698 }
699
700 if let Some(ConstFunctionMember {
701 vis,
702 ident,
703 meta,
704 }) = enable_from_ordinal_unsafe
705 {
706 let from_ordinal_unsafe = if variant_count == 1 {
707 let variant_ident = &variant_idents[0];
708
709 quote! {
710 #(#[#meta])*
711 #vis const unsafe fn #ident (_number: #variant_type) -> Self {
712 Self::#variant_ident
713 }
714 }
715 } else {
716 quote! {
717 #(#[#meta])*
718 #vis const unsafe fn #ident (number: #variant_type) -> Self {
719 unsafe { ::core::mem::transmute(number) }
720 }
721 }
722 };
723
724 expanded_2.extend(from_ordinal_unsafe);
725 }
726
727 if let Some(ConstFunctionMember {
728 vis,
729 ident,
730 meta,
731 }) = enable_from_ordinal
732 {
733 let from_ordinal = if use_constant_counter {
734 quote! {
735 #(#[#meta])*
736 #vis const fn #ident (number: #variant_type) -> Option<Self> {
737 if false {
738 unreachable!()
739 } #( else if number == #values {
740 Some(Self::#variant_idents)
741 } )* else {
742 None
743 }
744 }
745 }
746 } else {
747 quote! {
748 #(#[#meta])*
749 #vis const fn #ident (number: #variant_type) -> Option<Self> {
750 match number{
751 #(
752 #values => Some(Self::#variant_idents),
753 )*
754 _ => None
755 }
756 }
757 }
758 };
759
760 expanded_2.extend(from_ordinal);
761 }
762
763 if let Some(ConstFunctionMember {
764 vis,
765 ident,
766 meta,
767 }) = enable_ordinal
768 {
769 expanded_2.extend(quote! {
770 #(#[#meta])*
771 #vis const fn #ident (&self) -> #variant_type {
772 match self {
773 #(
774 Self::#variant_idents => #values,
775 )*
776 }
777 }
778 });
779 }
780
781 if !expanded_2.is_empty() {
782 expanded.extend(quote! {
783 impl #impl_generics #name #ty_generics #where_clause {
784 #expanded_2
785 }
786 });
787 }
788
789 expanded.into()
790}