flagger_macros/
lib.rs

1use std::collections::HashMap;
2
3use proc_macro::TokenStream;
4use syn::{DeriveInput, Fields, Expr, Ident, Lit, BinOp};
5use quote::*;
6
7#[derive(Clone)]
8enum FlagValue {
9    Expr(Expr),
10    Value(u128),
11    Implicit
12}
13
14#[proc_macro_attribute]
15pub fn flags(_attr: TokenStream, item: TokenStream) -> TokenStream {
16    let ast: DeriveInput = syn::parse(item).unwrap();
17
18    let data_enum = match ast.data {
19        syn::Data::Enum(data) => data,
20        _ => panic!("Flags macro only works on enums")
21    };
22
23    let attributes = ast.attrs;
24
25    let mut unprocessed_flags = Vec::<(Ident, FlagValue)>::new();
26    let mut processed_flags = HashMap::<Ident, FlagValue>::new();
27
28    for variant in data_enum.variants {
29        match variant.fields {
30            Fields::Unit => (),
31            _ => panic!("Variants with fields are not allowed")
32        }
33
34        let flag_value = match variant.discriminant {
35            Some(expr) => FlagValue::Expr(expr.1),
36            None => FlagValue::Implicit
37        };
38
39        unprocessed_flags.push((variant.ident, flag_value))
40    }
41
42    let mut processed_any = true;
43    while processed_any {
44        processed_any = false;
45        
46        let mut i = 0usize;
47        while i < unprocessed_flags.len() {
48            let (ident, value) = &unprocessed_flags[i as usize];
49
50            let processed_flag_value = process_discriminant(ident, value, &processed_flags);
51
52            match processed_flag_value {
53                FlagValue::Value(_) => {
54                    processed_flags.insert(ident.clone(), processed_flag_value);
55                    unprocessed_flags.remove(i as usize);
56                    processed_any = true;
57                }
58                _ => {
59                    i += 1;
60                }
61            }
62        }
63    }
64
65    let highest_bit = 31;
66
67    if unprocessed_flags.len() > 0 {
68        let (ident, value) = &unprocessed_flags[0];
69        match value {
70            FlagValue::Value(_) => (),
71            _ => panic!("Unable to determine value for \"{ident}\"")
72        }
73    }
74    
75    let variants: Vec<_> = processed_flags.into_iter()
76        .map(|(ident, value)| {
77            (ident, match value {
78                FlagValue::Value(value) => value as u32,
79                _ => unreachable!()
80            })
81        })
82        .map(|(ident, value)| {
83            quote! {
84                pub const #ident: Self = Self(#value);
85            }
86        })
87        .collect();
88
89    let representation = match highest_bit {
90        0..=7 => quote! { u8 },
91        8..=15 => quote! { u16 },
92        16..=31 => quote! { u32 },
93        32..=63 => quote! { u64 },
94        64..=127 => quote! { u128 },
95        _ => panic!("Cannot repr flags of this size")
96    };
97
98    let name = ast.ident;
99    let visibility = ast.vis;
100
101    quote! {
102        #[derive(Clone, Copy, Eq, PartialEq)]
103        #(#attributes)*
104        #visibility struct #name (#representation);
105
106        impl std::convert::From<#name> for #representation {
107            fn from(value: #name) -> Self {
108                value.0
109            }
110        }
111
112        #[allow(non_upper_case_globals)]
113        impl #name {
114            pub const None: Self = Self(0);
115            pub const All: Self = Self(#representation::MAX);
116
117            #(#variants)*
118
119            pub fn intersects(&self, flags: Self) -> bool {
120                (self.0 & flags.0) != 0
121            }
122        
123            pub fn contains(&self, flags: Self) -> bool {
124                (self.0 & flags.0) == flags.0
125            }
126        }
127
128        impl std::default::Default for #name {
129            fn default() -> Self {
130                #name::None
131            }
132        }
133
134        impl std::ops::BitAnd for #name {
135            type Output = Self;
136            fn bitand(self, rhs: Self) -> Self {
137                Self(self.0 & rhs.0)
138            }
139        }
140
141        impl std::ops::BitAndAssign for #name {
142            fn bitand_assign(&mut self, rhs: Self) {
143                self.0 &= rhs.0;
144            }
145        }
146
147        impl std::ops::BitOr for #name {
148            type Output = Self;
149            fn bitor(self, rhs: Self) -> Self {
150                Self(self.0 | rhs.0)
151            }
152        }
153
154        impl std::ops::BitOrAssign for #name {
155            fn bitor_assign(&mut self, rhs: Self) {
156                self.0 |= rhs.0;
157            }
158        }
159
160        impl std::ops::BitXor for #name {
161            type Output = Self;
162            fn bitxor(self, rhs: Self) -> Self {
163                Self(self.0 ^ rhs.0)
164            }
165        }
166
167        impl std::ops::BitXorAssign for #name {
168            fn bitxor_assign(&mut self, rhs: Self) {
169                self.0 ^= rhs.0;
170            }
171        }
172
173        impl std::ops::Not for #name {
174            type Output = Self;
175            fn not(self) -> Self{
176                Self(!self.0)
177            }
178        }
179    }.into()
180}
181
182fn process_discriminant(ident: &Ident, value: &FlagValue, processed_flags: &HashMap<Ident, FlagValue>) -> FlagValue {
183    match value {
184        FlagValue::Expr(expr) => {
185            match parse_discriminant(ident, expr, processed_flags) {
186                Some(value) => FlagValue::Value(value),
187                None => value.clone()
188            }
189        },
190        _ => value.clone()
191    }
192}
193
194fn parse_discriminant(ident: &Ident, expr: &Expr, processed_flags: &HashMap<Ident, FlagValue>) -> Option<u128> {
195    match expr {
196        Expr::Lit(expr_lit) => {
197            match &expr_lit.lit {
198                Lit::Int(lit_int) => {
199                    Some(lit_int.base10_digits().parse::<u128>().unwrap())
200                },
201                _ => panic!("Invalid discriminant for {ident}")
202            }
203        },
204        Expr::Path(expr_path) => {
205            let segments = &expr_path.path.segments;
206            if segments.len() != 2 {
207                panic!("Invalid discriminant for {ident}")
208            }
209
210            if segments[0].ident != *ident && segments[0].ident.to_string() != "Self" {
211                panic!("Invalid discriminant for {ident}")
212            }
213
214            match processed_flags.get(&segments[1].ident) {
215                Some(flag_value) => match flag_value {
216                    FlagValue::Value(value) => Some(*value),
217                    _ => panic!("Invalid discriminant for {ident}")
218                },
219                None => None
220            }
221        },
222        Expr::Binary(binary) => {
223            let lhs = parse_discriminant(ident, &*binary.left, processed_flags)?;
224            let rhs = parse_discriminant(ident, &*binary.right, processed_flags)?;
225            Some(match binary.op {
226                BinOp::BitAnd(_) => lhs & rhs,
227                BinOp::BitOr(_) => lhs | rhs,
228                BinOp::BitXor(_) => lhs ^ rhs,
229                _ => panic!("Invalid discriminant for {ident}")
230            })
231        }
232        _ => panic!("Invalid discriminant for {ident}")
233    }
234}