1use proc_macro2::{Ident, TokenStream};
25use proc_macro_error::*;
26use quote::quote;
27use syn::{parse_macro_input, DeriveInput};
28
29struct Variant<'a> {
30 ident: &'a Ident,
31 unit_field_count: usize,
32 has_named_fields: bool,
33}
34
35#[proc_macro_error]
45#[proc_macro_derive(Ordinal)]
46pub fn derive_ordinal(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
47 let input = parse_macro_input!(input as DeriveInput);
48
49 let variants = detect_variants(&input);
50
51 let match_arms = generate_match_arms(&variants, &input);
52
53 let enum_ident = &input.ident;
54
55 let tokens = quote! {
56 impl #enum_ident {
57 pub fn ordinal(&self) -> usize {
58 match self {
59 #(#match_arms,)*
60 }
61 }
62 }
63 };
64 tokens.into()
65}
66
67fn detect_variants(input: &DeriveInput) -> Vec<Variant> {
68 let mut vec = Vec::new();
69
70 let data = match &input.data {
71 syn::Data::Enum(data) => data,
72 _ => abort_call_site!("cannot derive `Ordinal` on an item which is not an enum"),
73 };
74
75 for variant in &data.variants {
76 vec.push(detect_variant(variant));
77 }
78
79 vec
80}
81
82fn detect_variant(variant: &syn::Variant) -> Variant {
83 let ident = &variant.ident;
84
85 let (unit_field_count, has_named_fields) = match &variant.fields {
86 syn::Fields::Named(_) => (0, true),
87 syn::Fields::Unit => (0, false),
88 syn::Fields::Unnamed(unnanmed) => (unnanmed.unnamed.len(), false),
89 };
90
91 Variant {
92 ident,
93 unit_field_count,
94 has_named_fields,
95 }
96}
97
98fn generate_match_arms(variants: &[Variant], input: &DeriveInput) -> Vec<TokenStream> {
99 let mut vec = Vec::new();
100 let enum_ident = &input.ident;
101
102 for (ordinal, variant) in variants.iter().enumerate() {
103 let variant_ident = variant.ident;
104 let pattern = match (variant.has_named_fields, variant.unit_field_count) {
105 (true, _) => quote! { #enum_ident::#variant_ident { .. } },
106 (false, x) if x != 0 => {
107 let underscores: Vec<_> = (0..x).map(|_| quote! { _ }).collect();
108
109 quote! {
110 #enum_ident::#variant_ident(#(#underscores),*)
111 }
112 }
113 (false, 0) => quote! { #enum_ident::#variant_ident },
114 _ => unreachable!(),
115 };
116
117 vec.push(quote! {
118 #pattern => #ordinal
119 });
120 }
121
122 vec
123}