1#![allow(clippy::needless_doctest_main)]
2
3extern crate proc_macro;
50
51use syn::{AttrStyle, Attribute, Data, Expr, ExprLit, Ident, Lit, LitInt, Meta, NestedMeta, Path};
52use {
53 self::proc_macro::TokenStream,
54 proc_macro2::{self, Span},
55 quote::*,
56 syn::{parse_macro_input, DeriveInput},
57};
58
59#[proc_macro_attribute]
60pub fn enum_flags(_args: TokenStream, input: TokenStream) -> TokenStream {
61 impl_flags(parse_macro_input!(input as DeriveInput))
62}
63
64fn impl_flags(mut ast: DeriveInput) -> TokenStream {
65 let enum_name = &ast.ident;
66
67 let num = if let Some(repr) = extract_repr(&ast.attrs) {
68 repr
69 } else {
70 ast.attrs.push(Attribute {
71 pound_token: Default::default(),
72 style: AttrStyle::Outer,
73 bracket_token: Default::default(),
74 path: Path::from(syn::Ident::new("repr", Span::call_site())),
75 tokens: syn::parse2(quote! { (usize) }).unwrap(),
76 });
77 syn::Ident::new("usize", Span::call_site())
78 };
79
80 let vis = &ast.vis;
81
82 if let Data::Enum(ref mut data_enum) = &mut ast.data {
83 let mut i = 0;
84
85 for variant in &mut data_enum.variants {
86 if let Some((_, ref expr)) = variant.discriminant {
87 i = if let Expr::Lit(ExprLit {
88 lit: Lit::Int(ref lit_int),
89 ..
90 }) = expr
91 {
92 lit_int
93 .to_string()
94 .parse::<u128>()
95 .expect("Invalid literal")
96 + 1
97 } else {
98 panic!("Unsupported discriminant type, only integer are supported.")
99 }
100 } else {
101 variant.discriminant = Some((
103 syn::token::Eq(Span::call_site()),
104 Expr::Lit(ExprLit {
105 lit: Lit::Int(LitInt::new(i.to_string().as_str(), Span::call_site())),
106 attrs: vec![],
107 }),
108 ));
109 i += 1;
110 }
111 }
112
113 data_enum
114 .variants
115 .push(syn::parse2(quote! {__Composed__(#num)}).unwrap());
116 } else {
117 panic!("`EnumFlags` has to be used with enums");
118 }
119
120
121
122 {
124 let dervies = extract_derives(&ast.attrs);
125
126 let dervies = ["Copy", "Clone", "PartialEq"]
127 .iter()
128 .filter(|x| dervies.iter().all(|d| d.ne(x)))
129 .map(|x| Ident::new(x, Span::call_site()))
130 .collect::<Vec<_>>();
131
132 if dervies.len() > 0 {
133 ast.attrs.push(Attribute {
134 pound_token: Default::default(),
135 style: AttrStyle::Outer,
136 bracket_token: Default::default(),
137 path: Path::from(syn::Ident::new("derive", Span::call_site())),
138 tokens: syn::parse2(quote! { (#(#dervies),* )}).unwrap(),
139 });
140 }
141 }
142
143 let result = match &ast.data {
144 Data::Enum(ref data_enum) => {
145 let (enum_items, enum_values): (Vec<&syn::Ident>, Vec<&syn::Expr>) = data_enum
146 .variants
147 .iter()
148 .filter(|f| f.ident.ne("__Composed__"))
149 .map(|v| (&v.ident, &v.discriminant.as_ref().expect("").1))
150 .unzip();
151
152 let has_enum_items = enum_items
153 .iter()
154 .map(|x| {
155 let mut n = to_snake_case(&x.to_string());
156 n.insert_str(0, "has_");
157 Ident::new(n.as_str(), enum_name.span().clone())
158 })
159 .collect::<Vec<syn::Ident>>();
160
161 let enum_names = enum_items
162 .iter()
163 .map(|x| {
164 let mut n = enum_name.to_string();
165 n.push_str("::");
166 n.push_str(&x.to_string());
167 n
168 })
169 .collect::<Vec<String>>();
170
171 quote! {
172
173 #ast
174
175 impl #enum_name {
176 #(
177 #[inline]
178 #vis fn #has_enum_items(&self)-> bool {
179 self.contains(#enum_name::#enum_items)
180 }
181 )*
182
183 #[inline]
185 #vis fn has_flag(&self, other: Self) -> bool {
186 self.contains(other)
187 }
188
189 #[inline]
191 #vis fn is_empty(&self) -> bool {
192 #num::from(self) == 0
193 }
194
195 #[inline]
197 #vis fn is_all(&self) -> bool {
198 use #enum_name::*;
199 let mut v = Self::from(0);
200 #(
201 v |= #enum_items;
202 )*
203 *self == v
204 }
205
206 #[inline]
208 #vis fn contains(&self, other: Self) -> bool {
209 let a: #num = self.into();
210 let b: #num = other.into();
211 if a == 0 {
212 b == 0
213 } else {
214 (a & b) != 0
215 }
216 }
217
218 #[inline]
219 #vis fn clear(&mut self) {
220 *self = Self::from(0);
221 }
222
223 #[inline]
225 #vis fn insert(&mut self, other: Self) {
226 *self |= other;
227 }
228
229 #[inline]
231 #vis fn remove(&mut self, other: Self) {
232 *self &= !other;
233 }
234
235 #[inline]
237 #vis fn set(&mut self, other: Self, value: bool) {
238 if value {
239 self.insert(other);
240 } else {
241 self.remove(other);
242 }
243 }
244
245 #[inline]
247 #vis fn toggle(&mut self, other: Self) {
248 *self ^= other;
249 }
250
251 #[inline]
253 #vis fn intersection(&self, other: Self) -> Self {
254 *self & other
255 }
256
257 #[inline]
259 #vis fn union(&self, other: Self) -> Self {
260 *self | other
261 }
262
263 #[inline]
265 #vis fn difference(&self, other: Self) -> Self {
266 *self & !other
267 }
268
269 #[inline]
272 #vis fn symmetric_difference(&self, other: Self) -> Self {
273 *self ^ other
274 }
275
276 #[inline]
277 #vis fn from_num(n: #num) -> Self {
278 n.into()
279 }
280
281 #[inline]
282 #vis fn as_num(&self) -> #num {
283 self.into()
284 }
285 }
286
287 impl From<#num> for #enum_name {
288 #[inline]
289 fn from(n: #num) -> Self {
290 use #enum_name::*;
291 match n {
292 #(
293 #enum_values => #enum_items,
294 )*
295 _ => __Composed__(n)
296 }
297 }
298 }
299
300 impl From<#enum_name> for #num {
301 #[inline]
302 fn from(s: #enum_name) -> Self {
303 use #enum_name::__Composed__;
304 match s {
305 __Composed__(n) => n,
306 _ => unsafe { *(&s as *const #enum_name as *const #num) }
307 }
308 }
309 }
310
311 impl From<&#enum_name> for #num {
312 #[inline]
313 fn from(s: &#enum_name) -> Self {
314 (*s).into()
315 }
316 }
317
318 impl core::ops::BitOr for #enum_name {
319 type Output = Self;
320 #[inline]
321 fn bitor(self, rhs: Self) -> Self::Output {
322 let a: #num = self.into();
323 let b: #num = rhs.into();
324 let c = a | b;
325 Self::from(c)
326 }
327 }
328
329 impl core::ops::BitAnd for #enum_name {
330 type Output = Self;
331 #[inline]
332 fn bitand(self, rhs: Self) -> Self::Output {
333 let a: #num = self.into();
334 let b: #num = rhs.into();
335 let c = a & b;
336 Self::from(c)
337 }
338 }
339
340 impl core::ops::BitXor for #enum_name {
341 type Output = Self;
342 #[inline]
343 fn bitxor(self, rhs: Self) -> Self::Output {
344 let a: #num = self.into();
345 let b: #num = rhs.into();
346 let c = a ^ b;
347 Self::from(c)
348 }
349 }
350
351 impl core::ops::Not for #enum_name {
352 type Output = Self;
353
354 #[inline]
355 fn not(self) -> Self::Output {
356 let a: #num = self.into();
357 Self::from(!a)
358 }
359 }
360
361 impl core::ops::Sub for #enum_name {
362 type Output = Self;
363
364 #[inline]
365 fn sub(self, rhs: Self) -> Self::Output {
366 self & (!rhs)
367 }
368 }
369
370 impl core::ops::BitOrAssign for #enum_name {
371 #[inline]
372 fn bitor_assign(&mut self, rhs: Self) {
373 *self = *self | rhs;
374 }
375 }
376
377 impl core::ops::BitAndAssign for #enum_name {
378 #[inline]
379 fn bitand_assign(&mut self, rhs: Self) {
380 *self = *self & rhs;
381 }
382 }
383
384 impl core::ops::BitXorAssign for #enum_name {
385 #[inline]
386 fn bitxor_assign(&mut self, rhs: Self) {
387 *self = *self ^ rhs;
388 }
389 }
390
391 impl core::ops::SubAssign for #enum_name {
392 #[inline]
393 fn sub_assign(&mut self, rhs: Self) {
394 *self = *self - rhs
395 }
396 }
397
398 impl core::fmt::Debug for #enum_name {
399 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
400 let mut first = true;
401 write!(f, "(")?;
402 #(
403 if self.#has_enum_items() {
404 if first {
405 first = false;
406 }else {
407 write!(f, " | ")?;
408 }
409 write!(f, "{}", #enum_names)?;
410 }
411 )*
412 write!(f, ")")
413 }
414 }
415
416 impl core::cmp::PartialEq<#num> for #enum_name {
417 #[inline]
418 fn eq(&self, other: &#num) -> bool {
419 #num::from(self) == *other
420 }
421 }
422
423 impl core::cmp::PartialEq<#enum_name> for #num {
424 #[inline]
425 fn eq(&self, other: &#enum_name) -> bool {
426 *self == #num::from(other)
427 }
428 }
429
430 }
431 }
432 _ => panic!("`EnumFlags` has to be used with enums"),
433 };
434
435 result.into()
436}
437
438fn extract_repr(attrs: &[Attribute]) -> Option<Ident> {
439 attrs
440 .iter()
441 .find_map(|attr| match attr.parse_meta() {
442 Err(why) => panic!("{:?}", syn::Error::new_spanned(
443 attr,
444 format!("Couldn't parse attribute: {}", why),
445 )),
446 Ok(Meta::List(ref meta)) if meta.path.is_ident("repr") => {
447 meta.nested.iter().find_map(|mi| match mi {
448 NestedMeta::Meta(Meta::Path(path)) => path.get_ident().cloned(),
449 _ => None,
450 })
451 }
452 Ok(_) => None,
453 })
454}
455
456fn extract_derives(attrs: &[Attribute]) -> Vec<Ident> {
457 attrs
458 .iter()
459 .flat_map(|attr| attr.parse_meta())
460 .flat_map(|ref meta| match meta {
461 Meta::List(ref meta) if meta.path.is_ident("derive") => {
462 meta.nested.iter().filter_map(|mi| match mi {
463 NestedMeta::Meta(Meta::Path(path)) => path.get_ident().cloned(),
464 _ => None,
465 })
466 .collect::<Vec<_>>()
467 }
468 _ => Default::default(),
469 })
470 .collect::<Vec<_>>()
471}
472
473fn to_snake_case(str: &str) -> String {
474 let mut s = String::with_capacity(str.len());
475 for (i, char) in str.char_indices() {
476 if char.is_uppercase() && char.is_ascii_alphabetic() {
477 if i > 0 {
478 s.push('_');
479 }
480 s.push(char.to_ascii_lowercase());
481 } else {
482 s.push(char)
483 }
484 }
485 s
486}