1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use proc_macro2::Ident;
use proc_macro_crate::{crate_name, FoundCrate};
use quote::{format_ident, quote};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{parse_macro_input, Data, DeriveInput, Path, PathSegment, Token};

/// Derive the `EnumFilter` trait on the given enum.
///
/// This will do a couple things:
/// 1. It will, of course, implement the `EnumFilter` trait
/// 2. It will generate a module with filter components for each enum variant
///
/// The generated module will have the name of the enum (snake-cased), appended by
/// `_filters`. So the enum, `MyEnum`, would generate a module called `my_enum_filters`.
///
/// The module will contain a zero-sized marker component struct for each variant.
/// For example, given the following enum:
///
/// ```
/// enum Foo {
///   Bar,
///   Baz(i32),
/// }
/// ```
///
/// We would end up generating the module `foo_filters` which contains the markers `Bar` and `Baz`.
///
/// See the [`Enum!`] macro for how to properly use this generated module.
#[proc_macro_derive(EnumFilter)]
pub fn derive_enum_filter(item: TokenStream) -> TokenStream {
    let input = parse_macro_input!(item as DeriveInput);

    let data = match input.data {
        Data::Enum(data) => data,
        Data::Struct(data) => {
            return syn::Error::new(
                data.struct_token.span,
                "Cannot derive `EnumTrait` on struct type",
            )
            .into_compile_error()
            .into();
        }
        Data::Union(data) => {
            return syn::Error::new(
                data.union_token.span,
                "Cannot derive `EnumTrait` on union type",
            )
            .into_compile_error()
            .into();
        }
    };

    let vis = &input.vis;
    let ident = &input.ident;
    let mod_ident = get_mod_ident(ident);
    let bevy_enum_filter = get_crate("bevy_enum_filter");
    let bevy = get_crate("bevy");

    let variants = data
        .variants
        .iter()
        .map(|variant| &variant.ident)
        .collect::<Vec<_>>();

    let docs = variants.iter().map(|variant| {
        format!(
            "Marker component generated for [`{}::{}`][super::{}::{}]",
            ident, variant, ident, variant
        )
    });
    let mod_doc = format!(
        "Auto-generated module containing marker components for each variant of [`{}`][super::{}]",
        ident, ident
    );

    let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl();

    TokenStream::from(quote! {
        impl #impl_generics #bevy_enum_filter::EnumFilter for #ident #ty_generics #where_clause {
            fn set_marker(commands: &mut #bevy::ecs::system::EntityCommands, value: &Self) {
                #(if matches!(value, #ident::#variants{..}) {
                    let entity = commands.id();
                    let mut commands = commands.commands();

                    commands.add(move |world: &mut #bevy::ecs::world::World| {
                        let mut entity_mut = world.entity_mut(entity);
                        if !entity_mut.contains::<#mod_ident::#variants>() {
                            // Only insert the marker if it doesn't already exist
                            entity_mut.insert(#mod_ident::#variants);
                        }
                    });
                } else {
                    commands.remove::<#mod_ident::#variants>();
                })*
            }
        }

        #[doc = #mod_doc]
        #[doc(hidden)]
        #vis mod #mod_ident {
            #(
                #[doc = #docs]
                #[doc(hidden)]
                #[derive(#bevy::prelude::Component)]
                pub struct #variants;
            )*
        }
    })
}

/// This macro can be used to retrieve the marker component generated by the [`EnumFilter`] derive for
/// the given enum value.
///
/// Because this macro relies on the module generated by the [`EnumFilter`] derive macro, you must
/// make sure it is in scope. Otherwise, you'll likely run into a compile error.
///
/// # Example
///
/// The basic usage of this macro looks like this:
///
/// ```ignore
/// type Marker = Enum!(Enum::Variant);
/// // or, Enum!(path::to::Enum::Variant)
/// ```
///
/// > Note: It doesn't matter whether `Enum::Variant` is a unit, tuple, or struct variant—
/// > you do __not__ need to specify any fields. Treat all variants like a unit variant.
///
/// ```ignore
/// // Make sure everything is in scope
/// use path::to::{Foo, foo_filters};
/// type Marker = Enum!(Foo::Baz);
/// ```
///
/// [`EnumFilter`]: derive@EnumFilter
#[allow(non_snake_case)]
#[proc_macro]
pub fn Enum(item: TokenStream) -> TokenStream {
    let input = parse_macro_input!(item as Path);

    let path_len = input.segments.len();

    if path_len < 2 {
        return syn::Error::new(
            input.span(),
            "expected a valid enum expression (i.e. `Foo::Bar`)",
        )
        .into_compile_error()
        .into();
    }

    let ident = input.segments[path_len - 2].ident.clone();
    let variant = input.segments[path_len - 1].ident.clone();
    let path_prefix = Punctuated::<PathSegment, Token![::]>::from_iter(
        input.segments.iter().take(path_len - 2).cloned(),
    );

    let mod_ident = get_mod_ident(&ident);

    let mod_path = if path_prefix.is_empty() {
        quote!(#mod_ident)
    } else {
        quote!(#path_prefix::#mod_ident)
    };

    TokenStream::from(quote! {
        #mod_path::#variant
    })
}

fn get_mod_ident(enum_ident: &Ident) -> Ident {
    format_ident!("{}_filters", enum_ident.to_string().to_case(Case::Snake))
}

fn get_crate(name: &str) -> proc_macro2::TokenStream {
    let found_crate = crate_name(name).expect(&format!("`{}` is present in `Cargo.toml`", name));

    match found_crate {
        FoundCrate::Itself => quote!(crate),
        FoundCrate::Name(name) => {
            let ident = format_ident!("{}", &name);
            quote!( #ident )
        }
    }
}