flag_mast_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro_error::*;
3use syn;
4use quote::quote;
5
6struct Flag {
7    value: TokenStream,
8    name: String,
9    method_name: Option<syn::Ident>,
10    doc: Option<String>,
11}
12
13impl Flag {
14    fn method_name(&self) -> syn::Ident {
15	if let Some(ident) = &self.method_name {
16	    ident.clone()
17	} else {
18	    syn::Ident::new(&self.name, proc_macro2::Span::call_site())
19	}
20    }
21
22}
23
24enum DebugMode {
25    None,
26    Standard,
27    Compact
28}
29
30struct FlagImpl {
31    struct_name: syn::Ident,
32    backing_field_name: syn::Member,
33    flags: Vec<Flag>,
34    debug_mode: DebugMode
35}
36
37fn get_value(lit: &syn::Lit, value_type: &syn::Type) -> TokenStream {
38    use syn::Lit::*;
39    
40    let result = match lit {
41	Int(_) => quote!{
42	    #lit as #value_type
43	},
44	Str(s) => {
45	    let expr: syn::Expr = match syn::parse_str(&s.value()) {
46		Ok(expr) => expr,
47		_ => {
48		    abort!(lit, "String must contain a valid expression");
49	        }
50	    };
51	    quote! {
52		(#expr) as #value_type
53	    }
54	},
55	_ => abort!(lit, "Bad value, must be an integer literal or string.")
56    };
57    result.into()
58}
59
60fn get_name(lit: &syn::Lit) -> String {
61    use syn::Lit::*;
62
63    match &lit {
64	Str(s) => s.value(),
65	_ => panic!("Bad name")
66    }
67}
68
69fn get_method_name(lit: &syn::Lit) -> syn::Ident {
70    use syn::{Lit::*, Ident};
71
72    match &lit {
73	Str(s) => Ident::new(&s.value(), lit.span()),
74	_ => panic!("Bad method_name")
75    }
76}
77
78fn get_doc(lit: &syn::Lit) -> String {
79    use syn::Lit::*;
80
81    match &lit {
82	Str(s) => s.value(),
83	_ => panic!("Bad doc attribute")
84    }
85}
86
87fn parse_flag(attr: syn::Meta, value_type: &syn::Type) -> Flag {
88    let mut name = None;
89    let mut value = None;
90    let mut method_name = None;
91    let mut doc = None;
92    
93    if let syn::Meta::List(attr) = &attr {
94	use syn::{Meta::NameValue, NestedMeta::Meta};
95	let args = &attr.nested;
96	
97	for arg in args {
98	    if let Meta(NameValue(m)) = arg {
99	    	if let Some(n) = m.path.get_ident() {
100		    match n.to_string().as_str() {
101			"name" => name = Some(get_name(&m.lit)),
102			"value" => value = Some(get_value(&m.lit, value_type)),
103			"method_name" => method_name = Some(get_method_name(&m.lit)),
104			"doc" => doc = Some(get_doc(&m.lit)),
105			s => abort!(arg, r#"Unknown configuration option "{}". Expected one of [name, value, method_name, doc]"#, s)
106		    }
107		}
108	    }
109	}
110    }
111
112    if let (Some(name), Some(value)) = (name, value) {
113	Flag {
114	    name,
115	    value,
116	    method_name,
117	    doc
118	}
119    } else {
120	abort!(attr, "Missing name or value argument for flag.")
121    }
122}
123
124fn get_backing_field(input: &syn::DeriveInput) -> (syn::Member, syn::Field) {
125    let st = if let syn::Data::Struct(ds) = &input.data {
126	ds
127    } else {
128	abort!(input, "Must be a struct")
129    };
130
131    let candidates: Vec<(syn::Member, &syn::Field)> = match &st.fields {
132	syn::Fields::Named(named) => {
133	    named.named.iter()
134		.filter(|f| f.attrs.iter().any(|a| a.path.is_ident("flag_backing_field")))
135		.map(|f| (syn::Member::Named(f.ident.clone().unwrap()), f))
136		.collect()
137	},
138	syn::Fields::Unnamed(unnamed) => {
139	    unnamed.unnamed.iter()
140		.enumerate()
141		.filter(|(_, f)| f.attrs.iter().any(|a| a.path.is_ident("flag_backing_field")))
142		.map(|(i, f)| (syn::Member::Unnamed(syn::Index::from(i)), f))
143		.collect()
144	},
145	_ => vec![]
146    };
147
148    if candidates.len() == 1 {
149	let (ident, field) = candidates.first().unwrap();
150	(ident.clone().into(), (*field).clone().into())
151    } else {
152	abort!(input, r#"Exactly one backing field must have the "flag_backing_field" attribute"#)
153    }
154}
155
156fn parse_impl(input: TokenStream) -> FlagImpl {
157    use syn::Meta::*;
158    
159    let ast: syn::DeriveInput = syn::parse(input).unwrap();
160    let (backing_field_name, backing_field) = get_backing_field(&ast);
161    let struct_name = ast.ident.clone();
162    let mut flags = vec![];
163    let mut debug_mode = DebugMode::None;
164    
165    for attr in ast.attrs {
166	if let Some(name) = attr.path.get_ident() {
167	    match name.to_string().as_str() {
168		"flag" => {
169		    let meta = attr.parse_meta().unwrap_or_else(|_| abort!(attr, "Bad attribute arguments"));
170		    let flag = parse_flag(meta, &backing_field.ty);
171		    flags.push(flag);
172		},
173		"flag_debug" => {
174		    let meta = attr.parse_meta();
175		    match meta {
176			Ok(Path(_)) => debug_mode = DebugMode::Standard,
177			Ok(List(ml)) => {
178			    if let Some(syn::NestedMeta::Meta(m)) = ml.nested.first() {
179				if ml.nested.len() == 1 && m.path().is_ident("compact") {
180				    debug_mode = DebugMode::Compact;
181				    continue;
182				} else {
183				    abort!(ml, "Bad option for flag_meta attribute");
184				}
185			    } else {
186				debug_mode = DebugMode::Standard;
187			    }
188			}
189			_ => abort!(attr, "Bad attribute arguments")
190		    }
191		}
192		_ => ()
193	    }
194	}
195    }
196
197    FlagImpl {
198	struct_name,
199	backing_field_name,
200	flags,
201	debug_mode
202    }
203}
204
205#[proc_macro_derive(Flags, attributes(flag, flag_backing_field, flag_debug))]
206#[proc_macro_error]
207pub fn derive_flags(input: TokenStream) -> TokenStream {
208    let mut flag_impl = parse_impl(input);
209    let backing_field_name = flag_impl.backing_field_name;
210    let struct_name = flag_impl.struct_name;
211
212    let mut methods = vec![];
213
214    let mut debug_fragments = vec![];
215
216    for flag in flag_impl.flags.drain(..) {
217	use quote::format_ident;
218	let name = flag.name.clone();
219	let method_name = flag.method_name();
220	let value: proc_macro2::TokenStream = flag.value.into();
221	
222	match flag_impl.debug_mode {
223	    DebugMode::None => (),
224	    DebugMode::Standard => {
225		debug_fragments.push(quote!{
226		    .field(stringify!(#method_name), &self.#method_name())
227		});
228	    },
229	    DebugMode::Compact => {
230		debug_fragments.push(quote!{
231		    if self.#method_name() {
232			dbg.entry(&#name);
233		    }
234		});
235	    }
236	}
237
238	let (doc, set_doc, only_doc) = {
239	    let doc_template = "Gets the value for the flag.";
240	    let set_template = "Sets the flag to the given value.";
241	    let only_template = "Checks if this flag is the only one set.";
242	    
243	    if let Some(doc) = flag.doc {
244		let doc_str = format!("{}\n\n{}", doc, doc_template);
245		let set_str = format!("{}\n\n{}", doc, set_template);
246		let only_str = format!("{}\n\n{}", doc, only_template);
247		(
248		    quote!{
249			#[doc = #doc_str]
250		    },
251		    quote!{
252			#[doc = #set_str]
253		    },
254		    quote!{
255			#[doc = #only_str]
256		    }
257		)
258	    } else {
259		(
260		    quote!{
261			#[doc = #doc_template]
262		    },
263		    quote!{
264			#[doc = #set_template]
265		    },
266		    quote!{
267			#[doc = #only_template]
268		    }
269		)
270	    }
271	};
272
273	let setter_name = format_ident!("set_{}", method_name);
274	let exclusive_name = format_ident!("only_{}", method_name);
275	let flag_methods = quote!{
276	    #doc
277	    pub fn #method_name(&self) -> bool {
278		self.#backing_field_name & (#value) == (#value)
279	    }
280	    #only_doc
281	    pub fn #exclusive_name(&self) -> bool {
282		self.#backing_field_name | (#value) == (#value)
283	    }
284	    #set_doc
285	    pub fn #setter_name(&mut self, value: bool) -> &Self {
286		if value {
287		    self.#backing_field_name |= (#value);
288		} else {
289		    self.#backing_field_name &= !(#value)
290		}
291		self
292	    }
293	};
294	
295	methods.push(flag_methods);
296	
297    }
298
299    let main_impl = quote!{
300	impl #struct_name {
301	    #(#methods)*
302	}
303    };
304
305    let debug_impl = match flag_impl.debug_mode {
306	DebugMode::None => quote!{},
307	DebugMode::Standard => quote!{
308	    impl core::fmt::Debug for #struct_name {
309		fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
310		    f.debug_struct(stringify!(#struct_name))
311			#(#debug_fragments)*
312		    .finish()
313		}
314	    }
315	},
316	DebugMode::Compact => quote!{
317	    impl core::fmt::Debug for #struct_name {
318		fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
319		    write!(f, "{} ", stringify!(#struct_name))?;
320		    let mut dbg = f.debug_set();
321		    #(#debug_fragments)*
322		    dbg.finish()
323		}
324	    }
325	}
326    };
327
328    (quote!{
329	#main_impl
330
331	#debug_impl
332    }).into()
333}