bevy_trait_query_impl/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{ItemTrait, Result, TraitItem, parse_quote};
5
6/// When added to a trait declaration, generates the impls required to use that trait in queries.
7///
8/// # Poor use cases
9///
10/// You should avoid using trait queries for very simple cases that can be solved with more direct solutions.
11///
12/// One naive use would be querying for a trait that looks something like:
13///
14/// ```
15/// trait Person {
16///     fn name(&self) -> &str;
17/// }
18/// ```
19///
20/// A far better way of expressing this would be to store the name in a separate component
21/// and query for that directly, making `Person` a simple marker component.
22///
23/// Trait queries are often the most *obvious* solution to a problem, but not always the best one.
24/// For examples of strong real-world use-cases, check out the RFC for trait queries in `bevy`:
25/// <https://github.com/bevyengine/rfcs/pull/39>.
26///
27/// # Note
28///
29/// This will add the trait bound `'static` to the trait and all of its type parameters.
30///
31/// You may opt out of this by using the form `#[queryable(no_bounds)]`,
32/// but you will have to add the bounds yourself to make it compile.
33#[proc_macro_attribute]
34pub fn queryable(attr: TokenStream, item: TokenStream) -> TokenStream {
35    impl_trait_query(attr, item)
36        .unwrap_or_else(syn::Error::into_compile_error)
37        .into()
38}
39
40fn impl_trait_query(arg: TokenStream, item: TokenStream) -> Result<TokenStream2> {
41    syn::custom_keyword!(no_bounds);
42    let no_bounds: Option<no_bounds> = syn::parse(arg).map_err(|e| {
43        syn::Error::new(
44            e.span(),
45            "Valid forms are: `#[queryable]` and `#[queryable(no_bounds)]`",
46        )
47    })?;
48
49    let mut trait_definition = syn::parse::<ItemTrait>(item)?;
50    let trait_name = trait_definition.ident.clone();
51
52    // Add `'static` bounds, unless the user asked us not to.
53    if no_bounds.is_none() {
54        trait_definition.supertraits.push(parse_quote!('static));
55
56        for param in &mut trait_definition.generics.params {
57            // Make sure the parameters to the trait are `'static`.
58            if let syn::GenericParam::Type(param) = param {
59                param.bounds.push(parse_quote!('static));
60            }
61        }
62
63        for item in &mut trait_definition.items {
64            // Make sure all associated types are `'static`.
65            if let TraitItem::Type(assoc) = item {
66                assoc.bounds.push(parse_quote!('static));
67            }
68        }
69    }
70
71    let mut impl_generics_list = vec![];
72    let mut trait_generics_list = vec![];
73    let where_clause = trait_definition.generics.where_clause.clone();
74
75    for param in &trait_definition.generics.params {
76        impl_generics_list.push(param.clone());
77        match param {
78            syn::GenericParam::Type(param) => {
79                let ident = &param.ident;
80                trait_generics_list.push(quote! { #ident });
81            }
82            syn::GenericParam::Lifetime(param) => {
83                let ident = &param.lifetime;
84                trait_generics_list.push(quote! { #ident });
85            }
86            syn::GenericParam::Const(param) => {
87                let ident = &param.ident;
88                trait_generics_list.push(quote! { #ident });
89            }
90        }
91    }
92
93    // Add generics for unbounded associated types.
94    for item in &trait_definition.items {
95        if let TraitItem::Type(assoc) = item {
96            if !assoc.generics.params.is_empty() {
97                return Err(syn::Error::new(
98                    assoc.ident.span(),
99                    "Generic associated types are not supported in trait queries",
100                ));
101            }
102            let ident = &assoc.ident;
103            let lower_ident = format_ident!("__{ident}");
104            let bound = &assoc.bounds;
105            impl_generics_list.push(parse_quote! { #lower_ident: #bound });
106            trait_generics_list.push(quote! { #ident = #lower_ident });
107        }
108    }
109
110    let impl_generics = quote! { <#( #impl_generics_list ,)*> };
111    let trait_generics = quote! { <#( #trait_generics_list ,)*> };
112
113    let trait_object = quote! { dyn #trait_name #trait_generics };
114
115    let my_crate = proc_macro_crate::crate_name("bevy-trait-query").unwrap();
116    let my_crate = match my_crate {
117        proc_macro_crate::FoundCrate::Itself => quote! { bevy_trait_query },
118        proc_macro_crate::FoundCrate::Name(x) => {
119            let ident = quote::format_ident!("{x}");
120            quote! { #ident }
121        }
122    };
123
124    let imports = quote! { #my_crate::imports };
125
126    let trait_query = quote! { #my_crate::TraitQuery };
127
128    let mut marker_impl_generics_list = impl_generics_list.clone();
129    marker_impl_generics_list
130        .push(parse_quote!(__Component: #trait_name #trait_generics + #imports::Component));
131    let marker_impl_generics = quote! { <#( #marker_impl_generics_list ,)*> };
132
133    let marker_impl_code = quote! {
134        impl #impl_generics #trait_query for #trait_object #where_clause {}
135
136        impl #marker_impl_generics #my_crate::TraitQueryMarker::<#trait_object> for (__Component,)
137        #where_clause
138        {
139            type Covered = __Component;
140            fn cast(ptr: *mut u8) -> *mut #trait_object {
141                ptr as *mut __Component as *mut _
142            }
143        }
144    };
145
146    let mut impl_generics_with_lifetime = impl_generics_list.clone();
147    impl_generics_with_lifetime.insert(0, parse_quote!('__a));
148    let impl_generics_with_lifetime = quote! { <#( #impl_generics_with_lifetime ,)*> };
149
150    let trait_object_query_code = quote! {
151        unsafe impl #impl_generics #imports::QueryData for &#trait_object
152        #where_clause
153        {
154            type ReadOnly = Self;
155
156            const IS_READ_ONLY: bool = true;
157
158            type Item<'__w, '__s> = #my_crate::ReadTraits<'__w, #trait_object>;
159
160            #[inline]
161            fn shrink<'wlong: 'wshort, 'wshort, 's>(
162                item: Self::Item<'wlong, 's>,
163            ) -> Self::Item<'wshort, 's> {
164                item
165            }
166
167            #[inline]
168            unsafe fn fetch<'w, 's>(
169                state: &'s Self::State,
170                fetch: &mut Self::Fetch<'w>,
171                entity: #imports::Entity,
172                table_row: #imports::TableRow,
173            ) -> Self::Item<'w, 's> {
174                <#my_crate::All<&#trait_object> as #imports::QueryData>::fetch(
175                    state,
176                    fetch,
177                    entity,
178                    table_row,
179                )
180            }
181        }
182        unsafe impl #impl_generics #imports::ReadOnlyQueryData for &#trait_object
183        #where_clause
184        {}
185
186        unsafe impl #impl_generics_with_lifetime #imports::WorldQuery for &'__a #trait_object
187        #where_clause
188        {
189            type Fetch<'__w> = <#my_crate::All<&'__a #trait_object> as #imports::WorldQuery>::Fetch<'__w>;
190            type State = #my_crate::TraitQueryState<#trait_object>;
191
192            #[inline]
193            unsafe fn init_fetch<'w>(
194                world: #imports::UnsafeWorldCell<'w>,
195                state: &Self::State,
196                last_run: #imports::Tick,
197                this_run: #imports::Tick,
198            ) -> Self::Fetch<'w> {
199                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::init_fetch(
200                    world,
201                    state,
202                    last_run,
203                    this_run,
204                )
205            }
206
207            const IS_DENSE: bool = <#my_crate::All<&#trait_object> as #imports::WorldQuery>::IS_DENSE;
208
209            #[inline]
210            unsafe fn set_archetype<'w>(
211                fetch: &mut Self::Fetch<'w>,
212                state: &Self::State,
213                archetype: &'w #imports::Archetype,
214                tables: &'w #imports::Table,
215            ) {
216                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::set_archetype(
217                    fetch, state, archetype, tables,
218                );
219            }
220
221            #[inline]
222            unsafe fn set_table<'w>(
223                fetch: &mut Self::Fetch<'w>,
224                state: &Self::State,
225                table: &'w #imports::Table,
226            ) {
227                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::set_table(fetch, state, table);
228            }
229
230            #[inline]
231            fn update_component_access(
232                state: &Self::State,
233                access: &mut #imports::FilteredAccess,
234            ) {
235                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::update_component_access(
236                    state, access,
237                );
238            }
239
240            #[inline]
241            fn init_state(world: &mut #imports::World) -> Self::State {
242                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::init_state(world)
243            }
244
245            #[inline]
246            fn get_state(_: &#imports::Components) -> Option<Self::State> {
247                // TODO: fix this https://github.com/bevyengine/bevy/issues/13798
248                panic!("transmuting and any other operations concerning the state of a query are currently broken and shouldn't be used. See https://github.com/JoJoJet/bevy-trait-query/issues/59");
249            }
250
251            #[inline]
252            fn matches_component_set(
253                state: &Self::State,
254                set_contains_id: &impl Fn(#imports::ComponentId) -> bool,
255            ) -> bool {
256                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::matches_component_set(state, set_contains_id)
257            }
258
259            #[inline]
260            fn shrink_fetch<'wlong: 'wshort, 'wshort>(fetch: Self::Fetch<'wlong>) -> Self::Fetch<'wshort> {
261                fetch
262            }
263        }
264
265        unsafe impl #impl_generics_with_lifetime #imports::QueryData for &'__a mut #trait_object
266        #where_clause
267        {
268            type ReadOnly = &'__a #trait_object;
269
270            type Item<'__w, '__s> = #my_crate::WriteTraits<'__w, #trait_object>;
271
272            const IS_READ_ONLY: bool = false;
273
274            #[inline]
275            fn shrink<'wlong: 'wshort, 'wshort, 's>(
276                item: Self::Item<'wlong, 's>,
277            ) -> Self::Item<'wshort, 's> {
278                item
279            }
280
281            #[inline]
282            unsafe fn fetch<'w, 's>(
283                state: &'s Self::State,
284                fetch: &mut Self::Fetch<'w>,
285                entity: #imports::Entity,
286                table_row: #imports::TableRow,
287            ) -> Self::Item<'w, 's> {
288                <#my_crate::All<&mut #trait_object> as #imports::QueryData>::fetch(
289                    state,
290                    fetch,
291                    entity,
292                    table_row,
293                )
294            }
295        }
296
297        unsafe impl #impl_generics_with_lifetime #imports::WorldQuery for &'__a mut #trait_object
298        #where_clause
299        {
300            type Fetch<'__w> = <#my_crate::All<&'__a #trait_object> as #imports::WorldQuery>::Fetch<'__w>;
301            type State = #my_crate::TraitQueryState<#trait_object>;
302
303            #[inline]
304            unsafe fn init_fetch<'w>(
305                world: #imports::UnsafeWorldCell<'w>,
306                state: &Self::State,
307                last_run: #imports::Tick,
308                this_run: #imports::Tick,
309            ) -> Self::Fetch<'w> {
310                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::init_fetch(
311                    world,
312                    state,
313                    last_run,
314                    this_run,
315                )
316            }
317
318            const IS_DENSE: bool = <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::IS_DENSE;
319
320            #[inline]
321            unsafe fn set_archetype<'w>(
322                fetch: &mut Self::Fetch<'w>,
323                state: &Self::State,
324                archetype: &'w #imports::Archetype,
325                table: &'w #imports::Table,
326            ) {
327                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::set_archetype(
328                    fetch, state, archetype, table,
329                );
330            }
331
332            #[inline]
333            unsafe fn set_table<'w>(
334                fetch: &mut Self::Fetch<'w>,
335                state: &Self::State,
336                table: &'w #imports::Table,
337            ) {
338                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::set_table(fetch, state, table);
339            }
340
341            #[inline]
342            fn update_component_access(
343                state: &Self::State,
344                access: &mut #imports::FilteredAccess,
345            ) {
346                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::update_component_access(
347                    state, access,
348                );
349            }
350
351            #[inline]
352            fn init_state(world: &mut #imports::World) -> Self::State {
353                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::init_state(world)
354            }
355
356            #[inline]
357            fn get_state(_: &#imports::Components) -> Option<Self::State> {
358                // TODO: fix this https://github.com/bevyengine/bevy/issues/13798
359                panic!("transmuting and any other operations concerning the state of a query are currently broken and shouldn't be used. See https://github.com/JoJoJet/bevy-trait-query/issues/59");
360            }
361
362            #[inline]
363            fn matches_component_set(
364                state: &Self::State,
365                set_contains_id: &impl Fn(#imports::ComponentId) -> bool,
366            ) -> bool {
367                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::matches_component_set(state, set_contains_id)
368            }
369
370            #[inline]
371            fn shrink_fetch<'wlong: 'wshort, 'wshort>(fetch: Self::Fetch<'wlong>) -> Self::Fetch<'wshort> {
372                fetch
373            }
374        }
375    };
376
377    Ok(quote! {
378        #trait_definition
379
380        #marker_impl_code
381
382        #trait_object_query_code
383    })
384}