enum_flags/
lib.rs

1#![allow(clippy::needless_doctest_main)]
2
3//!
4//! EnumFlags is a csharp like enum flags implementation.
5//!
6//! The generated code is `no_std` compatible.
7//!
8//! # Example
9//! ```rust
10//! #![feature(arbitrary_enum_discriminant)]
11//! use enum_flags::enum_flags;
12//!
13//! #[repr(u8)]  // default: #[repr(usize)]
14//! #[enum_flags]
15//! #[derive(Copy, Clone, PartialEq)]   // can be omitted
16//! enum Flags{
17//!     None = 0,
18//!     A = 1,
19//!     B, // 2
20//!     C = 4
21//! }
22//! fn main() {
23//!     let e1: Flags = Flags::A | Flags::C;
24//!     let e2 = Flags::B | Flags::C;
25//!
26//!     assert_eq!(e1 | e2, Flags::A | Flags::B | Flags::C); // union
27//!     assert_eq!(e1 & e2, Flags::C); // intersection
28//!     assert_eq!(e1 ^ e2, Flags::A | Flags::B); // toggle
29//!     assert_eq!(e1 & (!Flags::C), Flags::A); // deletion
30//!     assert_eq!(e1 - Flags::C, Flags::A); // deletion
31//!
32//!     assert_eq!(format!("{:?}", e1).as_str(), "(Flags::A | Flags::C)");
33//!     assert!(e1.has_a());
34//!     assert!(!e1.has_b());
35//!     assert!(e1.has_flag(Flags::C));
36//!     assert!(e1.contains(Flags::C));
37//!     assert_eq!(match Flags::A | Flags::C {
38//!         Flags::None => "None",
39//!         Flags::A => "A",
40//!         Flags::B => "B",
41//!         Flags::C => "C",
42//!         Flags::__Composed__(v) if v == Flags::A | Flags::B => "A and B",
43//!         Flags::__Composed__(v) if v == Flags::A | Flags::C => "A and C",
44//!         _ => "Others"
45//!     }, "A and C")
46//! }
47//! ```
48
49extern crate proc_macro;
50
51use syn::{AttrStyle, Attribute, Data, Expr, ExprLit, Ident, Lit, LitInt, Meta, NestedMeta, Path};
52use {
53    self::proc_macro::TokenStream,
54    proc_macro2::{self, Span},
55    quote::*,
56    syn::{parse_macro_input, DeriveInput},
57};
58
59#[proc_macro_attribute]
60pub fn enum_flags(_args: TokenStream, input: TokenStream) -> TokenStream {
61    impl_flags(parse_macro_input!(input as DeriveInput))
62}
63
64fn impl_flags(mut ast: DeriveInput) -> TokenStream {
65    let enum_name = &ast.ident;
66
67    let num = if let Some(repr) = extract_repr(&ast.attrs) {
68        repr
69    } else {
70        ast.attrs.push(Attribute {
71            pound_token: Default::default(),
72            style: AttrStyle::Outer,
73            bracket_token: Default::default(),
74            path: Path::from(syn::Ident::new("repr", Span::call_site())),
75            tokens: syn::parse2(quote! { (usize) }).unwrap(),
76        });
77        syn::Ident::new("usize", Span::call_site())
78    };
79
80    let vis = &ast.vis;
81
82    if let Data::Enum(ref mut data_enum) = &mut ast.data {
83        let mut i = 0;
84
85        for variant in &mut data_enum.variants {
86            if let Some((_, ref expr)) = variant.discriminant {
87                i = if let Expr::Lit(ExprLit {
88                    lit: Lit::Int(ref lit_int),
89                    ..
90                }) = expr
91                {
92                    lit_int
93                        .to_string()
94                        .parse::<u128>()
95                        .expect("Invalid literal")
96                        + 1
97                } else {
98                    panic!("Unsupported discriminant type, only integer are supported.")
99                }
100            } else {
101                // println!("{}:{}", variant.ident, i);
102                variant.discriminant = Some((
103                    syn::token::Eq(Span::call_site()),
104                    Expr::Lit(ExprLit {
105                        lit: Lit::Int(LitInt::new(i.to_string().as_str(), Span::call_site())),
106                        attrs: vec![],
107                    }),
108                ));
109                i += 1;
110            }
111        }
112
113        data_enum
114            .variants
115            .push(syn::parse2(quote! {__Composed__(#num)}).unwrap());
116    } else {
117        panic!("`EnumFlags` has to be used with enums");
118    }
119
120
121
122    // try to derive Copy,Clone,PartialEq automatically
123    {
124        let dervies = extract_derives(&ast.attrs);
125
126        let dervies = ["Copy", "Clone", "PartialEq"]
127            .iter()
128            .filter(|x| dervies.iter().all(|d| d.ne(x)))
129            .map(|x| Ident::new(x, Span::call_site()))
130            .collect::<Vec<_>>();
131
132        if dervies.len() > 0 {
133            ast.attrs.push(Attribute {
134                pound_token: Default::default(),
135                style: AttrStyle::Outer,
136                bracket_token: Default::default(),
137                path: Path::from(syn::Ident::new("derive", Span::call_site())),
138                tokens: syn::parse2(quote! { (#(#dervies),* )}).unwrap(),
139            });
140        }
141    }
142
143    let result = match &ast.data {
144        Data::Enum(ref data_enum) => {
145            let (enum_items, enum_values): (Vec<&syn::Ident>, Vec<&syn::Expr>) = data_enum
146                .variants
147                .iter()
148                .filter(|f| f.ident.ne("__Composed__"))
149                .map(|v| (&v.ident, &v.discriminant.as_ref().expect("").1))
150                .unzip();
151
152            let has_enum_items = enum_items
153                .iter()
154                .map(|x| {
155                    let mut n = to_snake_case(&x.to_string());
156                    n.insert_str(0, "has_");
157                    Ident::new(n.as_str(), enum_name.span().clone())
158                })
159                .collect::<Vec<syn::Ident>>();
160
161            let enum_names = enum_items
162                .iter()
163                .map(|x| {
164                    let mut n = enum_name.to_string();
165                    n.push_str("::");
166                    n.push_str(&x.to_string());
167                    n
168                })
169                .collect::<Vec<String>>();
170
171            quote! {
172
173                #ast
174
175                impl #enum_name {
176                    #(
177                        #[inline]
178                        #vis fn #has_enum_items(&self)-> bool {
179                            self.contains(#enum_name::#enum_items)
180                        }
181                    )*
182
183                    /// Returns `true` if all of the flags in `other` are contained within `self`.
184                    #[inline]
185                    #vis fn has_flag(&self, other: Self) -> bool {
186                        self.contains(other)
187                    }
188
189                    /// Returns `true` if no flags are currently stored.
190                    #[inline]
191                    #vis fn is_empty(&self) -> bool {
192                        #num::from(self) == 0
193                    }
194
195                    /// Returns `true` if all flags are currently set.
196                    #[inline]
197                    #vis fn is_all(&self) -> bool {
198                        use #enum_name::*;
199                        let mut v = Self::from(0);
200                        #(
201                            v |= #enum_items;
202                        )*
203                        *self == v
204                    }
205
206                    /// Returns `true` if all of the flags in `other` are contained within `self`.
207                    #[inline]
208                    #vis fn contains(&self, other: Self) -> bool {
209                        let a: #num = self.into();
210                        let b: #num = other.into();
211                        if a == 0 {
212                            b == 0
213                        } else {
214                            (a & b) != 0
215                        }
216                    }
217
218                    #[inline]
219                    #vis fn clear(&mut self) {
220                        *self = Self::from(0);
221                    }
222
223                    /// Inserts the specified flags in-place.
224                    #[inline]
225                    #vis fn insert(&mut self, other: Self) {
226                        *self |= other;
227                    }
228
229                    /// Removes the specified flags in-place.
230                    #[inline]
231                    #vis fn remove(&mut self, other: Self) {
232                        *self &= !other;
233                    }
234
235                    /// Inserts or removes the specified flags depending on the passed value.
236                    #[inline]
237                    #vis fn set(&mut self, other: Self, value: bool) {
238                        if value {
239                            self.insert(other);
240                        } else {
241                            self.remove(other);
242                        }
243                    }
244
245                    /// Toggles the specified flags in-place.
246                    #[inline]
247                    #vis fn toggle(&mut self, other: Self) {
248                        *self ^= other;
249                    }
250
251                    /// Returns the intersection between the flags in `self` and
252                    #[inline]
253                    #vis fn intersection(&self, other: Self) -> Self {
254                        *self & other
255                    }
256
257                    /// Returns the union of between the flags in `self` and `other`.
258                    #[inline]
259                    #vis fn union(&self, other: Self) -> Self {
260                        *self | other
261                    }
262
263                    /// Returns the difference between the flags in `self` and `other`.
264                    #[inline]
265                    #vis fn difference(&self, other: Self) -> Self {
266                        *self & !other
267                    }
268
269                    /// Returns the [symmetric difference][sym-diff] between the flags
270                    /// in `self` and `other`.
271                    #[inline]
272                    #vis fn symmetric_difference(&self, other: Self) -> Self {
273                        *self ^ other
274                    }
275
276                    #[inline]
277                    #vis fn from_num(n: #num) -> Self {
278                        n.into()
279                    }
280
281                    #[inline]
282                    #vis fn as_num(&self) -> #num {
283                        self.into()
284                    }
285                }
286
287                impl From<#num> for #enum_name {
288                    #[inline]
289                    fn from(n: #num) -> Self {
290                        use #enum_name::*;
291                        match n {
292                            #(
293                                #enum_values => #enum_items,
294                            )*
295                            _ => __Composed__(n)
296                        }
297                    }
298                }
299
300                impl From<#enum_name> for #num {
301                    #[inline]
302                    fn from(s: #enum_name) -> Self {
303                        use #enum_name::__Composed__;
304                        match s {
305                            __Composed__(n) => n,
306                            _ => unsafe { *(&s as *const #enum_name as *const #num) }
307                        }
308                    }
309                }
310
311                impl From<&#enum_name> for #num {
312                    #[inline]
313                    fn from(s: &#enum_name) -> Self {
314                        (*s).into()
315                    }
316                }
317
318                impl core::ops::BitOr for #enum_name {
319                    type Output = Self;
320                    #[inline]
321                    fn bitor(self, rhs: Self) -> Self::Output {
322                        let a: #num = self.into();
323                        let b: #num = rhs.into();
324                        let c = a | b;
325                        Self::from(c)
326                    }
327                }
328
329                impl core::ops::BitAnd for #enum_name {
330                    type Output = Self;
331                    #[inline]
332                    fn bitand(self, rhs: Self) -> Self::Output {
333                        let a: #num = self.into();
334                        let b: #num = rhs.into();
335                        let c = a & b;
336                        Self::from(c)
337                    }
338                }
339
340                impl core::ops::BitXor for #enum_name {
341                    type Output = Self;
342                    #[inline]
343                    fn bitxor(self, rhs: Self) -> Self::Output {
344                        let a: #num = self.into();
345                        let b: #num = rhs.into();
346                        let c = a ^ b;
347                        Self::from(c)
348                    }
349                }
350
351                impl core::ops::Not for #enum_name {
352                    type Output = Self;
353
354                    #[inline]
355                    fn not(self) -> Self::Output {
356                        let a: #num = self.into();
357                        Self::from(!a)
358                    }
359                }
360
361                impl core::ops::Sub for #enum_name {
362                    type Output = Self;
363
364                    #[inline]
365                    fn sub(self, rhs: Self) -> Self::Output {
366                        self & (!rhs)
367                    }
368                }
369
370                impl core::ops::BitOrAssign for #enum_name {
371                    #[inline]
372                    fn bitor_assign(&mut self, rhs: Self) {
373                        *self = *self | rhs;
374                    }
375                }
376
377                impl core::ops::BitAndAssign for #enum_name {
378                    #[inline]
379                    fn bitand_assign(&mut self, rhs: Self) {
380                        *self = *self & rhs;
381                    }
382                }
383
384                impl core::ops::BitXorAssign for #enum_name {
385                    #[inline]
386                    fn bitxor_assign(&mut self, rhs: Self) {
387                        *self = *self ^ rhs;
388                    }
389                }
390
391                impl core::ops::SubAssign for #enum_name {
392                    #[inline]
393                    fn sub_assign(&mut self, rhs: Self) {
394                        *self = *self - rhs
395                    }
396                }
397
398                impl core::fmt::Debug for #enum_name {
399                    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
400                        let mut first = true;
401                        write!(f, "(")?;
402                        #(
403                            if self.#has_enum_items() {
404                                if first {
405                                    first = false;
406                                }else {
407                                    write!(f, " | ")?;
408                                }
409                                write!(f, "{}", #enum_names)?;
410                            }
411                        )*
412                        write!(f, ")")
413                    }
414                }
415
416                impl core::cmp::PartialEq<#num> for #enum_name {
417                    #[inline]
418                    fn eq(&self, other: &#num) -> bool {
419                        #num::from(self) == *other
420                    }
421                }
422
423                impl core::cmp::PartialEq<#enum_name> for #num {
424                    #[inline]
425                    fn eq(&self, other: &#enum_name) -> bool {
426                        *self == #num::from(other)
427                    }
428                }
429
430            }
431        }
432        _ => panic!("`EnumFlags` has to be used with enums"),
433    };
434
435    result.into()
436}
437
438fn extract_repr(attrs: &[Attribute]) -> Option<Ident> {
439    attrs
440        .iter()
441        .find_map(|attr| match attr.parse_meta() {
442            Err(why) => panic!("{:?}", syn::Error::new_spanned(
443                attr,
444                format!("Couldn't parse attribute: {}", why),
445            )),
446            Ok(Meta::List(ref meta)) if meta.path.is_ident("repr") => {
447                meta.nested.iter().find_map(|mi| match mi {
448                    NestedMeta::Meta(Meta::Path(path)) => path.get_ident().cloned(),
449                    _ => None,
450                })
451            }
452            Ok(_) => None,
453        })
454}
455
456fn extract_derives(attrs: &[Attribute]) -> Vec<Ident> {
457    attrs
458        .iter()
459        .flat_map(|attr| attr.parse_meta())
460        .flat_map(|ref meta| match meta {
461            Meta::List(ref meta) if meta.path.is_ident("derive") => {
462                meta.nested.iter().filter_map(|mi| match mi {
463                    NestedMeta::Meta(Meta::Path(path)) => path.get_ident().cloned(),
464                    _ => None,
465                })
466                .collect::<Vec<_>>()
467            }
468            _ => Default::default(),
469        })
470        .collect::<Vec<_>>()
471}
472
473fn to_snake_case(str: &str) -> String {
474    let mut s = String::with_capacity(str.len());
475    for (i, char) in str.char_indices() {
476        if char.is_uppercase() && char.is_ascii_alphabetic() {
477            if i > 0 {
478                s.push('_');
479            }
480            s.push(char.to_ascii_lowercase());
481        } else {
482            s.push(char)
483        }
484    }
485    s
486}