discriminant_macro/
lib.rs1use proc_macro::TokenStream as TokenStream1;
2use proc_macro2::{Span, TokenStream};
3use proc_macro_error::{abort, proc_macro_error};
4use syn::punctuated::Punctuated;
5use syn::*;
6use template_quote::quote;
7
8fn random() -> u64 {
9 use std::hash::{BuildHasher, Hasher};
10 std::collections::hash_map::RandomState::new()
11 .build_hasher()
12 .finish()
13}
14
15fn internal(input: ItemEnum) -> TokenStream {
16 let krate: Path = input
17 .attrs
18 .iter()
19 .filter_map(|a| match &a.meta {
20 Meta::List(MetaList { path, tokens, .. }) => {
21 if let (true, krate) = (path.is_ident("discriminant"), parse_quote!(#tokens)) {
22 Some(krate)
23 } else {
24 None
25 }
26 }
27 _ => None,
28 })
29 .next()
30 .unwrap_or(parse_quote!(::discriminant));
31 let discriminant_attrs = input
32 .attrs
33 .iter()
34 .filter_map(|a| match &a.meta {
35 Meta::NameValue(MetaNameValue { path, value, .. })
36 if path.is_ident("discriminant_attr") =>
37 {
38 let s: LitStr = parse2(quote! {#value}).unwrap();
39 Some(s.value())
40 }
41 _ => None,
42 })
43 .collect::<Vec<_>>();
44 let discriminant_attrs = core::convert::identity::<ItemStruct>(
45 parse_str(&format!("{} struct S {{}}", discriminant_attrs.join(""))).unwrap(),
46 )
47 .attrs;
48 let specified_repr = discriminant_attrs
49 .iter()
50 .chain(&input.attrs)
51 .filter_map(|a| match &a.meta {
52 Meta::List(MetaList { path, tokens, .. }) if path.is_ident("repr") => {
53 if let Ok(reprs) = parse::Parser::parse2(
54 Punctuated::<Meta, Token![,]>::parse_terminated,
55 tokens.clone(),
56 ) {
57 reprs
58 .iter()
59 .filter_map(|r| Some(r.path().get_ident()?.to_string()))
60 .filter_map(|r| match r.as_str() {
61 "u8" | "u16" | "u32" | "u64" | "usize" | "i8" | "i16" | "i32"
62 | "i64" | "isize" => Some(Ident::new(&r, Span::call_site())),
63 _ => None,
64 })
65 .next()
66 } else {
67 None
68 }
69 }
70 _ => None,
71 })
72 .next();
73 let repr = specified_repr
74 .clone()
75 .unwrap_or(Ident::new("isize", Span::call_site()));
76 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
77 let discriminant_enum_ident = Ident::new(
78 &format!("__Discriminant_{}_{}", &input.ident, random() % 1000),
79 Span::call_site(),
80 );
81 let disc_indices = input
82 .variants
83 .iter()
84 .scan(parse_quote!(0), |acc, variant| {
85 if let Some((_, expr)) = &variant.discriminant {
86 *acc = expr.clone();
87 }
88 let ret = acc.clone();
89 *acc = parse_quote!(#ret + 1);
90 Some(ret)
91 })
92 .collect::<Vec<Expr>>();
93 quote! {
94 #[repr(#repr)]
95 #(#discriminant_attrs)*
96 #[derive(
97 ::core::marker::Copy,
98 ::core::clone::Clone,
99 ::core::fmt::Debug,
100 ::core::hash::Hash,
101 ::core::cmp::PartialEq,
102 ::core::cmp::Eq,
103 )]
104 #{&input.vis} enum #discriminant_enum_ident {
105 #(for variant in &input.variants) {
106 #{
107 variant.attrs.iter().filter_map(|a| match &a.meta {
108 Meta::NameValue(MetaNameValue{path, value, ..}) if path.is_ident("discriminant_attr") => {
109 let s: LitStr = parse2(quote! {#value}).unwrap();
110 let discriminant_attrs = core::convert::identity::<ItemStruct>(
111 parse_str(&format!("{} struct S {{}}", s.value())).unwrap()
112 ).attrs;
113 Some(quote!{#(#discriminant_attrs)*})
114 },
115 _ => None,
116 }).next()
117 }
118 #{&variant.ident}
119 #(if let Some((eq_token, expr)) = &variant.discriminant) {
120 #eq_token #expr
121 },
122 }
123 }
124
125 impl ::core::fmt::Display for #discriminant_enum_ident {
126 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
127 <Self as ::core::fmt::Debug>::fmt(self, f)
128 }
129 }
130
131 impl ::core::cmp::PartialOrd for #discriminant_enum_ident {
132 fn partial_cmp(&self, other: &Self) -> ::core::option::Option<::core::cmp::Ordering> {
133 (*self as #repr).partial_cmp(&(*other as #repr))
134 }
135 }
136
137 impl ::core::cmp::Ord for #discriminant_enum_ident {
138 fn cmp(&self, other: &Self) -> ::core::cmp::Ordering {
139 (*self as #repr).cmp(&(*other as #repr))
140 }
141 }
142
143 #[automatically_derived]
144 unsafe impl #impl_generics #krate::Enum for #{&input.ident}
145 #ty_generics #where_clause
146 {
147 type Discriminant = #discriminant_enum_ident;
148
149 fn discriminant(&self) -> Self::Discriminant {
150 match self {
151 #(for Variant{ident, fields, ..} in &input.variants) {
152 Self::#ident
153 #(if let Fields::Unnamed(_) = fields) { (..) }
154 #(if let Fields::Named(_) = fields) { {..} }
155 => #discriminant_enum_ident::#ident,
156 }
157 }
158 }
159 }
160
161 impl ::core::convert::TryFrom<#repr> for #discriminant_enum_ident {
162 type Error = ();
163 fn try_from(value: #repr) -> ::core::result::Result<Self, Self::Error> {
164 #(for (variant, disc) in input.variants.iter().zip(&disc_indices)) {
165 if value == #disc { ::core::result::Result::Ok(Self::#{&variant.ident}) } else
166 }
167 { ::core::result::Result::Err(()) }
168 }
169 }
170
171 impl ::core::convert::Into<#repr> for #discriminant_enum_ident {
172 fn into(self) -> #repr {
173 self as #repr
174 }
175 }
176
177 unsafe impl #krate::Discriminant for #discriminant_enum_ident {
178 type Repr = #repr;
179 fn all() -> impl ::core::iter::Iterator<Item = Self> {
180 struct Iter(::core::option::Option<#discriminant_enum_ident>);
181 impl ::core::iter::Iterator for Iter {
182 type Item = #discriminant_enum_ident;
183 fn next(&mut self) -> Option<Self::Item> {
184 match self.0 {
185 #(for (curr, next) in input.variants.iter().zip(
186 input.variants.iter().skip(1).map(Some).chain(core::iter::once(None))
187 )) {
188 ::core::option::Option::Some(#discriminant_enum_ident::#{&curr.ident}) => {
189 let ret = self.0;
190 self.0 = #(if let Some(next) = next) {
191 Some(#discriminant_enum_ident::#{&next.ident})
192 } #(else) { None };
193 ret
194 }
195 }
196 ::core::option::Option::None => ::core::option::Option::None,
197 }
198 }
199 fn size_hint(&self) -> (
200 ::core::primitive::usize,
201 ::core::option::Option<::core::primitive::usize>
202 ) {
203 let n = Self(self.0).count();
204 (n, ::core::option::Option::Some(n))
205 }
206 fn count(self) -> usize {
207 match self.0 {
208 #(for (n, variant) in input.variants.iter().enumerate()) {
209 ::core::option::Option::Some(#discriminant_enum_ident::#{&variant.ident}) => #{disc_indices.len() - n},
210 }
211 ::core::option::Option::None => 0,
212 }
213 }
214 fn last(self) -> Option<Self::Item> {
215 #(if let Some(last) = &input.variants.iter().last()) {
216 self.0.map(|_| #discriminant_enum_ident::#{&last.ident})
217 } #(else) {
218 ::core::option::Option::None
219 }
220 }
221 }
222 #(if let Some(item) = input.variants.iter().next()) {
223 Iter(::core::option::Option::Some(#discriminant_enum_ident::#{&item.ident}))
224 } #(else) {
225 Iter(::core::option::Option::None)
226 }
227 }
228 }
229 }
230}
231
232#[proc_macro_derive(Enum, attributes(discriminant, discriminant_attr))]
233#[proc_macro_error]
234pub fn derive_enum(input: TokenStream1) -> TokenStream1 {
235 internal(parse(input).unwrap_or_else(|_| {
236 abort!(
237 Span::call_site(),
238 "#[derive(Enum)] is only applicative on enums."
239 )
240 }))
241 .into()
242}