bevy_trait_query_impl/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{ItemTrait, Result, TraitItem, parse_quote};
5
6/// When added to a trait declaration, generates the impls required to use that trait in queries.
7///
8/// # Poor use cases
9///
10/// You should avoid using trait queries for very simple cases that can be solved with more direct solutions.
11///
12/// One naive use would be querying for a trait that looks something like:
13///
14/// ```
15/// trait Person {
16///     fn name(&self) -> &str;
17/// }
18/// ```
19///
20/// A far better way of expressing this would be to store the name in a separate component
21/// and query for that directly, making `Person` a simple marker component.
22///
23/// Trait queries are often the most *obvious* solution to a problem, but not always the best one.
24/// For examples of strong real-world use-cases, check out the RFC for trait queries in `bevy`:
25/// <https://github.com/bevyengine/rfcs/pull/39>.
26///
27/// # Note
28///
29/// This will add the trait bound `'static` to the trait and all of its type parameters.
30///
31/// You may opt out of this by using the form `#[queryable(no_bounds)]`,
32/// but you will have to add the bounds yourself to make it compile.
33#[proc_macro_attribute]
34pub fn queryable(attr: TokenStream, item: TokenStream) -> TokenStream {
35    impl_trait_query(attr, item)
36        .unwrap_or_else(syn::Error::into_compile_error)
37        .into()
38}
39
40fn impl_trait_query(arg: TokenStream, item: TokenStream) -> Result<TokenStream2> {
41    syn::custom_keyword!(no_bounds);
42    let no_bounds: Option<no_bounds> = syn::parse(arg).map_err(|e| {
43        syn::Error::new(
44            e.span(),
45            "Valid forms are: `#[queryable]` and `#[queryable(no_bounds)]`",
46        )
47    })?;
48
49    let mut trait_definition = syn::parse::<ItemTrait>(item)?;
50    let trait_name = trait_definition.ident.clone();
51
52    // Add `'static` bounds, unless the user asked us not to.
53    if no_bounds.is_none() {
54        trait_definition.supertraits.push(parse_quote!('static));
55
56        for param in &mut trait_definition.generics.params {
57            // Make sure the parameters to the trait are `'static`.
58            if let syn::GenericParam::Type(param) = param {
59                param.bounds.push(parse_quote!('static));
60            }
61        }
62
63        for item in &mut trait_definition.items {
64            // Make sure all associated types are `'static`.
65            if let TraitItem::Type(assoc) = item {
66                assoc.bounds.push(parse_quote!('static));
67            }
68        }
69    }
70
71    let mut impl_generics_list = vec![];
72    let mut trait_generics_list = vec![];
73    let where_clause = trait_definition.generics.where_clause.clone();
74
75    for param in &trait_definition.generics.params {
76        impl_generics_list.push(param.clone());
77        match param {
78            syn::GenericParam::Type(param) => {
79                let ident = &param.ident;
80                trait_generics_list.push(quote! { #ident });
81            }
82            syn::GenericParam::Lifetime(param) => {
83                let ident = &param.lifetime;
84                trait_generics_list.push(quote! { #ident });
85            }
86            syn::GenericParam::Const(param) => {
87                let ident = &param.ident;
88                trait_generics_list.push(quote! { #ident });
89            }
90        }
91    }
92
93    // Add generics for unbounded associated types.
94    for item in &trait_definition.items {
95        if let TraitItem::Type(assoc) = item {
96            if !assoc.generics.params.is_empty() {
97                return Err(syn::Error::new(
98                    assoc.ident.span(),
99                    "Generic associated types are not supported in trait queries",
100                ));
101            }
102            let ident = &assoc.ident;
103            let lower_ident = format_ident!("__{ident}");
104            let bound = &assoc.bounds;
105            impl_generics_list.push(parse_quote! { #lower_ident: #bound });
106            trait_generics_list.push(quote! { #ident = #lower_ident });
107        }
108    }
109
110    let impl_generics = quote! { <#( #impl_generics_list ,)*> };
111    let trait_generics = quote! { <#( #trait_generics_list ,)*> };
112
113    let trait_object = quote! { dyn #trait_name #trait_generics };
114
115    let my_crate = proc_macro_crate::crate_name("bevy-trait-query").unwrap();
116    let my_crate = match my_crate {
117        proc_macro_crate::FoundCrate::Itself => quote! { bevy_trait_query },
118        proc_macro_crate::FoundCrate::Name(x) => {
119            let ident = quote::format_ident!("{x}");
120            quote! { #ident }
121        }
122    };
123
124    let imports = quote! { #my_crate::imports };
125
126    let trait_query = quote! { #my_crate::TraitQuery };
127
128    let mut marker_impl_generics_list = impl_generics_list.clone();
129    marker_impl_generics_list
130        .push(parse_quote!(__Component: #trait_name #trait_generics + #imports::Component));
131    let marker_impl_generics = quote! { <#( #marker_impl_generics_list ,)*> };
132
133    let marker_impl_code = quote! {
134        impl #impl_generics #trait_query for #trait_object #where_clause {}
135
136        impl #marker_impl_generics #my_crate::TraitQueryMarker::<#trait_object> for (__Component,)
137        #where_clause
138        {
139            type Covered = __Component;
140            fn cast(ptr: *mut u8) -> *mut #trait_object {
141                ptr as *mut __Component as *mut _
142            }
143        }
144    };
145
146    let mut impl_generics_with_lifetime = impl_generics_list.clone();
147    impl_generics_with_lifetime.insert(0, parse_quote!('__a));
148    let impl_generics_with_lifetime = quote! { <#( #impl_generics_with_lifetime ,)*> };
149
150    let trait_object_query_code = quote! {
151        unsafe impl #impl_generics #imports::QueryData for &#trait_object
152        #where_clause
153        {
154            type ReadOnly = Self;
155
156            const IS_READ_ONLY: bool = true;
157            const IS_ARCHETYPAL: bool = false;
158
159            type Item<'__w, '__s> = #my_crate::ReadTraits<'__w, #trait_object>;
160
161            #[inline]
162            fn shrink<'wlong: 'wshort, 'wshort, 's>(
163                item: Self::Item<'wlong, 's>,
164            ) -> Self::Item<'wshort, 's> {
165                item
166            }
167
168            #[inline]
169            unsafe fn fetch<'w, 's>(
170                state: &'s Self::State,
171                fetch: &mut Self::Fetch<'w>,
172                entity: #imports::Entity,
173                table_row: #imports::TableRow,
174            ) -> Option<Self::Item<'w, 's>> {
175                <#my_crate::All<&#trait_object> as #imports::QueryData>::fetch(
176                    state,
177                    fetch,
178                    entity,
179                    table_row,
180                )
181            }
182
183            fn iter_access(
184                _state: &Self::State,
185            ) -> impl Iterator<Item = bevy_ecs::query::EcsAccessType<'_>> {
186                std::iter::empty()
187            }
188        }
189        unsafe impl #impl_generics #imports::ReadOnlyQueryData for &#trait_object
190        #where_clause
191        {}
192
193        unsafe impl #impl_generics_with_lifetime #imports::WorldQuery for &'__a #trait_object
194        #where_clause
195        {
196            type Fetch<'__w> = <#my_crate::All<&'__a #trait_object> as #imports::WorldQuery>::Fetch<'__w>;
197            type State = #my_crate::TraitQueryState<#trait_object>;
198
199            #[inline]
200            unsafe fn init_fetch<'w>(
201                world: #imports::UnsafeWorldCell<'w>,
202                state: &Self::State,
203                last_run: #imports::Tick,
204                this_run: #imports::Tick,
205            ) -> Self::Fetch<'w> {
206                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::init_fetch(
207                    world,
208                    state,
209                    last_run,
210                    this_run,
211                )
212            }
213
214            const IS_DENSE: bool = <#my_crate::All<&#trait_object> as #imports::WorldQuery>::IS_DENSE;
215
216            #[inline]
217            unsafe fn set_archetype<'w>(
218                fetch: &mut Self::Fetch<'w>,
219                state: &Self::State,
220                archetype: &'w #imports::Archetype,
221                tables: &'w #imports::Table,
222            ) {
223                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::set_archetype(
224                    fetch, state, archetype, tables,
225                );
226            }
227
228            #[inline]
229            unsafe fn set_table<'w>(
230                fetch: &mut Self::Fetch<'w>,
231                state: &Self::State,
232                table: &'w #imports::Table,
233            ) {
234                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::set_table(fetch, state, table);
235            }
236
237            #[inline]
238            fn update_component_access(
239                state: &Self::State,
240                access: &mut #imports::FilteredAccess,
241            ) {
242                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::update_component_access(
243                    state, access,
244                );
245            }
246
247            #[inline]
248            fn init_state(world: &mut #imports::World) -> Self::State {
249                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::init_state(world)
250            }
251
252            #[inline]
253            fn get_state(_: &#imports::Components) -> Option<Self::State> {
254                // TODO: fix this https://github.com/bevyengine/bevy/issues/13798
255                panic!("transmuting and any other operations concerning the state of a query are currently broken and shouldn't be used. See https://github.com/JoJoJet/bevy-trait-query/issues/59");
256            }
257
258            #[inline]
259            fn matches_component_set(
260                state: &Self::State,
261                set_contains_id: &impl Fn(#imports::ComponentId) -> bool,
262            ) -> bool {
263                <#my_crate::All<&#trait_object> as #imports::WorldQuery>::matches_component_set(state, set_contains_id)
264            }
265
266            #[inline]
267            fn shrink_fetch<'wlong: 'wshort, 'wshort>(fetch: Self::Fetch<'wlong>) -> Self::Fetch<'wshort> {
268                fetch
269            }
270        }
271
272        unsafe impl #impl_generics_with_lifetime #imports::QueryData for &'__a mut #trait_object
273        #where_clause
274        {
275            type ReadOnly = &'__a #trait_object;
276
277            type Item<'__w, '__s> = #my_crate::WriteTraits<'__w, #trait_object>;
278
279            const IS_READ_ONLY: bool = false;
280            const IS_ARCHETYPAL: bool = false;
281
282            #[inline]
283            fn shrink<'wlong: 'wshort, 'wshort, 's>(
284                item: Self::Item<'wlong, 's>,
285            ) -> Self::Item<'wshort, 's> {
286                item
287            }
288
289            #[inline]
290            unsafe fn fetch<'w, 's>(
291                state: &'s Self::State,
292                fetch: &mut Self::Fetch<'w>,
293                entity: #imports::Entity,
294                table_row: #imports::TableRow,
295            ) -> Option<Self::Item<'w, 's>> {
296                <#my_crate::All<&mut #trait_object> as #imports::QueryData>::fetch(
297                    state,
298                    fetch,
299                    entity,
300                    table_row,
301                )
302            }
303
304            fn iter_access(
305                _state: &Self::State,
306            ) -> impl Iterator<Item = bevy_ecs::query::EcsAccessType<'_>> {
307                std::iter::empty()
308            }
309        }
310
311        unsafe impl #impl_generics_with_lifetime #imports::WorldQuery for &'__a mut #trait_object
312        #where_clause
313        {
314            type Fetch<'__w> = <#my_crate::All<&'__a #trait_object> as #imports::WorldQuery>::Fetch<'__w>;
315            type State = #my_crate::TraitQueryState<#trait_object>;
316
317            #[inline]
318            unsafe fn init_fetch<'w>(
319                world: #imports::UnsafeWorldCell<'w>,
320                state: &Self::State,
321                last_run: #imports::Tick,
322                this_run: #imports::Tick,
323            ) -> Self::Fetch<'w> {
324                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::init_fetch(
325                    world,
326                    state,
327                    last_run,
328                    this_run,
329                )
330            }
331
332            const IS_DENSE: bool = <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::IS_DENSE;
333
334            #[inline]
335            unsafe fn set_archetype<'w>(
336                fetch: &mut Self::Fetch<'w>,
337                state: &Self::State,
338                archetype: &'w #imports::Archetype,
339                table: &'w #imports::Table,
340            ) {
341                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::set_archetype(
342                    fetch, state, archetype, table,
343                );
344            }
345
346            #[inline]
347            unsafe fn set_table<'w>(
348                fetch: &mut Self::Fetch<'w>,
349                state: &Self::State,
350                table: &'w #imports::Table,
351            ) {
352                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::set_table(fetch, state, table);
353            }
354
355            #[inline]
356            fn update_component_access(
357                state: &Self::State,
358                access: &mut #imports::FilteredAccess,
359            ) {
360                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::update_component_access(
361                    state, access,
362                );
363            }
364
365            #[inline]
366            fn init_state(world: &mut #imports::World) -> Self::State {
367                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::init_state(world)
368            }
369
370            #[inline]
371            fn get_state(_: &#imports::Components) -> Option<Self::State> {
372                // TODO: fix this https://github.com/bevyengine/bevy/issues/13798
373                panic!("transmuting and any other operations concerning the state of a query are currently broken and shouldn't be used. See https://github.com/JoJoJet/bevy-trait-query/issues/59");
374            }
375
376            #[inline]
377            fn matches_component_set(
378                state: &Self::State,
379                set_contains_id: &impl Fn(#imports::ComponentId) -> bool,
380            ) -> bool {
381                <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::matches_component_set(state, set_contains_id)
382            }
383
384            #[inline]
385            fn shrink_fetch<'wlong: 'wshort, 'wshort>(fetch: Self::Fetch<'wlong>) -> Self::Fetch<'wshort> {
386                fetch
387            }
388        }
389    };
390
391    Ok(quote! {
392        #trait_definition
393
394        #marker_impl_code
395
396        #trait_object_query_code
397    })
398}