use std::fmt;
use proc_macro2::{Span, TokenStream};
use quote::{quote, spanned::Spanned};
use syn::LitInt;
use syn::{meta::ParseNestedMeta, Token};
use crate::bit_size::Bits;
enum Error<'a> {
MissingSize,
InvalidExhaustive,
InvalidAttribute,
InvalidExhaustiveNextToken,
NotConditional,
NotExhaustive {
count: u32,
size: usize,
},
Exhaustive {
count: u32,
max_count: u128,
},
TooManyVariants {
count: u32,
max_count: u128,
},
NonLitDiscriminant {
variant: &'a syn::Ident,
suggest: u128,
},
TooLargeDiscriminant {
max_value: u128,
},
}
impl fmt::Display for Error<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MissingSize => write!(
f,
"Missing the storage type. It must be explicitly declared with #[bitenum(uN)], where N in range 1..=64"
),
Self::InvalidAttribute => write!(
f,
"Invalid attribute. Expected either 'uN' where N in range 1..=64, or 'exhaustive = …'."
),
Self::InvalidExhaustiveNextToken => write!(
f,
"'exhaustive' should be specified as 'exhaustive = …'. Possible values are: true, false, conditional."
),
Self::InvalidExhaustive => write!(
f,
"The specified 'exhaustive' is invalid. Possible values are: true, false, conditional."
),
Self::NotConditional => write!(
f,
"The enum contains at least one variant with a '#[cfg(…)]' attribute. It should be specified as 'exhaustive = conditional'."
),
Self::NotExhaustive { count, size } => write!(
f,
"The enum has {count} variants, it is exhaustive for {size} bits. Either remove variants, use a larger storage type, or mark this enum as 'exhaustive = true'."
),
Self::Exhaustive { count, max_count } => write!(
f,
"The enum has {count} variants, it would need {max_count} variants to be exhaustive. Either add variants, use a smaller storage type, or mark this enum as 'exhaustive = false'."
),
Self::TooManyVariants { count, max_count } => write!(
f,
"The enum has more variants than can be stored in the provided storage type. Either use a larger storage type or reduce the number of variants. Up to {max_count} variants possible, got {count}."
),
Self::NonLitDiscriminant { variant, suggest } => write!(
f,
"Discriminants must be literal integers. Eg: '{variant} = {suggest}' '{suggest:#x}' '{suggest:#b}' '{suggest:#o}'."
),
Self::TooLargeDiscriminant { max_value } => write!(
f,
"The largest discriminant value is larger than can be stored in the provided storage type. Max discriminant value is {max_value}."
),
}
}
}
#[derive(Clone, Copy)]
enum Exhaustiveness {
True,
False,
Conditional,
}
struct Exhaustive {
span: Span,
kind: Exhaustiveness,
}
impl Exhaustive {
fn is_conditional(&self) -> bool {
matches!(self.kind, Exhaustiveness::Conditional)
}
fn matches(&self, expected: bool) -> bool {
!matches!(
(self.kind, expected),
(Exhaustiveness::True, false) | (Exhaustiveness::False, true)
)
}
}
impl syn::parse::Parse for Exhaustive {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let span = input.span();
let kind = match (input.parse::<syn::LitBool>(), input.parse::<syn::Ident>()) {
(Ok(bool), _) if bool.value => Exhaustiveness::True,
(Ok(_), _) => Exhaustiveness::False,
(_, Ok(ident)) if ident == "conditional" => Exhaustiveness::Conditional,
_ => return Err(syn::Error::new(span, Error::InvalidExhaustive)),
};
Ok(Exhaustive { span, kind })
}
}
#[derive(Default)]
pub(crate) struct Config {
explicit_bits: Option<Bits>,
explicit_exhaustive: Option<Exhaustive>,
}
struct FullConfig {
bits: Bits,
exhaustive: Exhaustive,
}
impl Config {
pub(crate) fn parse(&mut self, meta: ParseNestedMeta) -> syn::Result<()> {
if meta.path.is_ident("exhaustive") {
let result1 = meta.input.parse::<Token![:]>();
let result2 = meta.input.parse::<Token![=]>();
if result1.is_err() && result2.is_err() {
return Err(meta.error(Error::InvalidExhaustiveNextToken));
}
self.explicit_exhaustive = Some(meta.input.parse()?);
} else {
let err = || meta.error(Error::InvalidAttribute);
let last_segment = meta.path.segments.last().ok_or_else(err)?;
let value = last_segment.ident.to_string();
let ("u", size) = value.split_at(1) else {
return Err(err());
};
self.explicit_bits = Some(Bits {
path: meta.path.clone(),
size: size.parse().map_err(|_| err())?,
});
}
Ok(())
}
fn explicit(self) -> syn::Result<FullConfig> {
let span = Span::call_site();
let Some(bits) = self.explicit_bits else {
return Err(syn::Error::new(span, Error::MissingSize));
};
let exhaustive = self.explicit_exhaustive.unwrap_or(Exhaustive {
span,
kind: Exhaustiveness::False,
});
Ok(FullConfig { bits, exhaustive })
}
}
fn conditional_attr(attr: &syn::Attribute) -> bool {
attr.path().is_ident("cfg")
}
fn check_explicit_conditional(config: &FullConfig, input: &syn::ItemEnum) -> syn::Result<()> {
let conditional_variant = |v: &syn::Variant| v.attrs.iter().any(conditional_attr);
let is_conditional = input.variants.iter().any(conditional_variant);
if is_conditional && !config.exhaustive.is_conditional() {
let span = config.exhaustive.span;
Err(syn::Error::new(span, Error::NotConditional))
} else {
Ok(())
}
}
fn parse_expr(expr: &syn::Expr) -> Option<u128> {
let syn::Expr::Lit(lit) = expr else {
return None;
};
let syn::Lit::Int(lit_int) = &lit.lit else {
return None;
};
lit_int.base10_parse().ok()
}
fn check_explicit_exhaustive(
config: &FullConfig,
input: &syn::ItemEnum,
) -> syn::Result<Vec<LitInt>> {
let max_count = 1_u128 << config.bits.size;
let count = input.variants.len() as u128;
let actually_exhaustive = match count.cmp(&max_count) {
std::cmp::Ordering::Equal => true,
std::cmp::Ordering::Greater if !config.exhaustive.is_conditional() => {
let count = count as u32;
let err = Error::TooManyVariants { max_count, count };
return Err(syn::Error::new_spanned(&config.bits.path, err));
}
_ => false,
};
let mut values = Vec::new();
if !config.exhaustive.matches(actually_exhaustive) {
let size = config.bits.size;
let count = count as u32;
let err = if actually_exhaustive {
Error::NotExhaustive { count, size }
} else {
Error::Exhaustive { count, max_count }
};
return Err(syn::Error::new(config.exhaustive.span, err));
}
let (mut max_discr, mut next_implicit_discr, mut max_discr_span) = (0, 0, Span::call_site());
for variant in &input.variants {
let (value, discr_span) = match variant.discriminant.as_ref() {
None => (next_implicit_discr, variant.ident.span()),
Some((_, discriminant)) => {
let Some(value) = parse_expr(discriminant) else {
let variant = &variant.ident;
let suggest = max_discr + 1;
let err = Error::NonLitDiscriminant { variant, suggest };
return Err(syn::Error::new_spanned(discriminant, err));
};
(value, discriminant.__span())
}
};
next_implicit_discr = value + 1;
if value > max_discr {
max_discr = value;
max_discr_span = discr_span;
}
values.push(LitInt::new(&value.to_string(), discr_span));
}
if max_discr >= max_count {
let max_value = max_count - 1;
let err = Error::TooLargeDiscriminant { max_value };
return Err(syn::Error::new(max_discr_span, err));
}
Ok(values)
}
pub(crate) fn fallback_impl(input: &syn::ItemEnum) -> TokenStream {
let name = &input.ident;
quote! {
#[derive(Copy, Clone)]
#input
impl #name {
pub const fn raw_value<T>(self) -> T {
unreachable!()
}
pub const fn new_with_raw_value<T>(value: T) -> Self {
unreachable!()
}
pub fn unwrap(self) -> Self { self }
}
}
}
pub(crate) fn bitenum(config: Config, input: &syn::ItemEnum) -> syn::Result<TokenStream> {
let config = config.explicit()?;
check_explicit_conditional(&config, input)?;
let values = check_explicit_exhaustive(&config, input)?;
let bits = config.bits;
let (base_type, qualified_type) = (bits.base_type()?, bits.qualified_path()?);
let (raw_value_constructor, reader) = (bits.constructor()?, bits.reader());
let non_exhaustive = config.exhaustive.matches(false);
let ok = non_exhaustive.then_some(quote!(Ok));
let new_return_type = match non_exhaustive {
true => quote!(Result<Self, #qualified_type>),
false => quote!(Self),
};
let new_default_branch = match non_exhaustive {
true => quote!(_ => Err(value)),
false => quote!(_ => unreachable!()),
};
let new_match_branches = input.variants.iter().zip(values).map(|(variant, value)| {
let cfg_attrs = variant.attrs.iter().filter(|a| conditional_attr(a));
let variant_name = &variant.ident;
quote!( #( #cfg_attrs )* (#value) => #ok(Self::#variant_name) )
});
let (attrs, vis, name, variants) = (&input.attrs, &input.vis, &input.ident, &input.variants);
Ok(quote! {
#[derive(Copy, Clone)]
#( #attrs )*
#vis enum #name {
#variants
}
impl #name {
pub const fn raw_value(self) -> #qualified_type {
#raw_value_constructor(self as #base_type)
}
pub const fn new_with_raw_value(value: #qualified_type) -> #new_return_type {
match value #reader {
#( #new_match_branches ,)*
#new_default_branch
}
}
}
})
}