#![doc = include_str!("lib.md")]
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Error, Expr, Fields, Type, Variant};
#[proc_macro_derive(UnitEnum, attributes(unit_enum))]
pub fn unit_enum_derive(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
match validate_and_process(&ast) {
Ok((discriminant_type, unit_variants, other_variant)) => {
impl_unit_enum(&ast, &discriminant_type, &unit_variants, other_variant)
}
Err(e) => e.to_compile_error().into(),
}
}
struct ValidationResult<'a> {
unit_variants: Vec<&'a Variant>,
other_variant: Option<(&'a Variant, Type)>,
}
fn validate_and_process(ast: &DeriveInput) -> Result<(Type, Vec<&Variant>, Option<(&Variant, Type)>), Error> {
let discriminant_type = get_discriminant_type(ast)?;
let data_enum = match &ast.data {
Data::Enum(data_enum) => data_enum,
_ => return Err(Error::new_spanned(ast, "UnitEnum can only be derived for enums")),
};
let mut validation = ValidationResult {
unit_variants: Vec::new(),
other_variant: None,
};
for variant in &data_enum.variants {
match &variant.fields {
Fields::Unit => {
if has_unit_enum_attr(variant) {
return Err(Error::new_spanned(variant,
"Unit variants cannot have #[unit_enum] attributes"));
}
validation.unit_variants.push(variant);
}
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
if has_unit_enum_other_attr(variant) {
if validation.other_variant.is_some() {
return Err(Error::new_spanned(variant,
"Multiple #[unit_enum(other)] variants found. Only one is allowed"));
}
validation.other_variant = Some((variant, fields.unnamed[0].ty.clone()));
} else {
return Err(Error::new_spanned(variant,
"Non-unit variant must be marked with #[unit_enum(other)] to be used as the catch-all variant"));
}
}
_ => return Err(Error::new_spanned(variant,
"Invalid variant. UnitEnum only supports unit variants and a single tuple variant marked with #[unit_enum(other)]")),
}
}
Ok((discriminant_type, validation.unit_variants, validation.other_variant))
}
fn get_discriminant_type(ast: &DeriveInput) -> Result<Type, Error> {
ast.attrs
.iter()
.find(|attr| attr.path().is_ident("repr"))
.map_or(Ok(syn::parse_quote!(i32)), |attr| {
attr.parse_args::<Type>()
.map_err(|_| Error::new_spanned(attr, "Invalid repr attribute"))
})
}
fn has_unit_enum_attr(variant: &Variant) -> bool {
variant.attrs.iter().any(|attr| attr.path().is_ident("unit_enum"))
}
fn has_unit_enum_other_attr(variant: &Variant) -> bool {
variant.attrs.iter().any(|attr| {
attr.path().is_ident("unit_enum")
&& attr
.parse_nested_meta(|meta| {
if meta.path.is_ident("other") {
Ok(())
} else {
Err(meta.error("Invalid unit_enum attribute"))
}
})
.is_ok()
})
}
fn compute_discriminants(variants: &[&Variant]) -> Vec<Expr> {
let mut discriminants = Vec::with_capacity(variants.len());
let mut last_discriminant: Option<Expr> = None;
for variant in variants {
let discriminant = variant
.discriminant
.as_ref()
.map(|(_, expr)| expr.clone())
.or_else(|| last_discriminant.clone().map(|expr| syn::parse_quote! { #expr + 1 }))
.unwrap_or_else(|| syn::parse_quote! { 0 });
discriminants.push(discriminant.clone());
last_discriminant = Some(discriminant);
}
discriminants
}
fn impl_unit_enum(
ast: &DeriveInput, discriminant_type: &Type, unit_variants: &[&Variant], other_variant: Option<(&Variant, Type)>,
) -> TokenStream {
let name = &ast.ident;
let num_variants = unit_variants.len();
let discriminants = compute_discriminants(unit_variants);
let name_impl = generate_name_impl(name, unit_variants, &other_variant);
let ordinal_impl = generate_ordinal_impl(name, unit_variants, &other_variant, num_variants);
let from_ordinal_impl = generate_from_ordinal_impl(name, unit_variants);
let discriminant_impl =
generate_discriminant_impl(name, unit_variants, &other_variant, discriminant_type, &discriminants);
let from_discriminant_impl =
generate_from_discriminant_impl(name, unit_variants, &other_variant, discriminant_type, &discriminants);
let values_impl = generate_values_impl(name, unit_variants, &discriminants, &other_variant);
quote! {
impl #name {
#name_impl
#ordinal_impl
#from_ordinal_impl
#discriminant_impl
#from_discriminant_impl
pub const fn len() -> usize {
#num_variants
}
#values_impl
}
}
.into()
}
fn generate_name_impl(
name: &syn::Ident, unit_variants: &[&Variant], other_variant: &Option<(&Variant, Type)>,
) -> proc_macro2::TokenStream {
let unit_match_arms = unit_variants.iter().map(|variant| {
let variant_name = &variant.ident;
quote! { #name::#variant_name => stringify!(#variant_name) }
});
let other_arm = other_variant.as_ref().map(|(variant, _)| {
let variant_name = &variant.ident;
quote! { #name::#variant_name(_) => stringify!(#variant_name) }
});
quote! {
pub const fn name(&self) -> &str {
match self {
#(#unit_match_arms,)*
#other_arm
}
}
}
}
fn generate_ordinal_impl(
name: &syn::Ident, unit_variants: &[&Variant], other_variant: &Option<(&Variant, Type)>, num_variants: usize,
) -> proc_macro2::TokenStream {
let unit_match_arms = unit_variants.iter().enumerate().map(|(index, variant)| {
let variant_name = &variant.ident;
quote! { #name::#variant_name => #index }
});
let other_arm = other_variant.as_ref().map(|(variant, _)| {
let variant_name = &variant.ident;
quote! { #name::#variant_name(_) => #num_variants }
});
quote! {
pub const fn ordinal(&self) -> usize {
match self {
#(#unit_match_arms,)*
#other_arm
}
}
}
}
fn generate_from_ordinal_impl(name: &syn::Ident, unit_variants: &[&Variant]) -> proc_macro2::TokenStream {
let match_arms = unit_variants.iter().enumerate().map(|(index, variant)| {
let variant_name = &variant.ident;
quote! { #index => Some(#name::#variant_name) }
});
quote! {
pub const fn from_ordinal(ord: usize) -> Option<Self> {
match ord {
#(#match_arms,)*
_ => None
}
}
}
}
fn generate_discriminant_impl(
name: &syn::Ident, unit_variants: &[&Variant], other_variant: &Option<(&Variant, Type)>, discriminant_type: &Type,
discriminants: &[Expr],
) -> proc_macro2::TokenStream {
let unit_match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
let variant_name = &variant.ident;
quote! { #name::#variant_name => #discriminant }
});
let other_arm = other_variant.as_ref().map(|(variant, _)| {
let variant_name = &variant.ident;
quote! { #name::#variant_name(val) => *val }
});
quote! {
pub const fn discriminant(&self) -> #discriminant_type {
match self {
#(#unit_match_arms,)*
#other_arm
}
}
}
}
fn generate_from_discriminant_impl(
name: &syn::Ident, unit_variants: &[&Variant], other_variant: &Option<(&Variant, Type)>, discriminant_type: &Type,
discriminants: &[Expr],
) -> proc_macro2::TokenStream {
if let Some((other_variant, _)) = other_variant {
let match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
let variant_name = &variant.ident;
quote! { x if x == #discriminant => #name::#variant_name }
});
let other_name = &other_variant.ident;
quote! {
pub const fn from_discriminant(discr: #discriminant_type) -> Self {
match discr {
#(#match_arms,)*
other => #name::#other_name(other)
}
}
}
} else {
let match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
let variant_name = &variant.ident;
quote! { x if x == #discriminant => Some(#name::#variant_name) }
});
quote! {
pub const fn from_discriminant(discr: #discriminant_type) -> Option<Self> {
match discr {
#(#match_arms,)*
_ => None
}
}
}
}
}
fn generate_values_impl(
name: &syn::Ident, unit_variants: &[&Variant], discriminants: &[Expr], _other_variant: &Option<(&Variant, Type)>,
) -> proc_macro2::TokenStream {
let variant_exprs = unit_variants.iter().zip(discriminants).map(|(variant, _discriminant)| {
let variant_name = &variant.ident;
quote! {
#name::#variant_name }
});
quote! {
pub fn values() -> impl Iterator<Item = Self> {
vec![
#(#variant_exprs),*
].into_iter()
}
}
}