discriminant_macro/
lib.rs

1use proc_macro::TokenStream as TokenStream1;
2use proc_macro2::{Span, TokenStream};
3use proc_macro_error::{abort, proc_macro_error};
4use syn::punctuated::Punctuated;
5use syn::*;
6use template_quote::quote;
7
8fn random() -> u64 {
9    use std::hash::{BuildHasher, Hasher};
10    std::collections::hash_map::RandomState::new()
11        .build_hasher()
12        .finish()
13}
14
15fn internal(input: ItemEnum) -> TokenStream {
16    let krate: Path = input
17        .attrs
18        .iter()
19        .filter_map(|a| match &a.meta {
20            Meta::List(MetaList { path, tokens, .. }) => {
21                if let (true, krate) = (path.is_ident("discriminant"), parse_quote!(#tokens)) {
22                    Some(krate)
23                } else {
24                    None
25                }
26            }
27            _ => None,
28        })
29        .next()
30        .unwrap_or(parse_quote!(::discriminant));
31    let discriminant_attrs = input
32        .attrs
33        .iter()
34        .filter_map(|a| match &a.meta {
35            Meta::NameValue(MetaNameValue { path, value, .. })
36                if path.is_ident("discriminant_attr") =>
37            {
38                let s: LitStr = parse2(quote! {#value}).unwrap();
39                Some(s.value())
40            }
41            _ => None,
42        })
43        .collect::<Vec<_>>();
44    let discriminant_attrs = core::convert::identity::<ItemStruct>(
45        parse_str(&format!("{} struct S {{}}", discriminant_attrs.join(""))).unwrap(),
46    )
47    .attrs;
48    let specified_repr = discriminant_attrs
49        .iter()
50        .chain(&input.attrs)
51        .filter_map(|a| match &a.meta {
52            Meta::List(MetaList { path, tokens, .. }) if path.is_ident("repr") => {
53                if let Ok(reprs) = parse::Parser::parse2(
54                    Punctuated::<Meta, Token![,]>::parse_terminated,
55                    tokens.clone(),
56                ) {
57                    reprs
58                        .iter()
59                        .filter_map(|r| Some(r.path().get_ident()?.to_string()))
60                        .filter_map(|r| match r.as_str() {
61                            "u8" | "u16" | "u32" | "u64" | "usize" | "i8" | "i16" | "i32"
62                            | "i64" | "isize" => Some(Ident::new(&r, Span::call_site())),
63                            _ => None,
64                        })
65                        .next()
66                } else {
67                    None
68                }
69            }
70            _ => None,
71        })
72        .next();
73    let repr = specified_repr
74        .clone()
75        .unwrap_or(Ident::new("isize", Span::call_site()));
76    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
77    let discriminant_enum_ident = Ident::new(
78        &format!("__Discriminant_{}_{}", &input.ident, random() % 1000),
79        Span::call_site(),
80    );
81    let disc_indices = input
82        .variants
83        .iter()
84        .scan(parse_quote!(0), |acc, variant| {
85            if let Some((_, expr)) = &variant.discriminant {
86                *acc = expr.clone();
87            }
88            let ret = acc.clone();
89            *acc = parse_quote!(#ret + 1);
90            Some(ret)
91        })
92        .collect::<Vec<Expr>>();
93    quote! {
94        #[repr(#repr)]
95        #(#discriminant_attrs)*
96        #[derive(
97            ::core::marker::Copy,
98            ::core::clone::Clone,
99            ::core::fmt::Debug,
100            ::core::hash::Hash,
101            ::core::cmp::PartialEq,
102            ::core::cmp::Eq,
103        )]
104        #{&input.vis} enum #discriminant_enum_ident {
105            #(for variant in &input.variants) {
106                #{
107                    variant.attrs.iter().filter_map(|a| match &a.meta {
108                        Meta::NameValue(MetaNameValue{path, value, ..}) if path.is_ident("discriminant_attr") => {
109                            let s: LitStr = parse2(quote! {#value}).unwrap();
110                            let discriminant_attrs = core::convert::identity::<ItemStruct>(
111                                parse_str(&format!("{} struct S {{}}", s.value())).unwrap()
112                            ).attrs;
113                            Some(quote!{#(#discriminant_attrs)*})
114                        },
115                        _ => None,
116                    }).next()
117                }
118                #{&variant.ident}
119                #(if let Some((eq_token, expr)) = &variant.discriminant) {
120                    #eq_token #expr
121                },
122            }
123        }
124
125        impl ::core::fmt::Display for #discriminant_enum_ident {
126            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
127                <Self as ::core::fmt::Debug>::fmt(self, f)
128            }
129        }
130
131        impl ::core::cmp::PartialOrd for #discriminant_enum_ident {
132            fn partial_cmp(&self, other: &Self) -> ::core::option::Option<::core::cmp::Ordering> {
133                (*self as #repr).partial_cmp(&(*other as #repr))
134            }
135        }
136
137        impl ::core::cmp::Ord for #discriminant_enum_ident {
138            fn cmp(&self, other: &Self) -> ::core::cmp::Ordering {
139                (*self as #repr).cmp(&(*other as #repr))
140            }
141        }
142
143        #[automatically_derived]
144        unsafe impl #impl_generics #krate::Enum for #{&input.ident}
145        #ty_generics #where_clause
146        {
147            type Discriminant = #discriminant_enum_ident;
148
149            fn discriminant(&self) -> Self::Discriminant {
150                match self {
151                    #(for Variant{ident, fields, ..} in &input.variants) {
152                        Self::#ident
153                        #(if let Fields::Unnamed(_) = fields) { (..) }
154                        #(if let Fields::Named(_) = fields) { {..} }
155                        => #discriminant_enum_ident::#ident,
156                    }
157                }
158            }
159        }
160
161        impl ::core::convert::TryFrom<#repr> for #discriminant_enum_ident {
162            type Error = ();
163            fn try_from(value: #repr) -> ::core::result::Result<Self, Self::Error> {
164                #(for (variant, disc) in input.variants.iter().zip(&disc_indices)) {
165                    if value == #disc { ::core::result::Result::Ok(Self::#{&variant.ident}) } else
166                }
167                { ::core::result::Result::Err(()) }
168            }
169        }
170
171        impl ::core::convert::Into<#repr> for #discriminant_enum_ident {
172            fn into(self) -> #repr {
173                self as #repr
174            }
175        }
176
177        unsafe impl #krate::Discriminant for #discriminant_enum_ident {
178            type Repr = #repr;
179            fn all() -> impl ::core::iter::Iterator<Item = Self> {
180                struct Iter(::core::option::Option<#discriminant_enum_ident>);
181                impl ::core::iter::Iterator for Iter {
182                    type Item = #discriminant_enum_ident;
183                    fn next(&mut self) -> Option<Self::Item> {
184                        match self.0 {
185                            #(for (curr, next) in input.variants.iter().zip(
186                                    input.variants.iter().skip(1).map(Some).chain(core::iter::once(None))
187                            )) {
188                                ::core::option::Option::Some(#discriminant_enum_ident::#{&curr.ident}) => {
189                                    let ret = self.0;
190                                    self.0 = #(if let Some(next) = next) {
191                                        Some(#discriminant_enum_ident::#{&next.ident})
192                                    } #(else) { None };
193                                    ret
194                                }
195                            }
196                            ::core::option::Option::None => ::core::option::Option::None,
197                        }
198                    }
199                    fn size_hint(&self) -> (
200                        ::core::primitive::usize,
201                        ::core::option::Option<::core::primitive::usize>
202                    ) {
203                        let n = Self(self.0).count();
204                        (n, ::core::option::Option::Some(n))
205                    }
206                    fn count(self) -> usize {
207                        match self.0 {
208                            #(for (n, variant) in input.variants.iter().enumerate()) {
209                                ::core::option::Option::Some(#discriminant_enum_ident::#{&variant.ident}) => #{disc_indices.len() - n},
210                            }
211                            ::core::option::Option::None => 0,
212                        }
213                    }
214                    fn last(self) -> Option<Self::Item> {
215                        #(if let Some(last) = &input.variants.iter().last()) {
216                            self.0.map(|_| #discriminant_enum_ident::#{&last.ident})
217                        } #(else) {
218                            ::core::option::Option::None
219                        }
220                    }
221                }
222                #(if let Some(item) = input.variants.iter().next()) {
223                    Iter(::core::option::Option::Some(#discriminant_enum_ident::#{&item.ident}))
224                } #(else) {
225                    Iter(::core::option::Option::None)
226                }
227            }
228        }
229    }
230}
231
232#[proc_macro_derive(Enum, attributes(discriminant, discriminant_attr))]
233#[proc_macro_error]
234pub fn derive_enum(input: TokenStream1) -> TokenStream1 {
235    internal(parse(input).unwrap_or_else(|_| {
236        abort!(
237            Span::call_site(),
238            "#[derive(Enum)] is only applicative on enums."
239        )
240    }))
241    .into()
242}