1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{parse_quote, ItemTrait, Result, TraitItem};
5
6#[proc_macro_attribute]
34pub fn queryable(attr: TokenStream, item: TokenStream) -> TokenStream {
35 impl_trait_query(attr, item)
36 .unwrap_or_else(syn::Error::into_compile_error)
37 .into()
38}
39
40fn impl_trait_query(arg: TokenStream, item: TokenStream) -> Result<TokenStream2> {
41 syn::custom_keyword!(no_bounds);
42 let no_bounds: Option<no_bounds> = syn::parse(arg).map_err(|e| {
43 syn::Error::new(
44 e.span(),
45 "Valid forms are: `#[queryable]` and `#[queryable(no_bounds)]`",
46 )
47 })?;
48
49 let mut trait_definition = syn::parse::<ItemTrait>(item)?;
50 let trait_name = trait_definition.ident.clone();
51
52 if no_bounds.is_none() {
54 trait_definition.supertraits.push(parse_quote!('static));
55
56 for param in &mut trait_definition.generics.params {
57 if let syn::GenericParam::Type(param) = param {
59 param.bounds.push(parse_quote!('static));
60 }
61 }
62
63 for item in &mut trait_definition.items {
64 if let TraitItem::Type(assoc) = item {
66 assoc.bounds.push(parse_quote!('static));
67 }
68 }
69 }
70
71 let mut impl_generics_list = vec![];
72 let mut trait_generics_list = vec![];
73 let where_clause = trait_definition.generics.where_clause.clone();
74
75 for param in &trait_definition.generics.params {
76 impl_generics_list.push(param.clone());
77 match param {
78 syn::GenericParam::Type(param) => {
79 let ident = ¶m.ident;
80 trait_generics_list.push(quote! { #ident });
81 }
82 syn::GenericParam::Lifetime(param) => {
83 let ident = ¶m.lifetime;
84 trait_generics_list.push(quote! { #ident });
85 }
86 syn::GenericParam::Const(param) => {
87 let ident = ¶m.ident;
88 trait_generics_list.push(quote! { #ident });
89 }
90 }
91 }
92
93 for item in &trait_definition.items {
95 if let TraitItem::Type(assoc) = item {
96 if !assoc.generics.params.is_empty() {
97 return Err(syn::Error::new(
98 assoc.ident.span(),
99 "Generic associated types are not supported in trait queries",
100 ));
101 }
102 let ident = &assoc.ident;
103 let lower_ident = format_ident!("__{ident}");
104 let bound = &assoc.bounds;
105 impl_generics_list.push(parse_quote! { #lower_ident: #bound });
106 trait_generics_list.push(quote! { #ident = #lower_ident });
107 }
108 }
109
110 let impl_generics = quote! { <#( #impl_generics_list ,)*> };
111 let trait_generics = quote! { <#( #trait_generics_list ,)*> };
112
113 let trait_object = quote! { dyn #trait_name #trait_generics };
114
115 let my_crate = proc_macro_crate::crate_name("bevy-trait-query").unwrap();
116 let my_crate = match my_crate {
117 proc_macro_crate::FoundCrate::Itself => quote! { bevy_trait_query },
118 proc_macro_crate::FoundCrate::Name(x) => {
119 let ident = quote::format_ident!("{x}");
120 quote! { #ident }
121 }
122 };
123
124 let imports = quote! { #my_crate::imports };
125
126 let trait_query = quote! { #my_crate::TraitQuery };
127
128 let mut marker_impl_generics_list = impl_generics_list.clone();
129 marker_impl_generics_list
130 .push(parse_quote!(__Component: #trait_name #trait_generics + #imports::Component));
131 let marker_impl_generics = quote! { <#( #marker_impl_generics_list ,)*> };
132
133 let marker_impl_code = quote! {
134 impl #impl_generics #trait_query for #trait_object #where_clause {}
135
136 impl #marker_impl_generics #my_crate::TraitQueryMarker::<#trait_object> for (__Component,)
137 #where_clause
138 {
139 type Covered = __Component;
140 fn cast(ptr: *mut u8) -> *mut #trait_object {
141 ptr as *mut __Component as *mut _
142 }
143 }
144 };
145
146 let mut impl_generics_with_lifetime = impl_generics_list.clone();
147 impl_generics_with_lifetime.insert(0, parse_quote!('__a));
148 let impl_generics_with_lifetime = quote! { <#( #impl_generics_with_lifetime ,)*> };
149
150 let trait_object_query_code = quote! {
151 unsafe impl #impl_generics #imports::QueryData for &#trait_object
152 #where_clause
153 {
154 type ReadOnly = Self;
155
156 const IS_READ_ONLY: bool = true;
157
158 type Item<'__w> = #my_crate::ReadTraits<'__w, #trait_object>;
159
160 #[inline]
161 fn shrink<'wlong: 'wshort, 'wshort>(
162 item: Self::Item<'wlong>,
163 ) -> Self::Item<'wshort> {
164 item
165 }
166
167 #[inline]
168 unsafe fn fetch<'w>(
169 fetch: &mut Self::Fetch<'w>,
170 entity: #imports::Entity,
171 table_row: #imports::TableRow,
172 ) -> Self::Item<'w> {
173 <#my_crate::All<&#trait_object> as #imports::QueryData>::fetch(
174 fetch,
175 entity,
176 table_row,
177 )
178 }
179 }
180 unsafe impl #impl_generics #imports::ReadOnlyQueryData for &#trait_object
181 #where_clause
182 {}
183
184 unsafe impl #impl_generics_with_lifetime #imports::WorldQuery for &'__a #trait_object
185 #where_clause
186 {
187 type Fetch<'__w> = <#my_crate::All<&'__a #trait_object> as #imports::WorldQuery>::Fetch<'__w>;
188 type State = #my_crate::TraitQueryState<#trait_object>;
189
190 #[inline]
191 unsafe fn init_fetch<'w>(
192 world: #imports::UnsafeWorldCell<'w>,
193 state: &Self::State,
194 last_run: #imports::Tick,
195 this_run: #imports::Tick,
196 ) -> Self::Fetch<'w> {
197 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::init_fetch(
198 world,
199 state,
200 last_run,
201 this_run,
202 )
203 }
204
205 const IS_DENSE: bool = <#my_crate::All<&#trait_object> as #imports::WorldQuery>::IS_DENSE;
206
207 #[inline]
208 unsafe fn set_archetype<'w>(
209 fetch: &mut Self::Fetch<'w>,
210 state: &Self::State,
211 archetype: &'w #imports::Archetype,
212 tables: &'w #imports::Table,
213 ) {
214 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::set_archetype(
215 fetch, state, archetype, tables,
216 );
217 }
218
219 #[inline]
220 unsafe fn set_table<'w>(
221 fetch: &mut Self::Fetch<'w>,
222 state: &Self::State,
223 table: &'w #imports::Table,
224 ) {
225 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::set_table(fetch, state, table);
226 }
227
228 #[inline]
229 fn update_component_access(
230 state: &Self::State,
231 access: &mut #imports::FilteredAccess<#imports::ComponentId>,
232 ) {
233 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::update_component_access(
234 state, access,
235 );
236 }
237
238 #[inline]
239 fn init_state(world: &mut #imports::World) -> Self::State {
240 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::init_state(world)
241 }
242
243 #[inline]
244 fn get_state(_: &#imports::Components) -> Option<Self::State> {
245 panic!("transmuting and any other operations concerning the state of a query are currently broken and shouldn't be used. See https://github.com/JoJoJet/bevy-trait-query/issues/59");
247 }
248
249 #[inline]
250 fn matches_component_set(
251 state: &Self::State,
252 set_contains_id: &impl Fn(#imports::ComponentId) -> bool,
253 ) -> bool {
254 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::matches_component_set(state, set_contains_id)
255 }
256
257 #[inline]
258 fn shrink_fetch<'wlong: 'wshort, 'wshort>(fetch: Self::Fetch<'wlong>) -> Self::Fetch<'wshort> {
259 fetch
260 }
261 }
262
263 unsafe impl #impl_generics_with_lifetime #imports::QueryData for &'__a mut #trait_object
264 #where_clause
265 {
266 type ReadOnly = &'__a #trait_object;
267
268 type Item<'__w> = #my_crate::WriteTraits<'__w, #trait_object>;
269
270 const IS_READ_ONLY: bool = false;
271
272 #[inline]
273 fn shrink<'wlong: 'wshort, 'wshort>(
274 item: Self::Item<'wlong>,
275 ) -> Self::Item<'wshort> {
276 item
277 }
278
279 #[inline]
280 unsafe fn fetch<'w>(
281 fetch: &mut Self::Fetch<'w>,
282 entity: #imports::Entity,
283 table_row: #imports::TableRow,
284 ) -> Self::Item<'w> {
285 <#my_crate::All<&mut #trait_object> as #imports::QueryData>::fetch(
286 fetch,
287 entity,
288 table_row,
289 )
290 }
291 }
292
293 unsafe impl #impl_generics_with_lifetime #imports::WorldQuery for &'__a mut #trait_object
294 #where_clause
295 {
296 type Fetch<'__w> = <#my_crate::All<&'__a #trait_object> as #imports::WorldQuery>::Fetch<'__w>;
297 type State = #my_crate::TraitQueryState<#trait_object>;
298
299 #[inline]
300 unsafe fn init_fetch<'w>(
301 world: #imports::UnsafeWorldCell<'w>,
302 state: &Self::State,
303 last_run: #imports::Tick,
304 this_run: #imports::Tick,
305 ) -> Self::Fetch<'w> {
306 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::init_fetch(
307 world,
308 state,
309 last_run,
310 this_run,
311 )
312 }
313
314 const IS_DENSE: bool = <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::IS_DENSE;
315
316 #[inline]
317 unsafe fn set_archetype<'w>(
318 fetch: &mut Self::Fetch<'w>,
319 state: &Self::State,
320 archetype: &'w #imports::Archetype,
321 table: &'w #imports::Table,
322 ) {
323 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::set_archetype(
324 fetch, state, archetype, table,
325 );
326 }
327
328 #[inline]
329 unsafe fn set_table<'w>(
330 fetch: &mut Self::Fetch<'w>,
331 state: &Self::State,
332 table: &'w #imports::Table,
333 ) {
334 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::set_table(fetch, state, table);
335 }
336
337 #[inline]
338 fn update_component_access(
339 state: &Self::State,
340 access: &mut #imports::FilteredAccess<#imports::ComponentId>,
341 ) {
342 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::update_component_access(
343 state, access,
344 );
345 }
346
347 #[inline]
348 fn init_state(world: &mut #imports::World) -> Self::State {
349 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::init_state(world)
350 }
351
352 #[inline]
353 fn get_state(_: &#imports::Components) -> Option<Self::State> {
354 panic!("transmuting and any other operations concerning the state of a query are currently broken and shouldn't be used. See https://github.com/JoJoJet/bevy-trait-query/issues/59");
356 }
357
358 #[inline]
359 fn matches_component_set(
360 state: &Self::State,
361 set_contains_id: &impl Fn(#imports::ComponentId) -> bool,
362 ) -> bool {
363 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::matches_component_set(state, set_contains_id)
364 }
365
366 #[inline]
367 fn shrink_fetch<'wlong: 'wshort, 'wshort>(fetch: Self::Fetch<'wlong>) -> Self::Fetch<'wshort> {
368 fetch
369 }
370 }
371 };
372
373 Ok(quote! {
374 #trait_definition
375
376 #marker_impl_code
377
378 #trait_object_query_code
379 })
380}