1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{DeriveInput, parse_macro_input};
4
5#[proc_macro_derive(BitFlag)]
20pub fn derive_bitflag(input: TokenStream) -> TokenStream {
21 let input = parse_macro_input!(input as DeriveInput);
22 match impl_bitflag(&input) {
23 Ok(ts) => ts.into(),
24 Err(e) => e.to_compile_error().into(),
25 }
26}
27
28fn impl_bitflag(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
29 let name = &input.ident;
30
31 let data = match &input.data {
32 syn::Data::Enum(data) => data,
33 _ => {
34 return Err(syn::Error::new_spanned(
35 name,
36 "BitFlag can only be derived for enums",
37 ));
38 }
39 };
40
41 let variants: Vec<&syn::Ident> = data
42 .variants
43 .iter()
44 .map(|v| {
45 if !matches!(v.fields, syn::Fields::Unit) {
46 return Err(syn::Error::new_spanned(
47 &v.ident,
48 "BitFlag variants must be unit variants",
49 ));
50 }
51 Ok(&v.ident)
52 })
53 .collect::<syn::Result<Vec<_>>>()?;
54
55 let variant_names: Vec<String> = variants.iter().map(|v| v.to_string()).collect();
56
57 let flags_entries = variants.iter().zip(variant_names.iter()).map(|(v, s)| {
58 quote! { ::bitflagset::Flag::new(#s, #name::#v) }
59 });
60
61 let try_from_arms = variants.iter().map(|v| {
62 quote! { x if x == #name::#v as u8 => Ok(#name::#v) }
63 });
64
65 let max_value_arms = variants.iter().map(|v| {
66 quote! {
67 let value = #name::#v as u8;
68 if value > max {
69 max = value;
70 }
71 }
72 });
73
74 Ok(quote! {
75 const _: () = assert!(
76 core::mem::size_of::<#name>() == core::mem::size_of::<u8>(),
77 "BitFlag enum must use #[repr(u8)]"
78 );
79
80 impl From<#name> for u8 {
81 #[inline]
82 fn from(v: #name) -> u8 { v as u8 }
83 }
84
85 impl TryFrom<u8> for #name {
86 type Error = ();
87 fn try_from(v: u8) -> Result<Self, ()> {
88 match v {
89 #(#try_from_arms,)*
90 _ => Err(()),
91 }
92 }
93 }
94
95 impl ::bitflagset::BitFlag for #name {
96 type Mask = u8;
97 const FLAGS: &'static [::bitflagset::Flag<Self>] = &[
98 #(#flags_entries),*
99 ];
100 const MAX_VALUE: u8 = {
101 let mut max: u8 = 0;
102 #(#max_value_arms)*
103 max
104 };
105 }
106 })
107}
108
109#[proc_macro_derive(BitFlagSet, attributes(bitflagset))]
120pub fn derive_bitflagset(input: TokenStream) -> TokenStream {
121 let input = parse_macro_input!(input as DeriveInput);
122 match impl_bitflagset(&input) {
123 Ok(ts) => ts.into(),
124 Err(e) => e.to_compile_error().into(),
125 }
126}
127
128struct BitFlagSetArgs {
129 element: syn::Path,
130}
131
132fn parse_bitflagset_args(input: &DeriveInput) -> syn::Result<BitFlagSetArgs> {
133 let mut element: Option<syn::Path> = None;
134
135 for attr in &input.attrs {
136 if !attr.path().is_ident("bitflagset") {
137 continue;
138 }
139 attr.parse_nested_meta(|meta| {
140 if meta.path.is_ident("element") {
141 let value = meta.value()?;
142 element = Some(value.parse()?);
143 Ok(())
144 } else {
145 Err(meta.error("expected `element`"))
146 }
147 })?;
148 }
149
150 let element = element.ok_or_else(|| {
151 syn::Error::new_spanned(&input.ident, "missing #[bitflagset(element = Type)]")
152 })?;
153 Ok(BitFlagSetArgs { element })
154}
155
156fn impl_bitflagset(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
157 let args = parse_bitflagset_args(input)?;
158 let name = &input.ident;
159 let typ = &args.element;
160
161 let fields = match &input.data {
162 syn::Data::Struct(data) => &data.fields,
163 _ => {
164 return Err(syn::Error::new_spanned(
165 name,
166 "BitFlagSet can only be derived for structs",
167 ));
168 }
169 };
170 let repr = match fields {
171 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
172 &fields.unnamed.first().unwrap().ty
173 }
174 _ => {
175 return Err(syn::Error::new_spanned(
176 name,
177 "BitFlagSet struct must have exactly one unnamed field, e.g. `struct Foo(u8)`",
178 ));
179 }
180 };
181
182 Ok(quote! {
183 ::bitflagset::bitflagset!(@__derive_impls #name, #repr, #typ);
184 })
185}