const_enum/
lib.rs

1#![deny(missing_docs)]
2
3//! This crate provides a procedural derive macro for constant `From` trait implementations for
4//! enumerations based on their `repr` type.
5//!
6//! Due to offering const support, this library requires the usage of Rust nightly.
7//! Additionally, you must add the following feature flags to your crate root:
8//!
9//! ```rust
10//! #![feature(const_trait_impl)]   // always required
11//! ```
12//!
13//! This is required as some features are currently gated behind these flags.
14//! Further documentation about usage can be found in the individual macro.
15
16extern crate proc_macro;
17
18use proc_macro::TokenStream as NativeTokenStream;
19use proc_macro2::{Delimiter, Span, TokenStream, TokenTree};
20use quote::quote;
21
22struct EnumVariant {
23    name: syn::Ident,
24    value: syn::Expr,
25}
26
27/// This function defines the procedural derive macro `ConstEnum`.
28/// It can be used on any enum with a `repr` type, e.g. `repr(u8)`.
29///
30/// When being used on an enum, this will automatically add the following implementations:
31///
32/// - `From<repr_type> for <enum_type>`: Convert repr type to enum, panic if invalid value
33/// - `From<enum_type> for <repr_type>`: Convert enum to repr type
34///
35/// # Example
36/// ```rust
37/// #![feature(const_trait_impl)]
38///
39/// use const_enum::ConstEnum;
40///
41/// #[derive(Copy, Clone, Debug, Eq, PartialEq, ConstEnum)]
42/// #[repr(u8)]
43/// enum Test {
44///     A = 0b010,
45///     B = 0b100,
46///     C = 0b001
47/// }
48///
49/// pub fn example() {
50///     println!("{:?}", Test::from(0b010 as u8));
51///     println!("{:?}", u8::from(Test::A));
52/// }
53/// ```
54#[proc_macro_derive(ConstEnum)]
55pub fn const_enum(input: NativeTokenStream) -> NativeTokenStream {
56    let input = syn::parse_macro_input!(input as syn::DeriveInput);
57
58    let enum_name = input.ident;
59    let enum_variants = get_enum_variants(&input.data);
60    let enum_type = get_enum_repr_type(&input.attrs);
61    let match_impl = build_from_match(&enum_name, &enum_variants);
62
63    let expanded = quote! {
64        impl const core::convert::From<#enum_name> for #enum_type {
65            fn from(value: #enum_name) -> Self {
66                value as Self
67            }
68        }
69
70        impl const core::convert::From<#enum_type> for #enum_name {
71            fn from(value: #enum_type) -> Self {
72                #match_impl
73            }
74        }
75    };
76
77    NativeTokenStream::from(expanded)
78}
79
80fn get_enum_repr_type(attrs: &Vec<syn::Attribute>) -> syn::Ident {
81    let repr = syn::Ident::new("repr", Span::call_site());
82    let repr_attr = attrs.iter().find(|attr| match attr.style {
83        syn::AttrStyle::Outer => attr.path.is_ident(&repr),
84        _ => false,
85    }).unwrap_or_else(|| panic!("repr attribute not found on enum"));
86
87    let repr_tokens = repr_attr.tokens.clone();
88    let mut repr_tokens_iter = repr_tokens.into_iter();
89
90    let first_token = repr_tokens_iter.next();
91    if first_token.is_none() || repr_tokens_iter.next().is_some() {
92        panic!("malformed repr attribute, expected repr(TYPE)");
93    }
94
95    match first_token.unwrap().clone() {
96        TokenTree::Group(repr_items) => {
97            if repr_items.delimiter() != Delimiter::Parenthesis {
98                panic!("malformed repr attribute, expected repr(TYPE)");
99            }
100
101            let mut repr_types_iter = repr_items.stream().into_iter();
102            let first_repr_item = repr_types_iter.next().unwrap();
103
104            if let Some(_) = repr_types_iter.next() {
105                panic!("malformed repr attribute, expected single type");
106            }
107
108            match first_repr_item.clone() {
109                TokenTree::Ident(repr_type) => repr_type,
110                _ => panic!("malformed repr attribute, unexpected type"),
111            }
112        },
113        _ => panic!("malformed repr attribute, unexpected token"),
114    }
115}
116
117fn get_enum_variants(data: &syn::Data) -> Vec<EnumVariant> {
118    match *data {
119        syn::Data::Enum(ref data) => {
120            data.variants.iter().map(|variant| {
121                let pair = variant.discriminant.as_ref().unwrap();
122                let name = variant.ident.clone();
123                let value = pair.1.clone();
124
125                EnumVariant { name, value }
126            }).collect()
127        }
128        syn::Data::Struct(_) => panic!("unexpected struct, const-enum only supports enums"),
129        syn::Data::Union(_) => panic!("unexpected union, const-enum only supports enums"),
130    }
131}
132
133fn build_from_match(enum_name: &syn::Ident, variants: &Vec<EnumVariant>) -> TokenStream {
134    let mut match_arms = TokenStream::new();
135
136    // Generate a match arm for each variant
137    variants.iter().for_each(|variant| {
138        let (name, value) = (&variant.name, &variant.value);
139
140        match_arms.extend(quote! {
141            #value => #enum_name::#name,
142        });
143    });
144
145    // Add exhaustive default match arm resulting in error
146    match_arms.extend(quote! {
147        _ => panic!("invalid value provided"),
148    });
149
150    return quote! {
151        match value {
152            #match_arms
153        }
154    };
155}