enum-flags 0.3.0

EnumFlags is a csharp like enum flags implementation.
Documentation
#![allow(clippy::needless_doctest_main)]

//!
//! EnumFlags is a csharp like enum flags implementation.
//!
//! # Example
//! ```rust
//! #![feature(arbitrary_enum_discriminant)]
//! use enum_flags::enum_flags;
//!
//! #[repr(u8)]  // default: #[repr(usize)]
//! #[enum_flags]
//! #[derive(Copy, Clone, PartialEq)]   // can be omitted
//! enum Flags{
//!     None = 0,
//!     A = 1,
//!     B, // 2
//!     C = 4
//! }
//! fn main() {
//!     let e1: Flags = Flags::A | Flags::C;
//!     let e2 = Flags::B | Flags::C;
//!
//!     assert_eq!(e1 | e2, Flags::A | Flags::B | Flags::C); // union
//!     assert_eq!(e1 & e2, Flags::C); // intersection
//!     assert_eq!(e1 ^ e2, Flags::A | Flags::B); // toggle
//!     assert_eq!(e1 & (!Flags::C), Flags::A); // deletion
//!     assert_eq!(e1 - Flags::C, Flags::A); // deletion
//!
//!     assert_eq!(format!("{:?}", e1).as_str(), "(Flags::A | Flags::C)");
//!     assert!(e1.has_a());
//!     assert!(!e1.has_b());
//!     assert!(e1.has_flag(Flags::C));
//!     assert!(e1.contains(Flags::C));
//!     assert_eq!(match Flags::A | Flags::C {
//!         Flags::None => "None",
//!         Flags::A => "A",
//!         Flags::B => "B",
//!         Flags::C => "C",
//!         Flags::__Composed__(v) if v == Flags::A | Flags::B => "A and B",
//!         Flags::__Composed__(v) if v == Flags::A | Flags::C => "A and C",
//!         _ => "Others"
//!     }, "A and C")
//! }
//! ```

extern crate proc_macro;

use syn::{AttrStyle, Attribute, Data, Expr, ExprLit, Ident, Lit, LitInt, Meta, NestedMeta, Path};
use {
    self::proc_macro::TokenStream,
    proc_macro2::{self, Span},
    quote::*,
    syn::{parse_macro_input, DeriveInput},
};

#[proc_macro_attribute]
pub fn enum_flags(_args: TokenStream, input: TokenStream) -> TokenStream {
    impl_flags(parse_macro_input!(input as DeriveInput))
}

