#![allow(unexpected_cfgs)]
extern crate proc_macro;
use proc_macro::TokenStream;
use quote::{ToTokens, format_ident, quote};
use syn::{Ident, ItemEnum, parse_macro_input};
#[allow(unexpected_cfgs)]
#[cfg(feature = "serde")]
fn emit_serde_code(
repr_type: &Ident,
struct_name: &Ident,
) -> (
proc_macro2::TokenStream,
proc_macro2::TokenStream,
proc_macro2::TokenStream,
)
{
(
quote! {
#[allow(unexpected_cfgs)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
},
quote! {
#[allow(unexpected_cfgs)]
#[cfg_attr(feature = "serde", derive(Serialize), serde(transparent))]
},
quote! {
#[allow(unexpected_cfgs)]
#[cfg(feature = "serde")]
impl<'de> ::serde::Deserialize<'de> for #struct_name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: ::serde::Deserializer<'de>,
{
let bits = <#repr_type as serde::Deserialize>::deserialize(deserializer)?;
Self::try_from_bits(bits).map_err(|invalid_bits| {
<D::Error as ::serde::de::Error>::invalid_value(
::serde::de::Unexpected::Unsigned(invalid_bits as u64),
&"a valid bitflag value",
)
})
}
}
},
)
}
#[allow(unexpected_cfgs)]
#[cfg(not(feature = "serde"))]
fn emit_serde_code(
_: &Ident,
_: &Ident,
) -> (
proc_macro2::TokenStream,
proc_macro2::TokenStream,
proc_macro2::TokenStream,
)
{
(quote! {}, quote! {}, quote! {})
}
#[proc_macro_attribute]
pub fn bit_index(_attr: TokenStream, item: TokenStream) -> TokenStream
{
let input = parse_macro_input!(item as ItemEnum);
let repr_type = find_repr_type(&input.attrs).unwrap_or_else(|| syn::parse_quote!(u8));
let mut enum_variants = Vec::new();
let mut enum_variants_name = Vec::new();
let mut struct_non_composite_const = Vec::new();
let mut struct_composite_const = Vec::new();
let mut index: usize = 0;
let visibility = &input.vis;
let enum_name = &input.ident;
let struct_name = format_ident!("{enum_name}Flags");
let max_bits = match repr_type.to_string().as_str()
{
"u8" => 8,
"u16" => 16,
"u32" => 32,
"u64" => 64,
_ => panic!("Unsupported repr type: {}", repr_type),
};
for variant in &input.variants
{
let var_ident = &variant.ident;
let flag_ident = var_ident.clone();
if let Some((_, expr)) = &variant.discriminant
{
index = match syn::Expr::clone(expr)
{
syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Int(ref lit_int),
..
}) => lit_int.base10_parse().unwrap(),
_ =>
{
let tt = add_bits_field(expr);
struct_composite_const.push(quote! {
#[allow(non_upper_case_globals)]
pub const #flag_ident: #struct_name = #struct_name { _bits_do_not_use_it: #tt };
});
continue;
}
};
}
struct_non_composite_const.push(quote! { pub const #flag_ident: #struct_name = #struct_name { _bits_do_not_use_it: 1 << (#enum_name::#flag_ident as #repr_type) }; });
enum_variants_name.push(flag_ident.clone());
let value = quote! {
#index as #repr_type
};
if index >= max_bits
{
return syn::Error::new_spanned(
var_ident,
format!(
"Too many enum variants for {}: {} exceeds max bits ({}) for repr type {}",
enum_name, index, max_bits, repr_type
),
)
.to_compile_error()
.into();
}
enum_variants.push(quote! {
#flag_ident = #value,
});
index += 1;
}
let nb_variant = enum_variants.len();
let (serde_serialize_deserialize, serde_serialize_transparent, serde_deserialiaze_flags) =
emit_serde_code(&repr_type, &struct_name);
let output = quote! {
#serde_serialize_deserialize
#[repr(#repr_type)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#visibility enum #enum_name
{
#(#enum_variants)*
}
impl ::std::cmp::PartialEq<#struct_name> for #enum_name { fn eq(&self, other: &#struct_name) -> bool { self.flags() == *other }}
impl #enum_name
{
pub const ALL: [Self; #nb_variant] = [
#(Self::#enum_variants_name,)*
];
pub const BITS: [#struct_name; #nb_variant] = [
#(#struct_name { _bits_do_not_use_it: 1 << Self::#enum_variants_name as #repr_type },
)*
];
#(#struct_composite_const)*
pub const fn index(self) -> usize { self as #repr_type as usize }
pub const fn flags(self) -> #struct_name { #struct_name { _bits_do_not_use_it: self.bits() } }
pub const fn bits(self) -> #repr_type
{
match self
{
#(Self::#enum_variants_name => (1 << Self::#enum_variants_name as #repr_type) as #repr_type,)*
}
}
pub fn from_index(index: #repr_type) -> Option<Self>
{
match index {
#(
x if x == Self::#enum_variants_name as #repr_type => Some(Self::#enum_variants_name),
)*
_ => None,
}
}
pub unsafe fn from_index_unchecked(index: #repr_type) -> Option<Self>
{
unsafe { ::std::mem::transmute_copy(&index) }
}
pub fn from_flags(flags: #struct_name) -> Option<Self>
{
Self::from_bits(flags.bits())
}
pub unsafe fn from_flags_unchecked(flags: #struct_name) -> Self
{
unsafe { Self::from_bits_unchecked(flags.bits()) }
}
pub fn from_bits(bits: #repr_type) -> Option<Self>
{
match bits
{
#( x if x == Self::#enum_variants_name.bits() => Some(Self::#enum_variants_name), )*
_ => None,
}
}
pub unsafe fn from_bits_unchecked(bits: #repr_type) -> Self
{
match bits
{
#( x if x == Self::#enum_variants_name.bits() => Self::#enum_variants_name, )*
_ => unreachable!(),
}
}
}
impl<T> std::ops::BitOr<T> for #enum_name where T: Into<#struct_name>
{
type Output = #struct_name; fn bitor(self, other: T) -> Self::Output { self.flags().bitor(other) }
}
impl<T> std::ops::BitAnd<T> for #enum_name where T: Into<#struct_name>
{
type Output = #struct_name; fn bitand(self, other: T) -> Self::Output { self.flags().bitand(other) }
}
impl<T> std::ops::BitXor<T> for #enum_name where T: Into<#struct_name>
{
type Output = #struct_name; fn bitxor(self, other: T) -> Self::Output { self.flags().bitxor(other) }
}
impl std::ops::Not for #enum_name
{
type Output = #struct_name;
fn not(self) -> Self::Output { self.flags().not() }
}
impl TryFrom<#struct_name> for #enum_name
{
type Error = ();
fn try_from(value: #struct_name) -> Result<Self, Self::Error>
{
Self::from_flags(value).ok_or(())
}
}
impl From<#enum_name> for #struct_name
{
fn from(value: #enum_name) -> Self { value.flags() }
}
#serde_serialize_transparent
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#visibility struct #struct_name
{
#[doc(hidden)]
_bits_do_not_use_it: #repr_type
}
impl ::std::cmp::PartialEq<#enum_name> for #struct_name { fn eq(&self, other: &#enum_name) -> bool { *self == other.flags() }}
#serde_deserialiaze_flags
impl ::std::fmt::Debug for #struct_name
{
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result
{
f.write_str(stringify!(#struct_name))?;
write!(f, "({:#b}", self._bits_do_not_use_it)?;
if self.is_not_empty()
{
write!(f, ", ")?;
let mut it = self.iter().peekable();
while let Some(v) = it.next()
{
write!(f, "{:?}", v)?;
if it.peek().is_some()
{
write!(f, " | ")?;
}
}
}
write!(f, ")")
}
}
impl ::std::fmt::Binary for #struct_name
{
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
::std::fmt::Binary::fmt(&self._bits_do_not_use_it, f)
}
}
impl ::std::fmt::LowerHex for #struct_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
::std::fmt::LowerHex::fmt(&self._bits_do_not_use_it, f)
}
}
impl ::std::fmt::UpperHex for #struct_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
::std::fmt::UpperHex::fmt(&self._bits_do_not_use_it, f)
}
}
impl ::std::fmt::LowerExp for #struct_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
::std::fmt::LowerExp::fmt(&self._bits_do_not_use_it, f)
}
}
impl ::std::fmt::UpperExp for #struct_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
::std::fmt::UpperExp::fmt(&self._bits_do_not_use_it, f)
}
}
impl ::std::fmt::Octal for #struct_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
::std::fmt::Octal::fmt(&self._bits_do_not_use_it, f)
}
}
impl TryFrom<#repr_type> for #struct_name
{
type Error = #repr_type;
fn try_from(bits: #repr_type) -> Result<Self, Self::Error>
{
Self::try_from_bits(bits)
}
}
#[allow(non_upper_case_globals)]
impl #struct_name
{
#(#struct_non_composite_const)*
#(#struct_composite_const)*
pub const ALL: Self = Self { _bits_do_not_use_it: #( #enum_name::#enum_variants_name.bits() |)* 0 };
pub const EMPTY: Self = Self { _bits_do_not_use_it: 0 };
pub const ZERO : Self = Self::EMPTY;
}
impl #struct_name
{
#[inline(always)]
pub const fn bits(self) -> #repr_type { self._bits_do_not_use_it }
pub const unsafe fn from_bits_unchecked(bits: #repr_type) -> Self { Self { _bits_do_not_use_it: bits} }
pub const fn from_bits(bits: #repr_type) -> Self { unsafe { Self::from_bits_unchecked(bits & Self::ALL._bits_do_not_use_it) } }
pub fn try_from_bits(bits: #repr_type) -> Result<Self, #repr_type>
{
if bits == (bits & Self::ALL.bits())
{
Ok(unsafe { Self::from_bits_unchecked(bits) })
}else
{
Err(bits & (!Self::ALL.bits()))
}
}
#[inline(always)]
pub const fn is_empty(self) -> bool { self.bits() == 0 }
#[inline(always)]
pub const fn is_not_empty(self) -> bool { self.bits() != 0 }
#[inline(always)]
pub const fn len(self) -> usize { self.bits().count_ones() as usize }
pub fn iter(self) -> Self { self }
#[must_use]
#[inline(always)]
pub const fn union(self, other: Self) -> Self { #struct_name { _bits_do_not_use_it: self.bits() | (other.bits()) } }
#[must_use]
#[inline(always)]
pub const fn intersection(self, other: Self) -> Self { #struct_name { _bits_do_not_use_it: self.bits() & (other.bits()) } }
#[must_use]
#[inline(always)]
pub const fn complement(self) -> Self { #struct_name { _bits_do_not_use_it: !self.bits() & Self::ALL.bits() } }
#[inline(always)]
pub fn intersects<T>(self, other: T) -> bool where T: Into<Self> { self & other != Self::EMPTY }
#[inline(always)]
pub fn contains<T>(self, other: T) -> bool where T: Into<Self> { self.contains_all(other) }
#[inline(always)]
pub fn contains_all<T>(self, other: T) -> bool where T: Into<Self> { let other = other.into(); self & other == other }
#[inline(always)]
pub fn contains_any<T>(self, other: T) -> bool where T: Into<Self> { let other = other.into(); (self & other).is_not_empty() }
#[must_use]
#[inline(always)]
pub fn toggled<T>(self, other: T) -> Self where T: Into<Self> { self ^ other }
#[inline(always)]
pub fn toggle<T>(&mut self, other: T) -> &mut Self where T: Into<Self> { *self ^= other; self }
#[must_use]
#[inline(always)]
pub fn inserted<T>(self, other: T) -> Self where T: Into<Self> { self | other }
#[inline(always)]
pub fn insert<T>(&mut self, other: T) -> &mut Self where T: Into<Self> { *self |= other; self }
#[must_use]
#[inline(always)]
pub fn removed<T>(self, other: T) -> Self where T: Into<Self> { self & other }
#[inline(always)]
pub fn remove<T>(&mut self, other: T) -> &mut Self where T: Into<Self> { *self &= !other.into(); self }
#[must_use]
#[inline(always)]
pub fn with<T>(self, other: T, insert: bool) -> Self where T: Into<Self> { if insert { self.inserted(other) } else { self.removed(other) } }
#[inline(always)]
pub fn set<T>(&mut self, other: T, insert: bool) -> &mut Self where T: Into<Self> { if insert { self.insert(other) } else { self.remove(other) } }
pub fn clear(&mut self) { *self = Self::EMPTY; }
}
impl<T> std::ops::BitOr<T> for #struct_name where T: Into<Self>
{
type Output = Self; fn bitor(self, other: T) -> Self::Output { #struct_name { _bits_do_not_use_it: self.bits().bitor(other.into().bits()) } }
}
impl<T> std::ops::BitOrAssign<T> for #struct_name where T: Into<Self>
{
fn bitor_assign(&mut self, other: T) { *self = <Self as ::std::ops::BitOr<T>>::bitor(*self, other); }
}
impl<T> std::ops::BitXor<T> for #struct_name where T: Into<Self>
{
type Output = Self; fn bitxor(self, other: T) -> Self::Output { #struct_name { _bits_do_not_use_it: self.bits().bitxor(other.into().bits()) } }
}
impl<T> std::ops::BitXorAssign<T> for #struct_name where T: Into<Self>
{
fn bitxor_assign(&mut self, other: T) { *self = <Self as ::std::ops::BitXor<T>>::bitxor(*self, other); }
}
impl<T> std::ops::BitAnd<T> for #struct_name where T: Into<Self>
{
type Output = Self; fn bitand(self, other: T) -> Self::Output { #struct_name { _bits_do_not_use_it: self.bits().bitand(other.into().bits()) } }
}
impl<T> std::ops::BitAndAssign<T> for #struct_name where T: Into<Self>
{
fn bitand_assign(&mut self, other: T) { *self = <Self as ::std::ops::BitAnd<T>>::bitand(*self, other); }
}
impl std::ops::Not for #struct_name
{
type Output = Self;
fn not(self) -> Self::Output
{
Self
{
_bits_do_not_use_it:
self.bits().not() & Self::ALL.bits() }
}
}
impl ::std::iter::Iterator for #struct_name
{
type Item = #enum_name;
fn next(&mut self) -> Option<Self::Item>
{
if self.is_not_empty()
{
let bits = self.bits();
let lsb = bits & bits.wrapping_neg(); let rest = bits & !lsb; self._bits_do_not_use_it = rest;
Some(unsafe{#enum_name::from_bits_unchecked(lsb)})
}
else
{
None
}
}
}
};
output.into()
}
fn add_bits_field(expr: &syn::Expr) -> proc_macro2::TokenStream
{
match expr
{
syn::Expr::Path(path) => quote! { #path.bits() },
syn::Expr::Binary(bin) =>
{
let left = add_bits_field(&bin.left);
let right = add_bits_field(&bin.right);
let op = bin.op.to_token_stream();
quote! { (#left #op #right) }
}
syn::Expr::Paren(paren) =>
{
let inner = add_bits_field(&paren.expr);
quote! { (#inner) }
}
syn::Expr::Unary(unary) => quote! { (#unary.expr).bits() },
_ => quote! { #expr },
}
}
fn find_repr_type(attrs: &[syn::Attribute]) -> Option<syn::Ident>
{
for attr in attrs
{
if attr.path().is_ident("repr")
{
let parser =
syn::punctuated::Punctuated::<syn::Ident, syn::Token![,]>::parse_terminated;
if let Ok(idents) = attr.parse_args_with(parser)
{
return idents.first().cloned();
}
}
}
None
}