mod bit_par_iter;
mod int_repr;
use std::{collections::HashMap, fmt::Display, str::FromStr};
use bit_par_iter::{BitParityIter, IntegerParity};
use darling::{FromAttributes, FromMeta};
use int_repr::IntRepr;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{Expr, ItemEnum, Variant, parse_macro_input, spanned::Spanned};
#[derive(Copy, Clone, Debug, FromMeta)]
enum Parity {
Even,
Odd,
}
impl Display for Parity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
Self::Even => "even",
Self::Odd => "odd",
}
)
}
}
#[derive(Debug, Clone, FromMeta)]
#[darling(derive_syn_parse)]
struct BitParityArgs {
#[darling(flatten)]
parity: Parity,
#[darling(default)]
allow_explicit_overrides: bool,
}
struct Ctx {
repr: IntRepr,
parity: Parity,
allow_explicit_overrides: bool,
}
fn parse_discriminant<N>(ctx: &Ctx, (_eq_tok, expr): (syn::token::Eq, Expr)) -> syn::Result<N>
where
N: IntegerParity + darling::ToTokens + FromStr,
N::Err: Display,
{
let Expr::Lit(syn::ExprLit {
lit: syn::Lit::Int(lit),
..
}) = expr.clone()
else {
return Err(syn::Error::new(
expr.span(),
"Invalid or unsupported enum discriminant value. Only literals are allowed",
));
};
let lit = lit.base10_parse::<N>()?;
if lit.has_parity(ctx.parity) || ctx.allow_explicit_overrides {
Ok(lit)
} else {
Err(syn::Error::new(
expr.span(),
format!(
"explicit discriminant does not have `{}` parity",
ctx.parity,
),
))
}
}
fn next_discriminant<N>(
ctx: &Ctx,
bpi: &mut BitParityIter<N>,
variant: &Variant,
explicit_discriminants: &HashMap<N, Span>,
) -> syn::Result<N>
where
N: IntegerParity + Eq + std::hash::Hash,
{
if let Some(next_val) = bpi.next() {
if let Some(span) = explicit_discriminants.get(&next_val) {
let mut err = syn::Error::new(*span, "previous assignment here");
err.combine(syn::Error::new(
variant.span(),
"discriminant value is already assigned",
));
return Err(err);
}
return Ok(next_val);
}
Err(syn::Error::new_spanned(
variant,
format!(
"ran out of discriminant values for `{}` repr type",
ctx.repr
),
))
}
fn generic_expand<T>(ctx: &Ctx, mut enum_item: ItemEnum) -> syn::Result<TokenStream>
where
T: IntegerParity + darling::ToTokens + FromStr + Eq + std::hash::Hash + std::fmt::Debug + Ord,
T::Err: Display,
{
let mut explicit_discriminants = enum_item
.variants
.iter()
.filter_map(|variant| {
variant
.discriminant
.clone()
.map(|disc| parse_discriminant::<T>(ctx, disc).map(|val| (val, variant.span())))
})
.collect::<syn::Result<HashMap<T, Span>>>()?;
let mut bpi = BitParityIter::<T>::new(ctx.parity);
for variant in &mut enum_item.variants {
let next_disc = if let Some(disc) = variant.discriminant.clone() {
let next_disc = parse_discriminant(ctx, disc)?;
bpi.set_override(next_disc);
next_disc
} else {
let next_disc = next_discriminant(ctx, &mut bpi, variant, &explicit_discriminants)?;
explicit_discriminants.insert(next_disc, variant.span());
next_disc
};
variant.discriminant = Some((syn::token::Eq::default(), syn::parse_quote!(#next_disc)));
}
Ok(quote! {#enum_item})
}
fn specialize_expand(ctx: &Ctx, enum_item: ItemEnum) -> syn::Result<TokenStream> {
match ctx.repr {
IntRepr::U8 => generic_expand::<u8>(ctx, enum_item),
IntRepr::U16 => generic_expand::<u16>(ctx, enum_item),
IntRepr::U32 => generic_expand::<u32>(ctx, enum_item),
IntRepr::U64 => generic_expand::<u64>(ctx, enum_item),
IntRepr::U128 => generic_expand::<u128>(ctx, enum_item),
IntRepr::Usize => generic_expand::<usize>(ctx, enum_item),
IntRepr::I8 => generic_expand::<i8>(ctx, enum_item),
IntRepr::I16 => generic_expand::<i16>(ctx, enum_item),
IntRepr::I32 => generic_expand::<i32>(ctx, enum_item),
IntRepr::I64 => generic_expand::<i64>(ctx, enum_item),
IntRepr::I128 => generic_expand::<i128>(ctx, enum_item),
IntRepr::Isize => generic_expand::<isize>(ctx, enum_item),
}
}
fn try_expand(args: &BitParityArgs, enum_item: ItemEnum) -> syn::Result<TokenStream> {
let repr = IntRepr::from_attributes(&enum_item.attrs)?;
let ctx = Ctx {
repr,
parity: args.parity,
allow_explicit_overrides: args.allow_explicit_overrides,
};
specialize_expand(&ctx, enum_item)
}
#[proc_macro_attribute]
pub fn bit_parity(
args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let args = parse_macro_input!(args as BitParityArgs);
let enum_item = parse_macro_input!(input as ItemEnum);
try_expand(&args, enum_item).map_or_else(|e| e.into_compile_error().into(), Into::into)
}