Skip to main content

bitflags_extras/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput};
4
5#[proc_macro_derive(BitflagExtras, attributes(bitflags_repr))]
6pub fn derive_bitflag_extras(input: TokenStream) -> TokenStream {
7    let input_ast = parse_macro_input!(input as DeriveInput);
8    let flag_name = input_ast.ident;
9
10    // Parse #[bitflags_repr(u32)] as syn::Type
11    let mut repr_type: Option<syn::Type> = None;
12
13    for attribute in input_ast.attrs {
14        if attribute.path().is_ident("bitflags_repr") {
15            let parsed_type = attribute
16                .parse_args::<syn::Type>()
17                .expect("Invalid #[bitflags_repr(...)] syntax. Example: #[bitflags_repr(u32)]");
18            repr_type = Some(parsed_type);
19            break;
20        }
21    }
22
23    let repr_type = repr_type.expect("Missing #[bitflags_repr(u8|u16|u32|u64)]");
24
25    let expanded = quote! {
26        impl binrw::BinRead for #flag_name {
27            type Args<'a> = ();
28
29            fn read_options<R: std::io::Read + std::io::Seek>(
30                reader: &mut R,
31                endian: binrw::Endian,
32                _: Self::Args<'_>,
33            ) -> binrw::BinResult<Self> {
34                let bits = <#repr_type as binrw::BinRead>::read_options(reader, endian, ())?;
35                Ok(Self::from_bits_truncate(bits))
36            }
37        }
38
39        impl binrw::BinWrite for #flag_name {
40            type Args<'a> = ();
41
42            fn write_options<W: std::io::Write>(
43                &self,
44                writer: &mut W,
45                endian: binrw::Endian,
46                _: Self::Args<'_>,
47            ) -> binrw::BinResult<()> {
48                let bits: #repr_type = self.bits();
49                match endian {
50                    binrw::Endian::Little => writer.write_all(&bits.to_le_bytes())?,
51                    binrw::Endian::Big => writer.write_all(&bits.to_be_bytes())?,
52                }
53                Ok(())
54            }
55        }
56
57        impl serde::Serialize for #flag_name {
58            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
59            where
60                S: serde::Serializer,
61            {
62                let mut string_flags = String::new();
63                let mut is_first = true;
64
65                for flag in self.iter() {
66                    if !is_first {
67                        string_flags.push('|');
68                    }
69                    is_first = false;
70
71                    string_flags.push_str(
72                        &format!("{:?}", flag)
73                            .replace(stringify!(#flag_name), "")
74                            .replace('(', "")
75                            .replace(')', ""),
76                    );
77                }
78
79                serializer.serialize_str(&string_flags)
80            }
81        }
82
83        impl crate::client::CalculateMetadata for #flag_name {
84            fn calculate<'a>(
85                &self,
86                context: &'a mut crate::client::MetadataContext,
87            ) -> &'a mut crate::client::MetadataContext {
88                self.bits().calculate(context)
89            }
90        }
91    };
92
93    TokenStream::from(expanded)
94}