1use std::collections::HashMap;
2
3use proc_macro::TokenStream;
4use syn::{DeriveInput, Fields, Expr, Ident, Lit, BinOp};
5use quote::*;
6
7#[derive(Clone)]
8enum FlagValue {
9 Expr(Expr),
10 Value(u128),
11 Implicit
12}
13
14#[proc_macro_attribute]
15pub fn flags(_attr: TokenStream, item: TokenStream) -> TokenStream {
16 let ast: DeriveInput = syn::parse(item).unwrap();
17
18 let data_enum = match ast.data {
19 syn::Data::Enum(data) => data,
20 _ => panic!("Flags macro only works on enums")
21 };
22
23 let attributes = ast.attrs;
24
25 let mut unprocessed_flags = Vec::<(Ident, FlagValue)>::new();
26 let mut processed_flags = HashMap::<Ident, FlagValue>::new();
27
28 for variant in data_enum.variants {
29 match variant.fields {
30 Fields::Unit => (),
31 _ => panic!("Variants with fields are not allowed")
32 }
33
34 let flag_value = match variant.discriminant {
35 Some(expr) => FlagValue::Expr(expr.1),
36 None => FlagValue::Implicit
37 };
38
39 unprocessed_flags.push((variant.ident, flag_value))
40 }
41
42 let mut processed_any = true;
43 while processed_any {
44 processed_any = false;
45
46 let mut i = 0usize;
47 while i < unprocessed_flags.len() {
48 let (ident, value) = &unprocessed_flags[i as usize];
49
50 let processed_flag_value = process_discriminant(ident, value, &processed_flags);
51
52 match processed_flag_value {
53 FlagValue::Value(_) => {
54 processed_flags.insert(ident.clone(), processed_flag_value);
55 unprocessed_flags.remove(i as usize);
56 processed_any = true;
57 }
58 _ => {
59 i += 1;
60 }
61 }
62 }
63 }
64
65 let highest_bit = 31;
66
67 if unprocessed_flags.len() > 0 {
68 let (ident, value) = &unprocessed_flags[0];
69 match value {
70 FlagValue::Value(_) => (),
71 _ => panic!("Unable to determine value for \"{ident}\"")
72 }
73 }
74
75 let variants: Vec<_> = processed_flags.into_iter()
76 .map(|(ident, value)| {
77 (ident, match value {
78 FlagValue::Value(value) => value as u32,
79 _ => unreachable!()
80 })
81 })
82 .map(|(ident, value)| {
83 quote! {
84 pub const #ident: Self = Self(#value);
85 }
86 })
87 .collect();
88
89 let representation = match highest_bit {
90 0..=7 => quote! { u8 },
91 8..=15 => quote! { u16 },
92 16..=31 => quote! { u32 },
93 32..=63 => quote! { u64 },
94 64..=127 => quote! { u128 },
95 _ => panic!("Cannot repr flags of this size")
96 };
97
98 let name = ast.ident;
99 let visibility = ast.vis;
100
101 quote! {
102 #[derive(Clone, Copy, Eq, PartialEq)]
103 #(#attributes)*
104 #visibility struct #name (#representation);
105
106 impl std::convert::From<#name> for #representation {
107 fn from(value: #name) -> Self {
108 value.0
109 }
110 }
111
112 #[allow(non_upper_case_globals)]
113 impl #name {
114 pub const None: Self = Self(0);
115 pub const All: Self = Self(#representation::MAX);
116
117 #(#variants)*
118
119 pub fn intersects(&self, flags: Self) -> bool {
120 (self.0 & flags.0) != 0
121 }
122
123 pub fn contains(&self, flags: Self) -> bool {
124 (self.0 & flags.0) == flags.0
125 }
126 }
127
128 impl std::default::Default for #name {
129 fn default() -> Self {
130 #name::None
131 }
132 }
133
134 impl std::ops::BitAnd for #name {
135 type Output = Self;
136 fn bitand(self, rhs: Self) -> Self {
137 Self(self.0 & rhs.0)
138 }
139 }
140
141 impl std::ops::BitAndAssign for #name {
142 fn bitand_assign(&mut self, rhs: Self) {
143 self.0 &= rhs.0;
144 }
145 }
146
147 impl std::ops::BitOr for #name {
148 type Output = Self;
149 fn bitor(self, rhs: Self) -> Self {
150 Self(self.0 | rhs.0)
151 }
152 }
153
154 impl std::ops::BitOrAssign for #name {
155 fn bitor_assign(&mut self, rhs: Self) {
156 self.0 |= rhs.0;
157 }
158 }
159
160 impl std::ops::BitXor for #name {
161 type Output = Self;
162 fn bitxor(self, rhs: Self) -> Self {
163 Self(self.0 ^ rhs.0)
164 }
165 }
166
167 impl std::ops::BitXorAssign for #name {
168 fn bitxor_assign(&mut self, rhs: Self) {
169 self.0 ^= rhs.0;
170 }
171 }
172
173 impl std::ops::Not for #name {
174 type Output = Self;
175 fn not(self) -> Self{
176 Self(!self.0)
177 }
178 }
179 }.into()
180}
181
182fn process_discriminant(ident: &Ident, value: &FlagValue, processed_flags: &HashMap<Ident, FlagValue>) -> FlagValue {
183 match value {
184 FlagValue::Expr(expr) => {
185 match parse_discriminant(ident, expr, processed_flags) {
186 Some(value) => FlagValue::Value(value),
187 None => value.clone()
188 }
189 },
190 _ => value.clone()
191 }
192}
193
194fn parse_discriminant(ident: &Ident, expr: &Expr, processed_flags: &HashMap<Ident, FlagValue>) -> Option<u128> {
195 match expr {
196 Expr::Lit(expr_lit) => {
197 match &expr_lit.lit {
198 Lit::Int(lit_int) => {
199 Some(lit_int.base10_digits().parse::<u128>().unwrap())
200 },
201 _ => panic!("Invalid discriminant for {ident}")
202 }
203 },
204 Expr::Path(expr_path) => {
205 let segments = &expr_path.path.segments;
206 if segments.len() != 2 {
207 panic!("Invalid discriminant for {ident}")
208 }
209
210 if segments[0].ident != *ident && segments[0].ident.to_string() != "Self" {
211 panic!("Invalid discriminant for {ident}")
212 }
213
214 match processed_flags.get(&segments[1].ident) {
215 Some(flag_value) => match flag_value {
216 FlagValue::Value(value) => Some(*value),
217 _ => panic!("Invalid discriminant for {ident}")
218 },
219 None => None
220 }
221 },
222 Expr::Binary(binary) => {
223 let lhs = parse_discriminant(ident, &*binary.left, processed_flags)?;
224 let rhs = parse_discriminant(ident, &*binary.right, processed_flags)?;
225 Some(match binary.op {
226 BinOp::BitAnd(_) => lhs & rhs,
227 BinOp::BitOr(_) => lhs | rhs,
228 BinOp::BitXor(_) => lhs ^ rhs,
229 _ => panic!("Invalid discriminant for {ident}")
230 })
231 }
232 _ => panic!("Invalid discriminant for {ident}")
233 }
234}