1use proc_macro::TokenStream;
2use proc_macro_error::*;
3use syn;
4use quote::quote;
5
6struct Flag {
7 value: TokenStream,
8 name: String,
9 method_name: Option<syn::Ident>,
10 doc: Option<String>,
11}
12
13impl Flag {
14 fn method_name(&self) -> syn::Ident {
15 if let Some(ident) = &self.method_name {
16 ident.clone()
17 } else {
18 syn::Ident::new(&self.name, proc_macro2::Span::call_site())
19 }
20 }
21
22}
23
24enum DebugMode {
25 None,
26 Standard,
27 Compact
28}
29
30struct FlagImpl {
31 struct_name: syn::Ident,
32 backing_field_name: syn::Member,
33 flags: Vec<Flag>,
34 debug_mode: DebugMode
35}
36
37fn get_value(lit: &syn::Lit, value_type: &syn::Type) -> TokenStream {
38 use syn::Lit::*;
39
40 let result = match lit {
41 Int(_) => quote!{
42 #lit as #value_type
43 },
44 Str(s) => {
45 let expr: syn::Expr = match syn::parse_str(&s.value()) {
46 Ok(expr) => expr,
47 _ => {
48 abort!(lit, "String must contain a valid expression");
49 }
50 };
51 quote! {
52 (#expr) as #value_type
53 }
54 },
55 _ => abort!(lit, "Bad value, must be an integer literal or string.")
56 };
57 result.into()
58}
59
60fn get_name(lit: &syn::Lit) -> String {
61 use syn::Lit::*;
62
63 match &lit {
64 Str(s) => s.value(),
65 _ => panic!("Bad name")
66 }
67}
68
69fn get_method_name(lit: &syn::Lit) -> syn::Ident {
70 use syn::{Lit::*, Ident};
71
72 match &lit {
73 Str(s) => Ident::new(&s.value(), lit.span()),
74 _ => panic!("Bad method_name")
75 }
76}
77
78fn get_doc(lit: &syn::Lit) -> String {
79 use syn::Lit::*;
80
81 match &lit {
82 Str(s) => s.value(),
83 _ => panic!("Bad doc attribute")
84 }
85}
86
87fn parse_flag(attr: syn::Meta, value_type: &syn::Type) -> Flag {
88 let mut name = None;
89 let mut value = None;
90 let mut method_name = None;
91 let mut doc = None;
92
93 if let syn::Meta::List(attr) = &attr {
94 use syn::{Meta::NameValue, NestedMeta::Meta};
95 let args = &attr.nested;
96
97 for arg in args {
98 if let Meta(NameValue(m)) = arg {
99 if let Some(n) = m.path.get_ident() {
100 match n.to_string().as_str() {
101 "name" => name = Some(get_name(&m.lit)),
102 "value" => value = Some(get_value(&m.lit, value_type)),
103 "method_name" => method_name = Some(get_method_name(&m.lit)),
104 "doc" => doc = Some(get_doc(&m.lit)),
105 s => abort!(arg, r#"Unknown configuration option "{}". Expected one of [name, value, method_name, doc]"#, s)
106 }
107 }
108 }
109 }
110 }
111
112 if let (Some(name), Some(value)) = (name, value) {
113 Flag {
114 name,
115 value,
116 method_name,
117 doc
118 }
119 } else {
120 abort!(attr, "Missing name or value argument for flag.")
121 }
122}
123
124fn get_backing_field(input: &syn::DeriveInput) -> (syn::Member, syn::Field) {
125 let st = if let syn::Data::Struct(ds) = &input.data {
126 ds
127 } else {
128 abort!(input, "Must be a struct")
129 };
130
131 let candidates: Vec<(syn::Member, &syn::Field)> = match &st.fields {
132 syn::Fields::Named(named) => {
133 named.named.iter()
134 .filter(|f| f.attrs.iter().any(|a| a.path.is_ident("flag_backing_field")))
135 .map(|f| (syn::Member::Named(f.ident.clone().unwrap()), f))
136 .collect()
137 },
138 syn::Fields::Unnamed(unnamed) => {
139 unnamed.unnamed.iter()
140 .enumerate()
141 .filter(|(_, f)| f.attrs.iter().any(|a| a.path.is_ident("flag_backing_field")))
142 .map(|(i, f)| (syn::Member::Unnamed(syn::Index::from(i)), f))
143 .collect()
144 },
145 _ => vec![]
146 };
147
148 if candidates.len() == 1 {
149 let (ident, field) = candidates.first().unwrap();
150 (ident.clone().into(), (*field).clone().into())
151 } else {
152 abort!(input, r#"Exactly one backing field must have the "flag_backing_field" attribute"#)
153 }
154}
155
156fn parse_impl(input: TokenStream) -> FlagImpl {
157 use syn::Meta::*;
158
159 let ast: syn::DeriveInput = syn::parse(input).unwrap();
160 let (backing_field_name, backing_field) = get_backing_field(&ast);
161 let struct_name = ast.ident.clone();
162 let mut flags = vec![];
163 let mut debug_mode = DebugMode::None;
164
165 for attr in ast.attrs {
166 if let Some(name) = attr.path.get_ident() {
167 match name.to_string().as_str() {
168 "flag" => {
169 let meta = attr.parse_meta().unwrap_or_else(|_| abort!(attr, "Bad attribute arguments"));
170 let flag = parse_flag(meta, &backing_field.ty);
171 flags.push(flag);
172 },
173 "flag_debug" => {
174 let meta = attr.parse_meta();
175 match meta {
176 Ok(Path(_)) => debug_mode = DebugMode::Standard,
177 Ok(List(ml)) => {
178 if let Some(syn::NestedMeta::Meta(m)) = ml.nested.first() {
179 if ml.nested.len() == 1 && m.path().is_ident("compact") {
180 debug_mode = DebugMode::Compact;
181 continue;
182 } else {
183 abort!(ml, "Bad option for flag_meta attribute");
184 }
185 } else {
186 debug_mode = DebugMode::Standard;
187 }
188 }
189 _ => abort!(attr, "Bad attribute arguments")
190 }
191 }
192 _ => ()
193 }
194 }
195 }
196
197 FlagImpl {
198 struct_name,
199 backing_field_name,
200 flags,
201 debug_mode
202 }
203}
204
205#[proc_macro_derive(Flags, attributes(flag, flag_backing_field, flag_debug))]
206#[proc_macro_error]
207pub fn derive_flags(input: TokenStream) -> TokenStream {
208 let mut flag_impl = parse_impl(input);
209 let backing_field_name = flag_impl.backing_field_name;
210 let struct_name = flag_impl.struct_name;
211
212 let mut methods = vec![];
213
214 let mut debug_fragments = vec![];
215
216 for flag in flag_impl.flags.drain(..) {
217 use quote::format_ident;
218 let name = flag.name.clone();
219 let method_name = flag.method_name();
220 let value: proc_macro2::TokenStream = flag.value.into();
221
222 match flag_impl.debug_mode {
223 DebugMode::None => (),
224 DebugMode::Standard => {
225 debug_fragments.push(quote!{
226 .field(stringify!(#method_name), &self.#method_name())
227 });
228 },
229 DebugMode::Compact => {
230 debug_fragments.push(quote!{
231 if self.#method_name() {
232 dbg.entry(&#name);
233 }
234 });
235 }
236 }
237
238 let (doc, set_doc, only_doc) = {
239 let doc_template = "Gets the value for the flag.";
240 let set_template = "Sets the flag to the given value.";
241 let only_template = "Checks if this flag is the only one set.";
242
243 if let Some(doc) = flag.doc {
244 let doc_str = format!("{}\n\n{}", doc, doc_template);
245 let set_str = format!("{}\n\n{}", doc, set_template);
246 let only_str = format!("{}\n\n{}", doc, only_template);
247 (
248 quote!{
249 #[doc = #doc_str]
250 },
251 quote!{
252 #[doc = #set_str]
253 },
254 quote!{
255 #[doc = #only_str]
256 }
257 )
258 } else {
259 (
260 quote!{
261 #[doc = #doc_template]
262 },
263 quote!{
264 #[doc = #set_template]
265 },
266 quote!{
267 #[doc = #only_template]
268 }
269 )
270 }
271 };
272
273 let setter_name = format_ident!("set_{}", method_name);
274 let exclusive_name = format_ident!("only_{}", method_name);
275 let flag_methods = quote!{
276 #doc
277 pub fn #method_name(&self) -> bool {
278 self.#backing_field_name & (#value) == (#value)
279 }
280 #only_doc
281 pub fn #exclusive_name(&self) -> bool {
282 self.#backing_field_name | (#value) == (#value)
283 }
284 #set_doc
285 pub fn #setter_name(&mut self, value: bool) -> &Self {
286 if value {
287 self.#backing_field_name |= (#value);
288 } else {
289 self.#backing_field_name &= !(#value)
290 }
291 self
292 }
293 };
294
295 methods.push(flag_methods);
296
297 }
298
299 let main_impl = quote!{
300 impl #struct_name {
301 #(#methods)*
302 }
303 };
304
305 let debug_impl = match flag_impl.debug_mode {
306 DebugMode::None => quote!{},
307 DebugMode::Standard => quote!{
308 impl core::fmt::Debug for #struct_name {
309 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
310 f.debug_struct(stringify!(#struct_name))
311 #(#debug_fragments)*
312 .finish()
313 }
314 }
315 },
316 DebugMode::Compact => quote!{
317 impl core::fmt::Debug for #struct_name {
318 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
319 write!(f, "{} ", stringify!(#struct_name))?;
320 let mut dbg = f.debug_set();
321 #(#debug_fragments)*
322 dbg.finish()
323 }
324 }
325 }
326 };
327
328 (quote!{
329 #main_impl
330
331 #debug_impl
332 }).into()
333}