fn impl_flags(mut ast: DeriveInput) -> TokenStream {
    let enum_name = &ast.ident;

    let num = if let Some(repr) = extract_repr(&ast.attrs) {
        repr
    } else {
        ast.attrs.push(Attribute {
            pound_token: Default::default(),
            style: AttrStyle::Outer,
            bracket_token: Default::default(),
            path: Path::from(syn::Ident::new("repr", Span::call_site())),
            tokens: syn::parse2(quote! { (usize) }).unwrap(),
        });
        syn::Ident::new("usize", Span::call_site())
    };

    let vis = &ast.vis;

    if let Data::Enum(ref mut data_enum) = &mut ast.data {
        let mut i = 0;

        for variant in &mut data_enum.variants {
            if let Some((_, ref expr)) = variant.discriminant {
                i = if let Expr::Lit(ExprLit {
                    lit: Lit::Int(ref lit_int),
                    ..
                }) = expr
                {
                    lit_int
                        .to_string()
                        .parse::<u128>()
                        .expect("Invalid literal")
                        + 1
                } else {
                    panic!("Unsupported discriminant type, only integer are supported.")
                }
            } else {
                // println!("{}:{}", variant.ident, i);
                variant.discriminant = Some((
                    syn::token::Eq(Span::call_site()),
                    Expr::Lit(ExprLit {
                        lit: Lit::Int(LitInt::new(i.to_string().as_str(), Span::call_site())),
                        attrs: vec![],
                    }),
                ));
                i += 1;
            }
        }

        data_enum
            .variants
            .push(syn::parse2(quote! {__Composed__(#num)}).unwrap());
    } else {
        panic!("`EnumFlags` has to be used with enums");
    }



    // try to derive Copy,Clone,PartialEq automatically
    {
        let dervies = extract_derives(&ast.attrs);

        let dervies = ["Copy", "Clone", "PartialEq"]
            .iter()
            .filter(|x| dervies.iter().all(|d| d.ne(x)))
            .map(|x| Ident::new(x, Span::call_site()))
            .collect::<Vec<_>>();

        if dervies.len() > 0 {
            ast.attrs.push(Attribute {
                pound_token: Default::default(),
                style: AttrStyle::Outer,
                bracket_token: Default::default(),
                path: Path::from(syn::Ident::new("derive", Span::call_site())),
                tokens: syn::parse2(quote! { (#(#dervies),* )}).unwrap(),
            });
        }
    }

    let result = match &ast.data {
        Data::Enum(ref data_enum) => {
            let (enum_items, enum_values): (Vec<&syn::Ident>, Vec<&syn::Expr>) = data_enum
                .variants
                .iter()
                .filter(|f| f.ident.ne("__Composed__"))
                .map(|v| (&v.ident, &v.discriminant.as_ref().expect("").1))
                .unzip();

            let has_enum_items = enum_items
                .iter()
                .map(|x| {
                    let mut n = to_snake_case(&x.to_string());
                    n.insert_str(0, "has_");
                    Ident::new(n.as_str(), enum_name.span().clone())
                })
                .collect::<Vec<syn::Ident>>();

            let enum_names = enum_items
                .iter()
                .map(|x| {
                    let mut n = enum_name.to_string();
                    n.push_str("::");
                    n.push_str(&x.to_string());
                    n
                })
                .collect::<Vec<String>>();

            quote! {

                #ast

                impl #enum_name {
                    #(
                        #[inline]
                        #vis fn #has_enum_items(&self)-> bool {
                            self.contains(#enum_name::#enum_items)
                        }
                    )*

                    /// Returns `true` if all of the flags in `other` are contained within `self`.
                    #[inline]
                    #vis fn has_flag(&self, other: Self) -> bool {
                        self.contains(other)
                    }

                    /// Returns `true` if no flags are currently stored.
                    #[inline]
                    #vis fn is_empty(&self) -> bool {
                        #num::from(self) == 0
                    }

                    /// Returns `true` if all flags are currently set.
                    #[inline]
                    #vis fn is_all(&self) -> bool {
                        use #enum_name::*;
                        let mut v = Self::from(0);
                        #(
                            v |= #enum_items;
                        )*
                        *self == v
                    }

                    /// Returns `true` if all of the flags in `other` are contained within `self`.
                    #[inline]
                    #vis fn contains(&self, other: Self) -> bool {
                        let a: #num = self.into();
                        let b: #num = other.into();
                        if a == 0 {
                            b == 0
                        } else {
                            (a & b) != 0
                        }
                    }

                    #[inline]
                    #vis fn clear(&mut self) {
                        *self = Self::from(0);
                    }

                    /// Inserts the specified flags in-place.
                    #[inline]
                    #vis fn insert(&mut self, other: Self) {
                        *self |= other;
                    }

                    /// Removes the specified flags in-place.
                    #[inline]
                    #vis fn remove(&mut self, other: Self) {
                        *self &= !other;
                    }

                    /// Inserts or removes the specified flags depending on the passed value.
                    #[inline]
                    #vis fn set(&mut self, other: Self, value: bool) {
                        if value {
                            self.insert(other);
                        } else {
                            self.remove(other);
                        }
                    }

                    /// Toggles the specified flags in-place.
                    #[inline]
                    #vis fn toggle(&mut self, other: Self) {
                        *self ^= other;
                    }

                    /// Returns the intersection between the flags in `self` and
                    #[inline]
                    #vis fn intersection(&self, other: Self) -> Self {
                        *self & other
                    }

                    /// Returns the union of between the flags in `self` and `other`.
                    #[inline]
                    #vis fn union(&self, other: Self) -> Self {
                        *self | other
                    }

                    /// Returns the difference between the flags in `self` and `other`.
                    #[inline]
                    #vis fn difference(&self, other: Self) -> Self {
                        *self & !other
                    }

                    /// Returns the [symmetric difference][sym-diff] between the flags
                    /// in `self` and `other`.
                    #[inline]
                    #vis fn symmetric_difference(&self, other: Self) -> Self {
                        *self ^ other
                    }

                    #[inline]
                    #vis fn from_num(n: #num) -> Self {
                        n.into()
                    }

                    #[inline]
                    #vis fn as_num(&self) -> #num {
                        self.into()
                    }
                }

                impl From<#num> for #enum_name {
                    #[inline]
                    fn from(n: #num) -> Self {
                        use #enum_name::*;
                        match n {
                            #(
                                #enum_values => #enum_items,
                            )*
                            _ => __Composed__(n)
                        }
                    }
                }

                impl From<#enum_name> for #num {
                    #[inline]
                    fn from(s: #enum_name) -> Self {
                        use #enum_name::__Composed__;
                        match s {
                            __Composed__(n) => n,
                            _ => unsafe { *(&s as *const #enum_name as *const #num) }
                        }
                    }
                }

                impl From<&#enum_name> for #num {
                    #[inline]
                    fn from(s: &#enum_name) -> Self {
                        (*s).into()
                    }
                }

                impl std::ops::BitOr for #enum_name {
                    type Output = Self;
                    #[inline]
                    fn bitor(self, rhs: Self) -> Self::Output {
                        let a: #num = self.into();
                        let b: #num = rhs.into();
                        let c = a | b;
                        Self::from(c)
                    }
                }

                impl std::ops::BitAnd for #enum_name {
                    type Output = Self;
                    #[inline]
                    fn bitand(self, rhs: Self) -> Self::Output {
                        let a: #num = self.into();
                        let b: #num = rhs.into();
                        let c = a & b;
                        Self::from(c)
                    }
                }

                impl std::ops::BitXor for #enum_name {
                    type Output = Self;
                    #[inline]
                    fn bitxor(self, rhs: Self) -> Self::Output {
                        let a: #num = self.into();
                        let b: #num = rhs.into();
                        let c = a ^ b;
                        Self::from(c)
                    }
                }

                impl std::ops::Not for #enum_name {
                    type Output = Self;

                    #[inline]
                    fn not(self) -> Self::Output {
                        let a: #num = self.into();
                        Self::from(!a)
                    }
                }

                impl std::ops::Sub for #enum_name {
                    type Output = Self;

                    #[inline]
                    fn sub(self, rhs: Self) -> Self::Output {
                        self & (!rhs)
                    }
                }

                impl std::ops::BitOrAssign for #enum_name {
                    #[inline]
                    fn bitor_assign(&mut self, rhs: Self) {
                        *self = *self | rhs;
                    }
                }

                impl std::ops::BitAndAssign for #enum_name {
                    #[inline]
                    fn bitand_assign(&mut self, rhs: Self) {
                        *self = *self & rhs;
                    }
                }

                impl std::ops::BitXorAssign for #enum_name {
                    #[inline]
                    fn bitxor_assign(&mut self, rhs: Self) {
                        *self = *self ^ rhs;
                    }
                }

                impl std::ops::SubAssign for #enum_name {
                    #[inline]
                    fn sub_assign(&mut self, rhs: Self) {
                        *self = *self - rhs
                    }
                }

                impl std::fmt::Debug for #enum_name {
                    #[inline]
                    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                        let mut v = Vec::new();
                        #(
                            if self.#has_enum_items() {
                                v.push(#enum_names)
                            }
                        )*
                        write!(f, "({})", v.join(" | "))
                    }
                }

                impl std::cmp::PartialEq<#num> for #enum_name {
                    #[inline]
                    fn eq(&self, other: &#num) -> bool {
                        #num::from(self) == *other
                    }
                }

                impl std::cmp::PartialEq<#enum_name> for #num {
                    #[inline]
                    fn eq(&self, other: &#enum_name) -> bool {
                        *self == #num::from(other)
                    }
                }

            }
        }
        _ => panic!("`EnumFlags` has to be used with enums"),
    };

    result.into()
}

