Skip to main content

bitflagset_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{DeriveInput, parse_macro_input};
4
5/// Derive alternative to [`bitflagset::bitflag!`] for `#[repr(u8)]` enums.
6///
7/// Generates `From<Enum> for u8`, `TryFrom<u8> for Enum`, and `impl BitFlag`.
8/// Useful for tools like cbindgen that cannot expand `macro_rules!` invocations.
9///
10/// ```ignore
11/// #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, BitFlag)]
12/// #[repr(u8)]
13/// enum Color {
14///     Red,
15///     Green,
16///     Blue,
17/// }
18/// ```
19#[proc_macro_derive(BitFlag)]
20pub fn derive_bitflag(input: TokenStream) -> TokenStream {
21    let input = parse_macro_input!(input as DeriveInput);
22    match impl_bitflag(&input) {
23        Ok(ts) => ts.into(),
24        Err(e) => e.to_compile_error().into(),
25    }
26}
27
28fn impl_bitflag(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
29    let name = &input.ident;
30
31    let data = match &input.data {
32        syn::Data::Enum(data) => data,
33        _ => {
34            return Err(syn::Error::new_spanned(
35                name,
36                "BitFlag can only be derived for enums",
37            ));
38        }
39    };
40
41    let variants: Vec<&syn::Ident> = data
42        .variants
43        .iter()
44        .map(|v| {
45            if !matches!(v.fields, syn::Fields::Unit) {
46                return Err(syn::Error::new_spanned(
47                    &v.ident,
48                    "BitFlag variants must be unit variants",
49                ));
50            }
51            Ok(&v.ident)
52        })
53        .collect::<syn::Result<Vec<_>>>()?;
54
55    let variant_names: Vec<String> = variants.iter().map(|v| v.to_string()).collect();
56
57    let flags_entries = variants.iter().zip(variant_names.iter()).map(|(v, s)| {
58        quote! { ::bitflagset::Flag::new(#s, #name::#v) }
59    });
60
61    let try_from_arms = variants.iter().map(|v| {
62        quote! { x if x == #name::#v as u8 => Ok(#name::#v) }
63    });
64
65    let max_value_arms = variants.iter().map(|v| {
66        quote! {
67            let value = #name::#v as u8;
68            if value > max {
69                max = value;
70            }
71        }
72    });
73
74    Ok(quote! {
75        const _: () = assert!(
76            core::mem::size_of::<#name>() == core::mem::size_of::<u8>(),
77            "BitFlag enum must use #[repr(u8)]"
78        );
79
80        impl From<#name> for u8 {
81            #[inline]
82            fn from(v: #name) -> u8 { v as u8 }
83        }
84
85        impl TryFrom<u8> for #name {
86            type Error = ();
87            fn try_from(v: u8) -> Result<Self, ()> {
88                match v {
89                    #(#try_from_arms,)*
90                    _ => Err(()),
91                }
92            }
93        }
94
95        impl ::bitflagset::BitFlag for #name {
96            type Mask = u8;
97            const FLAGS: &'static [::bitflagset::Flag<Self>] = &[
98                #(#flags_entries),*
99            ];
100            const MAX_VALUE: u8 = {
101                let mut max: u8 = 0;
102                #(#max_value_arms)*
103                max
104            };
105        }
106    })
107}
108
109/// Derive alternative to the enum form of [`bitflagset::bitflagset!`].
110///
111/// Useful for tools like cbindgen that cannot expand `macro_rules!` invocations.
112/// The struct must be a tuple struct with a single primitive field.
113///
114/// ```ignore
115/// #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, BitFlagSet)]
116/// #[bitflagset(element = Color)]
117/// struct ColorSet(u8);
118/// ```
119#[proc_macro_derive(BitFlagSet, attributes(bitflagset))]
120pub fn derive_bitflagset(input: TokenStream) -> TokenStream {
121    let input = parse_macro_input!(input as DeriveInput);
122    match impl_bitflagset(&input) {
123        Ok(ts) => ts.into(),
124        Err(e) => e.to_compile_error().into(),
125    }
126}
127
128struct BitFlagSetArgs {
129    element: syn::Path,
130}
131
132fn parse_bitflagset_args(input: &DeriveInput) -> syn::Result<BitFlagSetArgs> {
133    let mut element: Option<syn::Path> = None;
134
135    for attr in &input.attrs {
136        if !attr.path().is_ident("bitflagset") {
137            continue;
138        }
139        attr.parse_nested_meta(|meta| {
140            if meta.path.is_ident("element") {
141                let value = meta.value()?;
142                element = Some(value.parse()?);
143                Ok(())
144            } else {
145                Err(meta.error("expected `element`"))
146            }
147        })?;
148    }
149
150    let element = element.ok_or_else(|| {
151        syn::Error::new_spanned(&input.ident, "missing #[bitflagset(element = Type)]")
152    })?;
153    Ok(BitFlagSetArgs { element })
154}
155
156fn impl_bitflagset(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
157    let args = parse_bitflagset_args(input)?;
158    let name = &input.ident;
159    let typ = &args.element;
160
161    let fields = match &input.data {
162        syn::Data::Struct(data) => &data.fields,
163        _ => {
164            return Err(syn::Error::new_spanned(
165                name,
166                "BitFlagSet can only be derived for structs",
167            ));
168        }
169    };
170    let repr = match fields {
171        syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
172            &fields.unnamed.first().unwrap().ty
173        }
174        _ => {
175            return Err(syn::Error::new_spanned(
176                name,
177                "BitFlagSet struct must have exactly one unnamed field, e.g. `struct Foo(u8)`",
178            ));
179        }
180    };
181
182    Ok(quote! {
183        ::bitflagset::bitflagset!(@__derive_impls #name, #repr, #typ);
184    })
185}