bevy_trait_query_impl/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{parse_quote, ItemTrait, Result, TraitItem};
5
6/// When added to a trait declaration, generates the impls required to use that trait in queries.
7///
8/// # Poor use cases
9///
10/// You should avoid using trait queries for very simple cases that can be solved with more direct solutions.
11///
12/// One naive use would be querying for a trait that looks something like:
13///
14/// ```
15/// trait Person {
16///     fn name(&self) -> &str;
17/// }
18/// ```
19///
20/// A far better way of expressing this would be to store the name in a separate component
21/// and query for that directly, making `Person` a simple marker component.
22///
23/// Trait queries are often the most *obvious* solution to a problem, but not always the best one.
24/// For examples of strong real-world use-cases, check out the RFC for trait queries in `bevy`:
25/// <https://github.com/bevyengine/rfcs/pull/39>.
26///
27/// # Note
28///
29/// This will add the trait bound `'static` to the trait and all of its type parameters.
30///
31/// You may opt out of this by using the form `#[queryable(no_bounds)]`,
32/// but you will have to add the bounds yourself to make it compile.
33#[proc_macro_attribute]
34pub fn queryable(attr: TokenStream, item: TokenStream) -> TokenStream {
35    impl_trait_query(attr, item)
36        .unwrap_or_else(syn::Error::into_compile_error)
37        .into()
38}
39
40fn impl_trait_query(arg: TokenStream, item: TokenStream) -> Result<TokenStream2> {
41    syn::custom_keyword!(no_bounds);
42    let no_bounds: Option<no_bounds> = syn::parse(arg).map_err(|e| {
43        syn::Error::new(
44            e.span(),
45            "Valid forms are: `#[queryable]` and `#[queryable(no_bounds)]`",
46        )
47    })?;
48
49    let mut trait_definition = syn::parse::<ItemTrait>(item)?;
50    let trait_name = trait_definition.ident.clone();
51
52    // Add `'static` bounds, unless the user asked us not to.
53    if no_bounds.is_none() {
54        trait_definition.supertraits.push(parse_quote!('static));
55
56        for param in &mut trait_definition.generics.params {
57            // Make sure the parameters to the trait are `'static`.
58            if let syn::GenericParam::Type(param) = param {
59                param.bounds.push(parse_quote!('static));
60            }
61        }
62
63        for item in &mut trait_definition.items {
64            // Make sure all associated types are `'static`.
65            if let TraitItem::Type(assoc) = item {
66                assoc.bounds.push(parse_quote!('static));
67            }
68        }
69    }
70
71    let mut impl_generics_list = vec![];
72    let mut trait_generics_list = vec![];
73    let where_clause = trait_definition.generics.where_clause.clone();
74
75    for param in &trait_definition.generics.params {
76        impl_generics_list.push(param.clone());
77        match param {
78            syn::GenericParam::Type(param) => {
79                let ident = &param.ident;
80                trait_generics_list.push(quote! { #ident });
81            }
82            syn::GenericParam::Lifetime(param) => {
83                let ident = &param.lifetime;
84                trait_generics_list.push(quote! { #ident });
85            }
86            syn::GenericParam::Const(param) => {
87                let ident = &param.ident;
88                trait_generics_list.push(quote! { #ident });
89            }
90        }
91    }
92
93    // Add generics for unbounded associated types.
94    for item in &trait_definition.items {
95        if let TraitItem::Type(assoc) = item {
96            if !assoc.generics.params.is_empty() {
97                return Err(syn::Error::new(
98                    assoc.ident.span(),
99                    "Generic associated types are not supported in trait queries",
100                ));
101            }
102            let ident = &assoc.ident;
103            let lower_ident = format_ident!("__{ident}");
104            let bound = &assoc.bounds;
105            impl_generics_list.push(parse_quote! { #lower_ident: #bound });
106            trait_generics_list.push(quote! { #ident = #lower_ident });
107        }
108    }
109
110    let impl_generics = quote! { <#( #impl_generics_list ,)*> };
111    let trait_generics = quote! { <#( #trait_generics_list ,)*> };
112
113    let trait_object = quote! { dyn #trait_name #trait_generics };
114
115    let my_crate = proc_macro_crate::crate_name("bevy-trait-query").unwrap();
116    let my_crate = match my_crate {
117        proc_macro_crate::FoundCrate::Itself => quote! { bevy_trait_query },
118        proc_macro_crate::FoundCrate::Name(x) => {
119            let ident = quote::format_ident!("{x}");
120            quote! { #ident }
121        }
122    };
123
124    let imports = quote! { #my_crate::imports };
125
126    let trait_query = quote! { #my_crate::TraitQuery };
127
128    let mut marker_impl_generics_list = impl_generics_list.clone();
129    marker_impl_generics_list
130        .push(parse_quote!(__Component: #trait_name #trait_generics + #imports::Component));
131    let marker_impl_generics = quote! { <#( #marker_impl_generics_list ,)*> };
132
133    let marker_impl_code = quote! {
134        impl #impl_generics #trait_query for #trait_object #where_clause {}
135
136        impl #marker_impl_generics #my_crate::TraitQueryMarker::<#trait_object> for (__Component,)
137        #where_clause
138        {
139            type Covered = __Component;
140            fn cast(ptr: *mut u8) -> *mut #trait_object {
141                ptr as *mut __Component as *mut _
142            }
143        }
144    };
145
146    let mut impl_generics_with_lifetime = impl_generics_list.clone();
147    impl_generics_with_lifetime.insert(0, parse_quote!('__a));
148    let impl_generics_with_lifetime = quote! { <#( #impl_generics_with_lifetime ,)*> };
149
150    let trait_object_query_code = quote! {
151        unsafe impl #impl_generics #imports::QueryData for &#trait_object
152        #where_clause
153        {
154            type ReadOnly = Self;
155
156            const IS_READ_ONLY: bool = true;
157
158            type Item<'__w> = #my_crate::ReadTraits<'__w, #trait_object>;
159
160            #[inline]
161            fn shrink<'wlong: 'wshort, 'wshort>(
162                item: Self::Item<'wlong>,
163            ) -> Self::Item<'wshort> {
164                item
165            }
166
167            #[inline]
168            unsafe fn fetch<'w>(
169                fetch: &mut Self::Fetch<'w>,
170                entity: #imports::Entity,
171                table_row: #imports::TableRow,
172            ) -> Self::Item<'w> {
173                <#my_crate::All<&#trait_object> as #imports::QueryData>::fetch(
174                    fetch,
175                    entity,
176                    table_row,
177                )
178            }
179        }
180        unsafe impl #impl_generics #imports::ReadOnlyQueryData for &#trait_object
181        #where_clause
182        {}
183
184        unsafe impl #impl_generics_with_lifetime #imports::WorldQuery for &'__a #trait_object
185        #where_clause
186        {
187            type Fetch<'__w> = <#my_crate::All<&'__a #trait_object> as #imports::WorldQuery>::Fetch<'__w>;
188            type State = #my_crate::TraitQueryState<#trait_object>;
189
190            #[inline]
191            unsafe fn init_fetch<'w>(
192                world: #imports::UnsafeWorldCell<'w>,
193                state: &Self::State,
194                last_run: #imports::Tick,
195                this_run: #imports::Tick,
196            ) -> Self::Fetch<'w> {
197                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::init_fetch(
198                    world,
199                    state,
200                    last_run,
201                    this_run,
202                )
203            }
204
205            const IS_DENSE: bool = <#my_crate::All<&#trait_object> as #imports::WorldQuery>::IS_DENSE;
206
207            #[inline]
208            unsafe fn set_archetype<'w>(
209                fetch: &mut Self::Fetch<'w>,
210                state: &Self::State,
211                archetype: &'w #imports::Archetype,
212                tables: &'w #imports::Table,
213            ) {
214                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::set_archetype(
215                    fetch, state, archetype, tables,
216                );
217            }
218
219            #[inline]
220            unsafe fn set_table<'w>(
221                fetch: &mut Self::Fetch<'w>,
222                state: &Self::State,
223                table: &'w #imports::Table,
224            ) {
225                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::set_table(fetch, state, table);
226            }
227
228            #[inline]
229            fn update_component_access(
230                state: &Self::State,
231                access: &mut #imports::FilteredAccess<#imports::ComponentId>,
232            ) {
233                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::update_component_access(
234                    state, access,
235                );
236            }
237
238            #[inline]
239            fn init_state(world: &mut #imports::World) -> Self::State {
240                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::init_state(world)
241            }
242
243            #[inline]
244            fn get_state(_: &#imports::Components) -> Option<Self::State> {
245                // TODO: fix this https://github.com/bevyengine/bevy/issues/13798
246                panic!("transmuting and any other operations concerning the state of a query are currently broken and shouldn't be used. See https://github.com/JoJoJet/bevy-trait-query/issues/59");
247            }
248
249            #[inline]
250            fn matches_component_set(
251                state: &Self::State,
252                set_contains_id: &impl Fn(#imports::ComponentId) -> bool,
253            ) -> bool {
254                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::matches_component_set(state, set_contains_id)
255            }
256
257            #[inline]
258            fn shrink_fetch<'wlong: 'wshort, 'wshort>(fetch: Self::Fetch<'wlong>) -> Self::Fetch<'wshort> {
259                fetch
260            }
261        }
262
263        unsafe impl #impl_generics_with_lifetime #imports::QueryData for &'__a mut #trait_object
264        #where_clause
265        {
266            type ReadOnly = &'__a #trait_object;
267
268            type Item<'__w> = #my_crate::WriteTraits<'__w, #trait_object>;
269
270            const IS_READ_ONLY: bool = false;
271
272            #[inline]
273            fn shrink<'wlong: 'wshort, 'wshort>(
274                item: Self::Item<'wlong>,
275            ) -> Self::Item<'wshort> {
276                item
277            }
278
279            #[inline]
280            unsafe fn fetch<'w>(
281                fetch: &mut Self::Fetch<'w>,
282                entity: #imports::Entity,
283                table_row: #imports::TableRow,
284            ) -> Self::Item<'w> {
285                <#my_crate::All<&mut #trait_object> as #imports::QueryData>::fetch(
286                    fetch,
287                    entity,
288                    table_row,
289                )
290            }
291        }
292
293        unsafe impl #impl_generics_with_lifetime #imports::WorldQuery for &'__a mut #trait_object
294        #where_clause
295        {
296            type Fetch<'__w> = <#my_crate::All<&'__a #trait_object> as #imports::WorldQuery>::Fetch<'__w>;
297            type State = #my_crate::TraitQueryState<#trait_object>;
298
299            #[inline]
300            unsafe fn init_fetch<'w>(
301                world: #imports::UnsafeWorldCell<'w>,
302                state: &Self::State,
303                last_run: #imports::Tick,
304                this_run: #imports::Tick,
305            ) -> Self::Fetch<'w> {
306                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::init_fetch(
307                    world,
308                    state,
309                    last_run,
310                    this_run,
311                )
312            }
313
314            const IS_DENSE: bool = <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::IS_DENSE;
315
316            #[inline]
317            unsafe fn set_archetype<'w>(
318                fetch: &mut Self::Fetch<'w>,
319                state: &Self::State,
320                archetype: &'w #imports::Archetype,
321                table: &'w #imports::Table,
322            ) {
323                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::set_archetype(
324                    fetch, state, archetype, table,
325                );
326            }
327
328            #[inline]
329            unsafe fn set_table<'w>(
330                fetch: &mut Self::Fetch<'w>,
331                state: &Self::State,
332                table: &'w #imports::Table,
333            ) {
334                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::set_table(fetch, state, table);
335            }
336
337            #[inline]
338            fn update_component_access(
339                state: &Self::State,
340                access: &mut #imports::FilteredAccess<#imports::ComponentId>,
341            ) {
342                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::update_component_access(
343                    state, access,
344                );
345            }
346
347            #[inline]
348            fn init_state(world: &mut #imports::World) -> Self::State {
349                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::init_state(world)
350            }
351
352            #[inline]
353            fn get_state(_: &#imports::Components) -> Option<Self::State> {
354                // TODO: fix this https://github.com/bevyengine/bevy/issues/13798
355                panic!("transmuting and any other operations concerning the state of a query are currently broken and shouldn't be used. See https://github.com/JoJoJet/bevy-trait-query/issues/59");
356            }
357
358            #[inline]
359            fn matches_component_set(
360                state: &Self::State,
361                set_contains_id: &impl Fn(#imports::ComponentId) -> bool,
362            ) -> bool {
363                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::matches_component_set(state, set_contains_id)
364            }
365
366            #[inline]
367            fn shrink_fetch<'wlong: 'wshort, 'wshort>(fetch: Self::Fetch<'wlong>) -> Self::Fetch<'wshort> {
368                fetch
369            }
370        }
371    };
372
373    Ok(quote! {
374        #trait_definition
375
376        #marker_impl_code
377
378        #trait_object_query_code
379    })
380}