fn extract_repr(attrs: &[Attribute]) -> Option<Ident> {
    attrs
        .iter()
        .find_map(|attr| match attr.parse_meta() {
            Err(why) => panic!("{:?}", syn::Error::new_spanned(
                attr,
                format!("Couldn't parse attribute: {}", why),
            )),
            Ok(Meta::List(ref meta)) if meta.path.is_ident("repr") => {
                meta.nested.iter().find_map(|mi| match mi {
                    NestedMeta::Meta(Meta::Path(path)) => path.get_ident().cloned(),
                    _ => None,
                })
            }
            Ok(_) => None,
        })
}

fn extract_derives(attrs: &[Attribute]) -> Vec<Ident> {
    attrs
        .iter()
        .flat_map(|attr| attr.parse_meta())
        .flat_map(|ref meta| match meta {
            Meta::List(ref meta) if meta.path.is_ident("derive") => {
                meta.nested.iter().filter_map(|mi| match mi {
                    NestedMeta::Meta(Meta::Path(path)) => path.get_ident().cloned(),
                    _ => None,
                })
                .collect::<Vec<_>>()
            }
            _ => Default::default(),
        })
        .collect::<Vec<_>>()
}

fn to_snake_case(str: &str) -> String {
    let mut s = String::with_capacity(str.len());
    for (i, char) in str.char_indices() {
        if char.is_uppercase() && char.is_ascii_alphabetic() {
            if i > 0 {
                s.push('_');
            }
            s.push(char.to_ascii_lowercase());
        } else {
            s.push(char)
        }
    }
    s
}