ordinalizer/
lib.rs

1//! A simple derive macro to generate an `ordinal()`
2//! method for enums.
3//!
4//! Unlike `num_derive::ToPrimitive`, this derive macro
5//! allows non-C-like enums. The `ordinal` function reflects
6//! the variant of the enum and does not account
7//! for fields.
8//!
9//! # Example
10//! ```
11//! use ordinalizer::Ordinal;
12//! #[derive(Ordinal)]
13//! enum Animal {
14//!     Dog,
15//!     Cat {
16//!         age: i32,
17//!     }
18//! }
19//!
20//! assert_eq!(Animal::Dog.ordinal(), 0);
21//! assert_eq!((Animal::Cat { age: 10 }).ordinal(), 1);
22//! ```
23
24use proc_macro2::{Ident, TokenStream};
25use proc_macro_error::*;
26use quote::quote;
27use syn::{parse_macro_input, DeriveInput};
28
29struct Variant<'a> {
30    ident: &'a Ident,
31    unit_field_count: usize,
32    has_named_fields: bool,
33}
34
35/// Generates a `fn ordinal(&self) -> usize` for an enum.
36///
37/// The enum may have any number of variants. It is not
38/// required to be a C-like enum, i.e. its variants
39/// may have named or unnamed fields.
40///
41/// The returned ordinals will correspond to the variant's
42/// index in the enum definition. For example, the first
43/// variant of enum will have ordinal `0`.
44#[proc_macro_error]
45#[proc_macro_derive(Ordinal)]
46pub fn derive_ordinal(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
47    let input = parse_macro_input!(input as DeriveInput);
48
49    let variants = detect_variants(&input);
50
51    let match_arms = generate_match_arms(&variants, &input);
52
53    let enum_ident = &input.ident;
54
55    let tokens = quote! {
56        impl #enum_ident {
57            pub fn ordinal(&self) -> usize {
58                match self {
59                    #(#match_arms,)*
60                }
61            }
62        }
63    };
64    tokens.into()
65}
66
67fn detect_variants(input: &DeriveInput) -> Vec<Variant> {
68    let mut vec = Vec::new();
69
70    let data = match &input.data {
71        syn::Data::Enum(data) => data,
72        _ => abort_call_site!("cannot derive `Ordinal` on an item which is not an enum"),
73    };
74
75    for variant in &data.variants {
76        vec.push(detect_variant(variant));
77    }
78
79    vec
80}
81
82fn detect_variant(variant: &syn::Variant) -> Variant {
83    let ident = &variant.ident;
84
85    let (unit_field_count, has_named_fields) = match &variant.fields {
86        syn::Fields::Named(_) => (0, true),
87        syn::Fields::Unit => (0, false),
88        syn::Fields::Unnamed(unnanmed) => (unnanmed.unnamed.len(), false),
89    };
90
91    Variant {
92        ident,
93        unit_field_count,
94        has_named_fields,
95    }
96}
97
98fn generate_match_arms(variants: &[Variant], input: &DeriveInput) -> Vec<TokenStream> {
99    let mut vec = Vec::new();
100    let enum_ident = &input.ident;
101
102    for (ordinal, variant) in variants.iter().enumerate() {
103        let variant_ident = variant.ident;
104        let pattern = match (variant.has_named_fields, variant.unit_field_count) {
105            (true, _) => quote! { #enum_ident::#variant_ident { .. } },
106            (false, x) if x != 0 => {
107                let underscores: Vec<_> = (0..x).map(|_| quote! { _ }).collect();
108
109                quote! {
110                    #enum_ident::#variant_ident(#(#underscores),*)
111                }
112            }
113            (false, 0) => quote! { #enum_ident::#variant_ident },
114            _ => unreachable!(),
115        };
116
117        vec.push(quote! {
118            #pattern => #ordinal
119        });
120    }
121
122    vec
123}