1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use proc_macro2::{Ident, TokenStream};
use proc_macro_error::*;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};
struct Variant<'a> {
ident: &'a Ident,
unit_field_count: usize,
has_named_fields: bool,
}
#[proc_macro_error]
#[proc_macro_derive(Ordinal)]
pub fn derive_ordinal(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let variants = detect_variants(&input);
let match_arms = generate_match_arms(&variants, &input);
let enum_ident = &input.ident;
let tokens = quote! {
impl #enum_ident {
pub fn ordinal(&self) -> usize {
match self {
#(#match_arms,)*
}
}
}
};
tokens.into()
}
fn detect_variants(input: &DeriveInput) -> Vec<Variant> {
let mut vec = Vec::new();
let data = match &input.data {
syn::Data::Enum(data) => data,
_ => abort_call_site!("cannot derive `Ordinal` on an item which is not an enum"),
};
for variant in &data.variants {
vec.push(detect_variant(variant));
}
vec
}
fn detect_variant(variant: &syn::Variant) -> Variant {
let ident = &variant.ident;
let (unit_field_count, has_named_fields) = match &variant.fields {
syn::Fields::Named(_) => (0, true),
syn::Fields::Unit => (0, false),
syn::Fields::Unnamed(unnanmed) => (unnanmed.unnamed.len(), false),
};
Variant {
ident,
unit_field_count,
has_named_fields,
}
}
fn generate_match_arms(variants: &[Variant], input: &DeriveInput) -> Vec<TokenStream> {
let mut vec = Vec::new();
let enum_ident = &input.ident;
for (ordinal, variant) in variants.iter().enumerate() {
let variant_ident = variant.ident;
let pattern = match (variant.has_named_fields, variant.unit_field_count) {
(true, _) => quote! { #enum_ident::#variant_ident { .. } },
(false, x) if x != 0 => {
let underscores: Vec<_> = (0..x).map(|_| quote! { _ }).collect();
quote! {
#enum_ident::#variant_ident(#(#underscores),*)
}
}
(false, 0) => quote! { #enum_ident::#variant_ident },
_ => unreachable!(),
};
vec.push(quote! {
#pattern => #ordinal
});
}
vec
}