bitbag_derive 0.3.0

Derive macros for the bitbag crate
Documentation
use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{
    parse::{Nothing, Parse, ParseStream},
    parse_macro_input, DataEnum, DeriveInput, Fields, Ident, LitStr,
};

#[proc_macro_derive(Flags)]
pub fn derive_flags(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let user_struct = parse_macro_input!(input as DeriveInput);
    expand_flags(&user_struct)
        .unwrap_or_else(syn::Error::into_compile_error)
        .into()
}

#[proc_macro_derive(BitOr)]
pub fn derive_bitor(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let user_struct = parse_macro_input!(input as DeriveInput);
    expand_bitor(&user_struct)
        .unwrap_or_else(syn::Error::into_compile_error)
        .into()
}

#[proc_macro_attribute]
pub fn check(
    attr: proc_macro::TokenStream,
    item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
    let input = parse_macro_input!(item as DeriveInput);
    parse_macro_input!(attr as Nothing);
    expand_check(&input)
        .unwrap_or_else(syn::Error::into_compile_error)
        .into()
}

#[derive(Debug, Clone)]
struct ReprIntIdent {
    ident: Ident,
}

impl ToTokens for ReprIntIdent {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        self.ident.to_tokens(tokens)
    }
}

impl Parse for ReprIntIdent {
    fn parse(tokens: ParseStream) -> syn::Result<Self> {
        let ident = tokens.parse::<Ident>()?;

        macro_rules! impl_parse {
            ($first_candidate:ident, $($candidate:ident),* $(,)?) => {
                if ident == stringify!($first_candidate) {
                    return Ok(Self{ident})
                }
                $(
                    if ident == stringify!($candidate) {
                        return Ok(Self{ident})
                    }
                )*
                return Err(syn::Error::new_spanned(ident, concat!(
                    "bitbag: ident must be one of [",
                    stringify!($first_candidate),
                    $(
                        ", ",
                        stringify!($candidate),
                    )*
                    "]"
                )))

            };
        }

        impl_parse!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128, isize, usize);
    }
}

fn get_repr_ident(input: &DeriveInput) -> syn::Result<ReprIntIdent> {
    let mut repr_idents = Vec::new();
    for attr in &input.attrs {
        if attr.path().is_ident("repr") {
            repr_idents.push(attr.parse_args::<ReprIntIdent>()?);
        }
    }
    match repr_idents.len() {
        0 => Err(syn::Error::new_spanned(
            input,
            "bitbag: must have a #[repr(..)] attribute",
        )),
        1 => Ok(repr_idents.remove(0)),
        _ => Err(syn::Error::new_spanned(
            input,
            "bitbag: must have only one #[repr(..)] attribute",
        )),
    }
}

fn extract_enum_and_repr(input: &DeriveInput) -> syn::Result<(&DataEnum, ReprIntIdent)> {
    let syn::Data::Enum(data) = &input.data else {
        return Err(syn::Error::new_spanned(
            input,
            "bitbag: only enums are supported",
        ));
    };
    let repr = get_repr_ident(input)?;

    let mut error = None;
    for variant in &data.variants {
        if let Fields::Named(_) | Fields::Unnamed(_) = variant.fields {
            error
                .get_or_insert(syn::Error::new_spanned(
                    &data.variants,
                    "bitbag: only field-less enums are supported",
                ))
                .combine(syn::Error::new_spanned(
                    &variant.fields,
                    "bitbag: cannot have fields",
                ));
        };
    }
    match error {
        Some(err) => Err(err),
        None => Ok((data, repr)),
    }
}

fn expand_flags(input: &DeriveInput) -> syn::Result<TokenStream> {
    let (data, repr) = extract_enum_and_repr(input)?;
    let user_ident = &input.ident;
    let names_and_values = data.variants.iter().map(|variant| {
        let ident = &variant.ident;
        let name = syn::LitStr::new(&ident.to_string(), ident.span());
        quote! {
            (#name, Self::#ident, Self::#ident as Self::Repr)
        }
    });
    let idents = data.variants.iter().map(|it| &it.ident);

    let name = syn::LitStr::new(&user_ident.to_string(), user_ident.span());

    Ok(quote! {
        impl bitbag::Flags for #user_ident {
            type Repr = #repr;
            fn to_repr(&self) -> Self::Repr {
                unsafe { *<*const _>::from(self).cast() }
            }
            const VARIANTS: &'static [(&'static str, Self, Self::Repr)] = &[
                    #(#names_and_values,)*
                ];
            const NAME: &str = #name;
            const ALL: Self::Repr = #(Self::#idents as Self::Repr|)* 0;
        }
    })
}

fn expand_bitor(input: &DeriveInput) -> syn::Result<TokenStream> {
    let user_ident = &input.ident;
    Ok(quote! {
        #[automatically_derived]
        impl core::ops::BitOr<Self> for #user_ident
        where
            Self: bitbag::Flags,
        {
            type Output = bitbag::BitBag<Self>;
            fn bitor(self, rhs: Self) -> Self::Output {
                *bitbag::BitBag::empty()
                    .set(self)
                    .set(rhs)
            }
        }

        #[automatically_derived]
        impl core::ops::BitOr<bitbag::BitBag<Self>> for #user_ident
        where
            Self: bitbag::Flags,
        {
            type Output = bitbag::BitBag<Self>;
            fn bitor(self, mut rhs: bitbag::BitBag<Self>) -> Self::Output {
                *rhs.set(self)
            }
        }
    })
}

fn expand_check(input: &DeriveInput) -> syn::Result<TokenStream> {
    let (data, repr) = extract_enum_and_repr(input)?;
    let struct_ident = &input.ident;

    let mut pairs = Vec::new();
    for right_ix in (0..data.variants.len()).rev() {
        for left_ix in 0..right_ix {
            pairs.push((
                &data.variants[left_ix].ident,
                &data.variants[right_ix].ident,
            ))
        }
    }
    let overlap_checkers = pairs.into_iter().map(|(left, right)| {
        let panic_msg = LitStr::new(
            &format!("{struct_ident}::{left} and {struct_ident}::{right} have overlapping bits"),
            Span::call_site(),
        );
        quote!(
            {
                let left = #struct_ident::#left as #repr;
                let right = #struct_ident::#right as #repr;
                if left & right != 0 {
                    panic!(#panic_msg)
                }
            }
        )
    });

    let nonzero_checkers = data.variants.iter().map(|variant| {
        let variant = &variant.ident;
        let panic_msg = LitStr::new(
            &format!("{struct_ident}::{variant} has no bits set"),
            Span::call_site(),
        );
        quote!(
            if #struct_ident::#variant as #repr == 0 {
                panic!(#panic_msg)
            }
        )
    });

    Ok(quote!(
        #input
        #[allow(warnings)]
        const _: () = {
            #(#overlap_checkers)*
            #(#nonzero_checkers)*
        };
    ))
}

#[cfg(test)]
mod tests {
    #[test]
    fn trybuild() {
        let t = trybuild::TestCases::new();
        t.pass("trybuild/pass/**/*.rs");
        t.compile_fail("trybuild/fail/**/*.rs")
    }
}