sawp_flags_derive/
lib.rs

1extern crate proc_macro;
2use proc_macro2::{Ident, Span, TokenStream, TokenTree};
3use quote::quote;
4
5#[proc_macro_derive(BitFlags)]
6pub fn derive_sawp_flags(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
7    let ast: syn::DeriveInput = syn::parse(input).unwrap();
8    impl_sawp_flags(&ast).into()
9}
10
11fn impl_sawp_flags(ast: &syn::DeriveInput) -> TokenStream {
12    let name = &ast.ident;
13    let repr = if let Some(repr) = get_repr(ast) {
14        repr
15    } else {
16        panic!("BitFlags enum must have a `repr` attribute with numeric argument");
17    };
18    match &ast.data {
19        syn::Data::Enum(data) => impl_enum_traits(name, &repr, data),
20        _ => panic!("Bitflags is only supported on enums"),
21    }
22}
23
24fn get_repr(ast: &syn::DeriveInput) -> Option<Ident> {
25    ast.attrs.iter().find_map(|attr| {
26        if let Some(path) = attr.path.get_ident() {
27            if path == "repr" {
28                if let Some(tree) = attr.tokens.clone().into_iter().next() {
29                    match tree {
30                        TokenTree::Group(group) => {
31                            if let Some(ident) = group.stream().into_iter().next() {
32                                match ident {
33                                    TokenTree::Ident(ident) => Some(ident),
34                                    _ => None,
35                                }
36                            } else {
37                                None
38                            }
39                        }
40                        _ => None,
41                    }
42                } else {
43                    None
44                }
45            } else {
46                None
47            }
48        } else {
49            None
50        }
51    })
52}
53
54fn impl_enum_traits(name: &syn::Ident, repr: &Ident, data: &syn::DataEnum) -> TokenStream {
55    // TODO: compile error when these items are reused.
56    let list_items = data.variants.iter().map(|variant| &variant.ident);
57    let list_all = list_items.clone();
58    let display_items = list_items.clone();
59    let from_str_items = list_items.clone();
60    let from_str_items_str = list_items.clone().map(|variant| {
61        Ident::new(
62            variant.to_string().to_lowercase().as_str(),
63            Span::call_site(),
64        )
65    });
66
67    quote! {
68        impl Flag for #name {
69            type Primitive = #repr;
70
71            const ITEMS: &'static [Self] = &[#(#name::#list_items),*];
72
73            fn bits(self) -> Self::Primitive {
74                self as #repr
75            }
76
77            fn none() -> Flags<Self> {
78                Flags::from_bits(0)
79            }
80
81            fn all() -> Flags<Self> {
82                Flags::from_bits(#(#name::#list_all as Self::Primitive)|*)
83            }
84        }
85
86        impl std::ops::BitOr for #name {
87            type Output = Flags<#name>;
88
89            fn bitor(self, other: Self) -> Self::Output {
90                Flags::from_bits(self.bits() | other.bits())
91            }
92        }
93
94        impl std::ops::BitAnd for #name {
95            type Output = Flags<#name>;
96
97            fn bitand(self, other: Self) -> Self::Output {
98                Flags::from_bits(self.bits() & other.bits())
99            }
100        }
101
102        impl std::ops::BitXor for #name {
103            type Output = Flags<#name>;
104
105            fn bitxor(self, other: Self) -> Self::Output {
106                Flags::from_bits(self.bits() ^ other.bits())
107            }
108        }
109
110        impl std::ops::Not for #name {
111            type Output = Flags<#name>;
112
113            fn not(self) -> Self::Output {
114                Flags::from_bits(!self.bits())
115            }
116        }
117
118        impl std::fmt::Display for #name {
119            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120                let empty = self.bits() == Self::none().bits();
121                let mut first = true;
122                #(
123                    if self.bits() & #name::#display_items.bits() == #name::#display_items.bits() {
124                        write!(f, "{}{}", if first { "" } else { " | " }, stringify!(#display_items))?;
125                        first = false;
126
127                        if empty {
128                            return Ok(());
129                        }
130                    }
131                )*
132
133                if empty {
134                    write!(f, "NONE")?;
135                }
136
137                Ok(())
138            }
139        }
140
141        impl std::str::FromStr for #name {
142            type Err = ();
143            fn from_str(val: &str) -> std::result::Result<#name, Self::Err> {
144                match val.to_lowercase().as_str() {
145                    #(stringify!(#from_str_items_str) => Ok(#name::#from_str_items),)*
146                    _ => Err(()),
147                }
148            }
149        }
150
151        impl PartialEq<Flags<Self>> for #name {
152            fn eq(&self, other: &Flags<Self>) -> bool {
153                self.bits() == other.bits()
154            }
155        }
156
157        impl std::fmt::Binary for #name {
158            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159                std::fmt::Binary::fmt(&self.bits(), f)
160            }
161        }
162    }
163}
164
165/// BitFlags derive macro tests
166///
167/// `#[derive(BitFlags)]` can't be used here and `impl_sawp_flags`
168/// is being called directly instead.
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn test_macro_enum() {
175        let input = r#"
176            #[repr(u8)]
177            enum Test {
178                A = 0b0000,
179                B = 0b0001,
180                C = 0b0010,
181                D = 0b0100,
182            }
183        "#;
184        let parsed: syn::DeriveInput = syn::parse_str(input).unwrap();
185        impl_sawp_flags(&parsed);
186    }
187
188    #[test]
189    #[should_panic(expected = "BitFlags enum must have a `repr` attribute")]
190    fn test_macro_repr_panic() {
191        let input = r#"
192            enum Test {
193                A = 0b0000,
194                B = 0b0001,
195                C = 0b0010,
196                D = 0b0100,
197            }
198        "#;
199        let parsed: syn::DeriveInput = syn::parse_str(input).unwrap();
200        impl_sawp_flags(&parsed);
201    }
202
203    #[test]
204    #[should_panic(expected = "Bitflags is only supported on enums")]
205    fn test_macro_not_enum_panic() {
206        let input = r#"
207            #[repr(u8)]
208            struct Test {
209            }
210        "#;
211        let parsed: syn::DeriveInput = syn::parse_str(input).unwrap();
212        impl_sawp_flags(&parsed);
213    }
214}