enumflags2_derive 0.7.4

Do not use directly, use the reexport in the `enumflags2` crate. This allows for better compatibility across versions.
Documentation
#![recursion_limit = "2048"]
extern crate proc_macro;
#[macro_use]
extern crate quote;

use proc_macro2::{Span, TokenStream};
use std::convert::TryFrom;
use syn::{
    parse::{Parse, ParseStream},
    parse_macro_input,
    spanned::Spanned,
    Expr, Ident, Item, ItemEnum, Token, Variant,
};

struct Flag<'a> {
    name: Ident,
    span: Span,
    value: FlagValue<'a>,
}

enum FlagValue<'a> {
    Literal(u128),
    Deferred,
    Inferred(&'a mut Variant),
}

impl FlagValue<'_> {
    // matches! is beyond our MSRV
    #[allow(clippy::match_like_matches_macro)]
    fn is_inferred(&self) -> bool {
        match self {
            FlagValue::Inferred(_) => true,
            _ => false,
        }
    }
}

struct Parameters {
    default: Vec<Ident>,
}

impl Parse for Parameters {
    fn parse(input: ParseStream) -> syn::parse::Result<Self> {
        if input.is_empty() {
            return Ok(Parameters { default: vec![] });
        }

        input.parse::<Token![default]>()?;
        input.parse::<Token![=]>()?;
        let mut default = vec![input.parse()?];
        while !input.is_empty() {
            input.parse::<Token![|]>()?;
            default.push(input.parse()?);
        }

        Ok(Parameters { default })
    }
}

#[proc_macro_attribute]
pub fn bitflags_internal(
    attr: proc_macro::TokenStream,
    input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
    let Parameters { default } = parse_macro_input!(attr as Parameters);
    let mut ast = parse_macro_input!(input as Item);
    let output = match ast {
        Item::Enum(ref mut item_enum) => gen_enumflags(item_enum, default),
        _ => Err(syn::Error::new_spanned(
            &ast,
            "#[bitflags] requires an enum",
        )),
    };

    output
        .unwrap_or_else(|err| {
            let error = err.to_compile_error();
            quote! {
                #ast
                #error
            }
        })
        .into()
}

/// Try to evaluate the expression given.
fn fold_expr(expr: &syn::Expr) -> Option<u128> {
    match expr {
        Expr::Lit(ref expr_lit) => match expr_lit.lit {
            syn::Lit::Int(ref lit_int) => lit_int.base10_parse().ok(),
            _ => None,
        },
        Expr::Binary(ref expr_binary) => {
            let l = fold_expr(&expr_binary.left)?;
            let r = fold_expr(&expr_binary.right)?;
            match &expr_binary.op {
                syn::BinOp::Shl(_) => u32::try_from(r).ok().and_then(|r| l.checked_shl(r)),
                _ => None,
            }
        }
        Expr::Paren(syn::ExprParen { expr, .. }) | Expr::Group(syn::ExprGroup { expr, .. }) => {
            fold_expr(expr)
        }
        _ => None,
    }
}

fn collect_flags<'a>(
    variants: impl Iterator<Item = &'a mut Variant>,
) -> Result<Vec<Flag<'a>>, syn::Error> {
    variants
        .map(|variant| {
            // MSRV: Would this be cleaner with `matches!`?
            match variant.fields {
                syn::Fields::Unit => (),
                _ => {
                    return Err(syn::Error::new_spanned(
                        &variant.fields,
                        "Bitflag variants cannot contain additional data",
                    ))
                }
            }

            let name = variant.ident.clone();
            let span = variant.span();
            let value = if let Some(ref expr) = variant.discriminant {
                if let Some(n) = fold_expr(&expr.1) {
                    FlagValue::Literal(n)
                } else {
                    FlagValue::Deferred
                }
            } else {
                FlagValue::Inferred(variant)
            };

            Ok(Flag { name, span, value })
        })
        .collect()
}

