bevy_enum_filter_derive/
lib.rs

1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use proc_macro2::Ident;
4use proc_macro_crate::{crate_name, FoundCrate};
5use quote::{format_ident, quote};
6use syn::punctuated::Punctuated;
7use syn::spanned::Spanned;
8use syn::{parse_macro_input, Data, DeriveInput, Path, PathSegment, Token};
9
10/// Derive the `EnumFilter` trait on the given enum.
11///
12/// This will do a couple things:
13/// 1. It will, of course, implement the `EnumFilter` trait
14/// 2. It will generate a module with filter components for each enum variant
15///
16/// The generated module will have the name of the enum (snake-cased), appended by
17/// `_filters`. So the enum, `MyEnum`, would generate a module called `my_enum_filters`.
18///
19/// The module will contain a zero-sized marker component struct for each variant.
20/// For example, given the following enum:
21///
22/// ```
23/// enum Foo {
24///   Bar,
25///   Baz(i32),
26/// }
27/// ```
28///
29/// We would end up generating the module `foo_filters` which contains the markers `Bar` and `Baz`.
30///
31/// See the [`Enum!`] macro for how to properly use this generated module.
32#[proc_macro_derive(EnumFilter)]
33pub fn derive_enum_filter(item: TokenStream) -> TokenStream {
34    let input = parse_macro_input!(item as DeriveInput);
35
36    let data = match input.data {
37        Data::Enum(data) => data,
38        Data::Struct(data) => {
39            return syn::Error::new(
40                data.struct_token.span,
41                "Cannot derive `EnumTrait` on struct type",
42            )
43            .into_compile_error()
44            .into();
45        }
46        Data::Union(data) => {
47            return syn::Error::new(
48                data.union_token.span,
49                "Cannot derive `EnumTrait` on union type",
50            )
51            .into_compile_error()
52            .into();
53        }
54    };
55
56    let vis = &input.vis;
57    let ident = &input.ident;
58    let mod_ident = get_mod_ident(ident);
59    let bevy_enum_filter = get_crate("bevy_enum_filter");
60    let bevy = get_crate("bevy");
61
62    let variants = data
63        .variants
64        .iter()
65        .map(|variant| &variant.ident)
66        .collect::<Vec<_>>();
67
68    let docs = variants.iter().map(|variant| {
69        format!(
70            "Marker component generated for [`{}::{}`][super::{}::{}]",
71            ident, variant, ident, variant
72        )
73    });
74    let mod_doc = format!(
75        "Auto-generated module containing marker components for each variant of [`{}`][super::{}]",
76        ident, ident
77    );
78
79    let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl();
80
81    TokenStream::from(quote! {
82        impl #impl_generics #bevy_enum_filter::EnumFilter for #ident #ty_generics #where_clause {
83            fn set_marker(commands: &mut #bevy::ecs::system::EntityCommands, value: &Self) {
84                #(if matches!(value, #ident::#variants{..}) {
85                    let entity = commands.id();
86                    let mut commands = commands.commands();
87
88                    commands.add(move |world: &mut #bevy::ecs::world::World| {
89                        let mut entity_mut = world.entity_mut(entity);
90                        if !entity_mut.contains::<#mod_ident::#variants>() {
91                            // Only insert the marker if it doesn't already exist
92                            entity_mut.insert(#mod_ident::#variants);
93                        }
94                    });
95                } else {
96                    commands.remove::<#mod_ident::#variants>();
97                })*
98            }
99        }
100
101        #[doc = #mod_doc]
102        #[doc(hidden)]
103        #vis mod #mod_ident {
104            #(
105                #[doc = #docs]
106                #[doc(hidden)]
107                #[derive(#bevy::prelude::Component)]
108                pub struct #variants;
109            )*
110        }
111    })
112}
113
114/// This macro can be used to retrieve the marker component generated by the [`EnumFilter`] derive for
115/// the given enum value.
116///
117/// Because this macro relies on the module generated by the [`EnumFilter`] derive macro, you must
118/// make sure it is in scope. Otherwise, you'll likely run into a compile error.
119///
120/// # Example
121///
122/// The basic usage of this macro looks like this:
123///
124/// ```ignore
125/// type Marker = Enum!(Enum::Variant);
126/// // or, Enum!(path::to::Enum::Variant)
127/// ```
128///
129/// > Note: It doesn't matter whether `Enum::Variant` is a unit, tuple, or struct variant—
130/// > you do __not__ need to specify any fields. Treat all variants like a unit variant.
131///
132/// ```ignore
133/// // Make sure everything is in scope
134/// use path::to::{Foo, foo_filters};
135/// type Marker = Enum!(Foo::Baz);
136/// ```
137///
138/// [`EnumFilter`]: derive@EnumFilter
139#[allow(non_snake_case)]
140#[proc_macro]
141pub fn Enum(item: TokenStream) -> TokenStream {
142    let input = parse_macro_input!(item as Path);
143
144    let path_len = input.segments.len();
145
146    if path_len < 2 {
147        return syn::Error::new(
148            input.span(),
149            "expected a valid enum expression (i.e. `Foo::Bar`)",
150        )
151        .into_compile_error()
152        .into();
153    }
154
155    let ident = input.segments[path_len - 2].ident.clone();
156    let variant = input.segments[path_len - 1].ident.clone();
157    let path_prefix = Punctuated::<PathSegment, Token![::]>::from_iter(
158        input.segments.iter().take(path_len - 2).cloned(),
159    );
160
161    let mod_ident = get_mod_ident(&ident);
162
163    let mod_path = if path_prefix.is_empty() {
164        quote!(#mod_ident)
165    } else {
166        quote!(#path_prefix::#mod_ident)
167    };
168
169    TokenStream::from(quote! {
170        #mod_path::#variant
171    })
172}
173
174fn get_mod_ident(enum_ident: &Ident) -> Ident {
175    format_ident!("{}_filters", enum_ident.to_string().to_case(Case::Snake))
176}
177
178fn get_crate(name: &str) -> proc_macro2::TokenStream {
179    let found_crate = crate_name(name).expect(&format!("`{}` is present in `Cargo.toml`", name));
180
181    match found_crate {
182        FoundCrate::Itself => quote!(crate),
183        FoundCrate::Name(name) => {
184            let ident = format_ident!("{}", &name);
185            quote!( #ident )
186        }
187    }
188}