Skip to main content

enum_ordinalize_derive/
lib.rs

1/*!
2# Enum Ordinalize Derive
3
4This library enables enums to not only obtain the ordinal values of their variants but also allows for the construction of enums from an ordinal value. See the [`enum-ordinalize`](https://crates.io/crates/enum-ordinalize) crate.
5*/
6
7#![no_std]
8
9#[macro_use]
10extern crate alloc;
11
12mod int128;
13mod int_wrapper;
14mod panic;
15mod variant_type;
16
17use alloc::{string::ToString, vec::Vec};
18
19use proc_macro::TokenStream;
20use quote::quote;
21use syn::{
22    Data, DeriveInput, Expr, Fields, Ident, Lit, Meta, Token, UnOp, Visibility,
23    parse::{Parse, ParseStream},
24    parse_macro_input,
25    punctuated::Punctuated,
26};
27use variant_type::VariantType;
28
29use crate::{int_wrapper::IntWrapper, int128::Int128};
30
31#[proc_macro_derive(Ordinalize, attributes(ordinalize))]
32pub fn ordinalize_derive(input: TokenStream) -> TokenStream {
33    struct ConstMember {
34        vis:      Option<Visibility>,
35        ident:    Ident,
36        meta:     Vec<Meta>,
37        function: bool,
38    }
39
40    impl Parse for ConstMember {
41        #[inline]
42        fn parse(input: ParseStream) -> syn::Result<Self> {
43            let vis = input.parse::<Visibility>().ok();
44
45            let _ = input.parse::<Token![const]>();
46
47            let function = input.parse::<Token![fn]>().is_ok();
48
49            let ident = input.parse::<Ident>()?;
50
51            let mut meta = Vec::new();
52
53            if !input.is_empty() {
54                input.parse::<Token![,]>()?;
55
56                if !input.is_empty() {
57                    let result = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
58
59                    let mut has_inline = false;
60
61                    for m in result {
62                        if m.path().is_ident("inline") {
63                            has_inline = true;
64                        }
65
66                        meta.push(m);
67                    }
68
69                    if !has_inline {
70                        meta.push(syn::parse_str("inline")?);
71                    }
72                }
73            }
74
75            Ok(Self {
76                vis,
77                ident,
78                meta,
79                function,
80            })
81        }
82    }
83
84    struct ConstFunctionMember {
85        vis:   Option<Visibility>,
86        ident: Ident,
87        meta:  Vec<Meta>,
88    }
89
90    impl Parse for ConstFunctionMember {
91        #[inline]
92        fn parse(input: ParseStream) -> syn::Result<Self> {
93            let vis = input.parse::<Visibility>().ok();
94
95            let _ = input.parse::<Token![const]>();
96
97            input.parse::<Token![fn]>()?;
98
99            let ident = input.parse::<Ident>()?;
100
101            let mut meta = Vec::new();
102
103            if !input.is_empty() {
104                input.parse::<Token![,]>()?;
105
106                if !input.is_empty() {
107                    let result = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
108
109                    let mut has_inline = false;
110
111                    for m in result {
112                        if m.path().is_ident("inline") {
113                            has_inline = true;
114                        }
115
116                        meta.push(m);
117                    }
118
119                    if !has_inline {
120                        meta.push(syn::parse_str("inline")?);
121                    }
122                }
123            }
124
125            Ok(Self {
126                vis,
127                ident,
128                meta,
129            })
130        }
131    }
132
133    struct MyDeriveInput {
134        ast:                        DeriveInput,
135        variant_type:               VariantType,
136        values:                     Vec<IntWrapper>,
137        variant_idents:             Vec<Ident>,
138        use_constant_counter:       bool,
139        enable_trait:               bool,
140        enable_variant_count:       Option<ConstMember>,
141        enable_variants:            Option<ConstMember>,
142        enable_values:              Option<ConstMember>,
143        enable_from_ordinal_unsafe: Option<ConstFunctionMember>,
144        enable_from_ordinal:        Option<ConstFunctionMember>,
145        enable_ordinal:             Option<ConstFunctionMember>,
146    }
147
148    impl Parse for MyDeriveInput {
149        fn parse(input: ParseStream) -> syn::Result<Self> {
150            let ast = input.parse::<DeriveInput>()?;
151
152            let mut variant_type = VariantType::default();
153            let mut enable_trait = cfg!(feature = "traits");
154            let mut enable_variant_count = None;
155            let mut enable_variants = None;
156            let mut enable_values = None;
157            let mut enable_from_ordinal_unsafe = None;
158            let mut enable_from_ordinal = None;
159            let mut enable_ordinal = None;
160
161            for attr in ast.attrs.iter() {
162                let path = attr.path();
163
164                if let Some(ident) = path.get_ident() {
165                    match ident.to_string().as_str() {
166                        "repr" => {
167                            if let Meta::List(list) = &attr.meta {
168                                let result = list.parse_args_with(
169                                    Punctuated::<Meta, Token![,]>::parse_terminated,
170                                )?;
171
172                                for meta in result {
173                                    if let Some(ident) = meta.path().get_ident() {
174                                        let repr_type = VariantType::from_str(ident.to_string());
175
176                                        if !matches!(repr_type, VariantType::NonDetermined) {
177                                            variant_type = repr_type;
178                                            break;
179                                        }
180                                    }
181                                }
182                            }
183                        },
184                        "ordinalize" => {
185                            if let Meta::List(list) = &attr.meta {
186                                let result = list.parse_args_with(
187                                    Punctuated::<Meta, Token![,]>::parse_terminated,
188                                )?;
189
190                                for meta in result {
191                                    let path = meta.path();
192
193                                    if let Some(ident) = path.get_ident() {
194                                        match ident.to_string().as_str() {
195                                            "impl_trait" => {
196                                                if let Meta::NameValue(name_value) = &meta {
197                                                    if let Expr::Lit(lit) = &name_value.value {
198                                                        if let Lit::Bool(value) = &lit.lit {
199                                                            if cfg!(feature = "traits") {
200                                                                enable_trait = value.value;
201                                                            }
202                                                        } else {
203                                                            return Err(
204                                                                panic::bool_attribute_usage(
205                                                                    ident, lit,
206                                                                ),
207                                                            );
208                                                        }
209                                                    } else {
210                                                        return Err(panic::bool_attribute_usage(
211                                                            ident,
212                                                            &name_value.value,
213                                                        ));
214                                                    }
215                                                } else {
216                                                    return Err(panic::bool_attribute_usage(
217                                                        ident, &meta,
218                                                    ));
219                                                }
220                                            },
221                                            "variant_count" => {
222                                                if let Meta::List(list) = &meta {
223                                                    enable_variant_count = Some(list.parse_args()?);
224                                                } else {
225                                                    return Err(panic::list_attribute_usage(
226                                                        ident, &meta,
227                                                    ));
228                                                }
229                                            },
230                                            "variants" => {
231                                                if let Meta::List(list) = &meta {
232                                                    enable_variants = Some(list.parse_args()?);
233                                                } else {
234                                                    return Err(panic::list_attribute_usage(
235                                                        ident, &meta,
236                                                    ));
237                                                }
238                                            },
239                                            "values" => {
240                                                if let Meta::List(list) = &meta {
241                                                    enable_values = Some(list.parse_args()?);
242                                                } else {
243                                                    return Err(panic::list_attribute_usage(
244                                                        ident, &meta,
245                                                    ));
246                                                }
247                                            },
248                                            "from_ordinal_unsafe" => {
249                                                if let Meta::List(list) = &meta {
250                                                    enable_from_ordinal_unsafe =
251                                                        Some(list.parse_args()?);
252                                                } else {
253                                                    return Err(panic::list_attribute_usage(
254                                                        ident, &meta,
255                                                    ));
256                                                }
257                                            },
258                                            "from_ordinal" => {
259                                                if let Meta::List(list) = &meta {
260                                                    enable_from_ordinal = Some(list.parse_args()?);
261                                                } else {
262                                                    return Err(panic::list_attribute_usage(
263                                                        ident, &meta,
264                                                    ));
265                                                }
266                                            },
267                                            "ordinal" => {
268                                                if let Meta::List(list) = &meta {
269                                                    enable_ordinal = Some(list.parse_args()?);
270                                                } else {
271                                                    return Err(panic::list_attribute_usage(
272                                                        ident, &meta,
273                                                    ));
274                                                }
275                                            },
276                                            _ => {
277                                                return Err(panic::sub_attributes_for_ordinalize(
278                                                    &meta,
279                                                ));
280                                            },
281                                        }
282                                    } else {
283                                        return Err(panic::sub_attributes_for_ordinalize(&meta));
284                                    }
285                                }
286                            } else {
287                                return Err(panic::list_attribute_usage(ident, attr));
288                            }
289                        },
290                        _ => (),
291                    }
292                }
293            }
294
295            let name = &ast.ident;
296
297            if let Data::Enum(data) = &ast.data {
298                let variant_count = data.variants.len();
299
300                if variant_count == 0 {
301                    return Err(panic::no_variant(name));
302                }
303
304                let mut values: Vec<IntWrapper> = Vec::with_capacity(variant_count);
305                let mut variant_idents: Vec<Ident> = Vec::with_capacity(variant_count);
306
307                let mut use_constant_counter = false;
308
309                if let VariantType::NonDetermined = variant_type {
310                    let mut min = i128::MAX;
311                    let mut max = i128::MIN;
312                    let mut counter = 0;
313
314                    for variant in data.variants.iter() {
315                        if let Fields::Unit = variant.fields {
316                            if let Some((_, exp)) = variant.discriminant.as_ref() {
317                                match exp {
318                                    Expr::Lit(lit) => {
319                                        if let Lit::Int(lit) = &lit.lit {
320                                            counter = lit.base10_parse().map_err(|error| {
321                                                syn::Error::new_spanned(lit, error)
322                                            })?;
323                                        } else {
324                                            return Err(panic::unsupported_discriminant(lit));
325                                        }
326                                    },
327                                    Expr::Unary(unary) => {
328                                        if let UnOp::Neg(_) = unary.op {
329                                            match unary.expr.as_ref() {
330                                            Expr::Lit(lit) => {
331                                                if let Lit::Int(lit) = &lit.lit {
332                                                    match lit.base10_parse::<i128>() {
333                                                        Ok(i) => {
334                                                            counter = -i;
335                                                        },
336                                                        Err(error) => {
337                                                            // overflow
338                                                            if lit.base10_digits() == "170141183460469231731687303715884105728" {
339                                                                counter = i128::MIN;
340                                                            } else {
341                                                                return Err(syn::Error::new_spanned(lit, error));
342                                                            }
343                                                        },
344                                                    }
345                                                } else {
346                                                    return Err(panic::unsupported_discriminant(lit));
347                                                }
348                                            },
349                                            Expr::Path(_)
350                                            | Expr::Cast(_)
351                                            | Expr::Binary(_)
352                                            | Expr::Call(_) => {
353                                                return Err(panic::constant_variable_on_non_determined_size_enum(unary))
354                                            },
355                                            _ => return Err(panic::unsupported_discriminant(unary)),
356                                        }
357                                        } else {
358                                            return Err(panic::unsupported_discriminant(unary));
359                                        }
360                                    },
361                                    Expr::Path(_)
362                                    | Expr::Cast(_)
363                                    | Expr::Binary(_)
364                                    | Expr::Call(_) => {
365                                        return Err(
366                                            panic::constant_variable_on_non_determined_size_enum(
367                                                exp,
368                                            ),
369                                        );
370                                    },
371                                    _ => return Err(panic::unsupported_discriminant(exp)),
372                                }
373                            };
374
375                            if min > counter {
376                                min = counter;
377                            }
378
379                            if max < counter {
380                                max = counter;
381                            }
382
383                            variant_idents.push(variant.ident.clone());
384
385                            values.push(IntWrapper::from(counter));
386
387                            counter = counter.saturating_add(1);
388                        } else {
389                            return Err(panic::not_unit_variant(variant));
390                        }
391                    }
392
393                    if min >= i8::MIN as i128 && max <= i8::MAX as i128 {
394                        variant_type = VariantType::I8;
395                    } else if min >= i16::MIN as i128 && max <= i16::MAX as i128 {
396                        variant_type = VariantType::I16;
397                    } else if min >= i32::MIN as i128 && max <= i32::MAX as i128 {
398                        variant_type = VariantType::I32;
399                    } else if min >= i64::MIN as i128 && max <= i64::MAX as i128 {
400                        variant_type = VariantType::I64;
401                    } else {
402                        variant_type = VariantType::I128;
403                    }
404                } else {
405                    let mut counter = Int128::ZERO;
406                    let mut constant_counter = 0;
407                    let mut last_exp: Option<&Expr> = None;
408
409                    for variant in data.variants.iter() {
410                        if let Fields::Unit = variant.fields {
411                            if let Some((_, exp)) = variant.discriminant.as_ref() {
412                                match exp {
413                                    Expr::Lit(lit) => {
414                                        if let Lit::Int(lit) = &lit.lit {
415                                            counter = lit.base10_parse().map_err(|error| {
416                                                syn::Error::new_spanned(lit, error)
417                                            })?;
418
419                                            values.push(IntWrapper::from(counter));
420
421                                            counter.inc();
422
423                                            last_exp = None;
424                                        } else {
425                                            return Err(panic::unsupported_discriminant(lit));
426                                        }
427                                    },
428                                    Expr::Unary(unary) => {
429                                        if let UnOp::Neg(_) = unary.op {
430                                            match unary.expr.as_ref() {
431                                                Expr::Lit(lit) => {
432                                                    if let Lit::Int(lit) = &lit.lit {
433                                                        counter = -lit.base10_parse().map_err(
434                                                            |error| {
435                                                                syn::Error::new_spanned(lit, error)
436                                                            },
437                                                        )?;
438
439                                                        values.push(IntWrapper::from(counter));
440
441                                                        counter.inc();
442
443                                                        last_exp = None;
444                                                    } else {
445                                                        return Err(
446                                                            panic::unsupported_discriminant(lit),
447                                                        );
448                                                    }
449                                                },
450                                                Expr::Path(_) => {
451                                                    values.push(IntWrapper::from((exp, 0)));
452
453                                                    last_exp = Some(exp);
454                                                    constant_counter = 1;
455                                                },
456                                                Expr::Cast(_) | Expr::Binary(_) | Expr::Call(_) => {
457                                                    values.push(IntWrapper::from((exp, 0)));
458
459                                                    last_exp = Some(exp);
460                                                    constant_counter = 1;
461
462                                                    use_constant_counter = true;
463                                                },
464                                                _ => {
465                                                    return Err(panic::unsupported_discriminant(
466                                                        exp,
467                                                    ));
468                                                },
469                                            }
470                                        } else {
471                                            return Err(panic::unsupported_discriminant(unary));
472                                        }
473                                    },
474                                    Expr::Path(_) => {
475                                        values.push(IntWrapper::from((exp, 0)));
476
477                                        last_exp = Some(exp);
478                                        constant_counter = 1;
479                                    },
480                                    Expr::Cast(_) | Expr::Binary(_) | Expr::Call(_) => {
481                                        values.push(IntWrapper::from((exp, 0)));
482
483                                        last_exp = Some(exp);
484                                        constant_counter = 1;
485
486                                        use_constant_counter = true;
487                                    },
488                                    _ => return Err(panic::unsupported_discriminant(exp)),
489                                }
490                            } else if let Some(exp) = last_exp {
491                                values.push(IntWrapper::from((exp, constant_counter)));
492
493                                constant_counter += 1;
494
495                                use_constant_counter = true;
496                            } else {
497                                values.push(IntWrapper::from(counter));
498
499                                counter.inc();
500                            }
501
502                            variant_idents.push(variant.ident.clone());
503                        } else {
504                            return Err(panic::not_unit_variant(variant));
505                        }
506                    }
507                }
508
509                Ok(MyDeriveInput {
510                    ast,
511                    variant_type,
512                    values,
513                    variant_idents,
514                    use_constant_counter,
515                    enable_trait,
516                    enable_variant_count,
517                    enable_variants,
518                    enable_values,
519                    enable_from_ordinal_unsafe,
520                    enable_from_ordinal,
521                    enable_ordinal,
522                })
523            } else {
524                Err(panic::not_enum(&ast.ident))
525            }
526        }
527    }
528
529    // Parse the token stream
530    let derive_input = parse_macro_input!(input as MyDeriveInput);
531
532    let MyDeriveInput {
533        ast,
534        variant_type,
535        values,
536        variant_idents,
537        use_constant_counter,
538        enable_trait,
539        enable_variant_count,
540        enable_variants,
541        enable_values,
542        enable_ordinal,
543        enable_from_ordinal_unsafe,
544        enable_from_ordinal,
545    } = derive_input;
546
547    // Get the identifier of the type.
548    let name = &ast.ident;
549
550    let variant_count = values.len();
551
552    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
553
554    // Build the code
555    let mut expanded = proc_macro2::TokenStream::new();
556
557    if enable_trait {
558        #[cfg(feature = "traits")]
559        {
560            let from_ordinal_unsafe = if variant_count == 1 {
561                let variant_ident = &variant_idents[0];
562
563                quote! {
564                    #[inline]
565                    unsafe fn from_ordinal_unsafe(_number: #variant_type) -> Self {
566                        Self::#variant_ident
567                    }
568                }
569            } else {
570                quote! {
571                    #[inline]
572                    unsafe fn from_ordinal_unsafe(number: #variant_type) -> Self {
573                        unsafe { ::core::mem::transmute(number) }
574                    }
575                }
576            };
577
578            let from_ordinal = if use_constant_counter {
579                quote! {
580                    #[inline]
581                    fn from_ordinal(number: #variant_type) -> Option<Self> {
582                        if false {
583                            unreachable!()
584                        } #( else if number == #values {
585                            Some(Self::#variant_idents)
586                        } )* else {
587                            None
588                        }
589                    }
590                }
591            } else {
592                quote! {
593                    #[inline]
594                    fn from_ordinal(number: #variant_type) -> Option<Self> {
595                        match number{
596                            #(
597                                #values => Some(Self::#variant_idents),
598                            )*
599                            _ => None
600                        }
601                    }
602                }
603            };
604
605            expanded.extend(quote! {
606                impl #impl_generics ::enum_ordinalize::Ordinalize for #name #ty_generics #where_clause {
607                    type VariantType = #variant_type;
608
609                    const VARIANT_COUNT: usize = #variant_count;
610
611                    const VARIANTS: &'static [Self] = &[#( Self::#variant_idents, )*];
612
613                    const VALUES: &'static [#variant_type] = &[#( #values, )*];
614
615                    #[inline]
616                    fn ordinal(&self) -> #variant_type {
617                        match self {
618                            #(
619                                Self::#variant_idents => #values,
620                            )*
621                        }
622                    }
623
624                    #from_ordinal_unsafe
625
626                    #from_ordinal
627                }
628            });
629        }
630    }
631
632    let mut expanded_2 = proc_macro2::TokenStream::new();
633
634    if let Some(ConstMember {
635        vis,
636        ident,
637        meta,
638        function,
639    }) = enable_variant_count
640    {
641        expanded_2.extend(if function {
642            quote! {
643                #(#[#meta])*
644                #vis const fn #ident () -> usize {
645                    #variant_count
646                }
647            }
648        } else {
649            quote! {
650                #(#[#meta])*
651                #vis const #ident: usize = #variant_count;
652            }
653        });
654    }
655
656    if let Some(ConstMember {
657        vis,
658        ident,
659        meta,
660        function,
661    }) = enable_variants
662    {
663        expanded_2.extend(if function {
664            quote! {
665                #(#[#meta])*
666                #vis const fn #ident () -> [Self; #variant_count] {
667                    [#( Self::#variant_idents, )*]
668                }
669            }
670        } else {
671            quote! {
672                #(#[#meta])*
673                #vis const #ident: [Self; #variant_count] = [#( Self::#variant_idents, )*];
674            }
675        });
676    }
677
678    if let Some(ConstMember {
679        vis,
680        ident,
681        meta,
682        function,
683    }) = enable_values
684    {
685        expanded_2.extend(if function {
686            quote! {
687                #(#[#meta])*
688                #vis const fn #ident () -> [#variant_type; #variant_count] {
689                    [#( #values, )*]
690                }
691            }
692        } else {
693            quote! {
694                #(#[#meta])*
695                #vis const #ident: [#variant_type; #variant_count] = [#( #values, )*];
696            }
697        });
698    }
699
700    if let Some(ConstFunctionMember {
701        vis,
702        ident,
703        meta,
704    }) = enable_from_ordinal_unsafe
705    {
706        let from_ordinal_unsafe = if variant_count == 1 {
707            let variant_ident = &variant_idents[0];
708
709            quote! {
710                #(#[#meta])*
711                #vis const unsafe fn #ident (_number: #variant_type) -> Self {
712                    Self::#variant_ident
713                }
714            }
715        } else {
716            quote! {
717                #(#[#meta])*
718                #vis const unsafe fn #ident (number: #variant_type) -> Self {
719                    unsafe { ::core::mem::transmute(number) }
720                }
721            }
722        };
723
724        expanded_2.extend(from_ordinal_unsafe);
725    }
726
727    if let Some(ConstFunctionMember {
728        vis,
729        ident,
730        meta,
731    }) = enable_from_ordinal
732    {
733        let from_ordinal = if use_constant_counter {
734            quote! {
735                #(#[#meta])*
736                #vis const fn #ident (number: #variant_type) -> Option<Self> {
737                    if false {
738                        unreachable!()
739                    } #( else if number == #values {
740                        Some(Self::#variant_idents)
741                    } )* else {
742                        None
743                    }
744                }
745            }
746        } else {
747            quote! {
748                #(#[#meta])*
749                #vis const fn #ident (number: #variant_type) -> Option<Self> {
750                    match number{
751                        #(
752                            #values => Some(Self::#variant_idents),
753                        )*
754                        _ => None
755                    }
756                }
757            }
758        };
759
760        expanded_2.extend(from_ordinal);
761    }
762
763    if let Some(ConstFunctionMember {
764        vis,
765        ident,
766        meta,
767    }) = enable_ordinal
768    {
769        expanded_2.extend(quote! {
770            #(#[#meta])*
771            #vis const fn #ident (&self) -> #variant_type {
772                match self {
773                    #(
774                        Self::#variant_idents => #values,
775                    )*
776                }
777            }
778        });
779    }
780
781    if !expanded_2.is_empty() {
782        expanded.extend(quote! {
783            impl #impl_generics #name #ty_generics #where_clause {
784                #expanded_2
785            }
786        });
787    }
788
789    expanded.into()
790}