fn inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr {
    let tokens = if previous_variants.is_empty() {
        quote!(1)
    } else {
        quote!(::enumflags2::_internal::next_bit(
                #(#type_name::#previous_variants as u128)|*
        ) as #repr)
    };

    syn::parse2(tokens).expect("couldn't parse inferred value")
}

fn infer_values(flags: &mut [Flag], type_name: &Ident, repr: &Ident) {
    let mut previous_variants: Vec<Ident> = flags
        .iter()
        .filter(|flag| !flag.value.is_inferred())
        .map(|flag| flag.name.clone())
        .collect();

    for flag in flags {
        if let FlagValue::Inferred(ref mut variant) = flag.value {
            variant.discriminant = Some((
                <Token![=]>::default(),
                inferred_value(type_name, &previous_variants, repr),
            ));
            previous_variants.push(flag.name.clone());
        }
    }
}

/// Given a list of attributes, find the `repr`, if any, and return the integer
/// type specified.
fn extract_repr(attrs: &[syn::Attribute]) -> Result<Option<Ident>, syn::Error> {
    use syn::{Meta, NestedMeta};
    attrs
        .iter()
        .find_map(|attr| match attr.parse_meta() {
            Err(why) => Some(Err(syn::Error::new_spanned(
                attr,
                format!("Couldn't parse attribute: {}", why),
            ))),
            Ok(Meta::List(ref meta)) if meta.path.is_ident("repr") => {
                meta.nested.iter().find_map(|mi| match mi {
                    NestedMeta::Meta(Meta::Path(path)) => path.get_ident().cloned().map(Ok),
                    _ => None,
                })
            }
            Ok(_) => None,
        })
        .transpose()
}

/// Check the repr and return the number of bits available
fn type_bits(ty: &Ident) -> Result<u8, syn::Error> {
    // This would be so much easier if we could just match on an Ident...
    if ty == "usize" {
        Err(syn::Error::new_spanned(
            ty,
            "#[repr(usize)] is not supported. Use u32 or u64 instead.",
        ))
    } else if ty == "i8"
        || ty == "i16"
        || ty == "i32"
        || ty == "i64"
        || ty == "i128"
        || ty == "isize"
    {
        Err(syn::Error::new_spanned(
            ty,
            "Signed types in a repr are not supported.",
        ))
    } else if ty == "u8" {
        Ok(8)
    } else if ty == "u16" {
        Ok(16)
    } else if ty == "u32" {
        Ok(32)
    } else if ty == "u64" {
        Ok(64)
    } else if ty == "u128" {
        Ok(128)
    } else {
        Err(syn::Error::new_spanned(
            ty,
            "repr must be an integer type for #[bitflags].",
        ))
    }
}

/// Returns deferred checks
fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenStream>, syn::Error> {
    use FlagValue::*;
    match flag.value {
        Literal(n) => {
            if !n.is_power_of_two() {
                Err(syn::Error::new(
                    flag.span,
                    "Flags must have exactly one set bit",
                ))
            } else if bits < 128 && n >= 1 << bits {
                Err(syn::Error::new(
                    flag.span,
                    format!("Flag value out of range for u{}", bits),
                ))
            } else {
                Ok(None)
            }
        }
        Inferred(_) => Ok(None),
        Deferred => {
            let variant_name = &flag.name;
            // MSRV: Use an unnamed constant (`const _: ...`).
            let assertion_name = syn::Ident::new(
                &format!("__enumflags_assertion_{}_{}", type_name, flag.name),
                Span::call_site(),
            ); // call_site because def_site is unstable

            Ok(Some(quote_spanned!(flag.span =>
                #[doc(hidden)]
                const #assertion_name:
                    <<[(); (
                        (#type_name::#variant_name as u128).is_power_of_two()
                    ) as usize] as enumflags2::_internal::AssertionHelper>
                        ::Status as enumflags2::_internal::ExactlyOneBitSet>::X
                    = ();
            )))
        }
    }
}

fn gen_enumflags(ast: &mut ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn::Error> {
    let ident = &ast.ident;

    let span = Span::call_site();

    let repr = extract_repr(&ast.attrs)?
        .ok_or_else(|| syn::Error::new_spanned(&ident,
                        "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?;
    let bits = type_bits(&repr)?;

    let mut variants = collect_flags(ast.variants.iter_mut())?;
    let deferred = variants
        .iter()
        .flat_map(|variant| check_flag(ident, variant, bits).transpose())
        .collect::<Result<Vec<_>, _>>()?;

    infer_values(&mut variants, ident, &repr);

    if (bits as usize) < variants.len() {
        return Err(syn::Error::new_spanned(
            &repr,
            format!("Not enough bits for {} flags", variants.len()),
        ));
    }

    let std = quote_spanned!(span => ::enumflags2::_internal::core);
    let variant_names = ast.variants.iter().map(|v| &v.ident).collect::<Vec<_>>();

    Ok(quote_spanned! {
        span =>
            #ast
            #(#deferred)*
            impl #std::ops::Not for #ident {
                type Output = ::enumflags2::BitFlags<Self>;
                #[inline(always)]
                fn not(self) -> Self::Output {
                    use ::enumflags2::BitFlags;
                    BitFlags::from_flag(self).not()
                }
            }

            impl #std::ops::BitOr for #ident {
                type Output = ::enumflags2::BitFlags<Self>;
                #[inline(always)]
                fn bitor(self, other: Self) -> Self::Output {
                    use ::enumflags2::BitFlags;
                    BitFlags::from_flag(self) | other
                }
            }

            impl #std::ops::BitAnd for #ident {
                type Output = ::enumflags2::BitFlags<Self>;
                #[inline(always)]
                fn bitand(self, other: Self) -> Self::Output {
                    use ::enumflags2::BitFlags;
                    BitFlags::from_flag(self) & other
                }
            }

            impl #std::ops::BitXor for #ident {
                type Output = ::enumflags2::BitFlags<Self>;
                #[inline(always)]
                fn bitxor(self, other: Self) -> Self::Output {
                    use ::enumflags2::BitFlags;
                    BitFlags::from_flag(self) ^ other
                }
            }

            unsafe impl ::enumflags2::_internal::RawBitFlags for #ident {
                type Numeric = #repr;

                const EMPTY: Self::Numeric = 0;

                const DEFAULT: Self::Numeric =
                    0 #(| (Self::#default as #repr))*;

                const ALL_BITS: Self::Numeric =
                    0 #(| (Self::#variant_names as #repr))*;

                const BITFLAGS_TYPE_NAME : &'static str =
                    concat!("BitFlags<", stringify!(#ident), ">");

                fn bits(self) -> Self::Numeric {
                    self as #repr
                }
            }

            impl ::enumflags2::BitFlag for #ident {}
    })
}