1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{ItemTrait, Result, TraitItem, parse_quote};
5
6#[proc_macro_attribute]
34pub fn queryable(attr: TokenStream, item: TokenStream) -> TokenStream {
35 impl_trait_query(attr, item)
36 .unwrap_or_else(syn::Error::into_compile_error)
37 .into()
38}
39
40fn impl_trait_query(arg: TokenStream, item: TokenStream) -> Result<TokenStream2> {
41 syn::custom_keyword!(no_bounds);
42 let no_bounds: Option<no_bounds> = syn::parse(arg).map_err(|e| {
43 syn::Error::new(
44 e.span(),
45 "Valid forms are: `#[queryable]` and `#[queryable(no_bounds)]`",
46 )
47 })?;
48
49 let mut trait_definition = syn::parse::<ItemTrait>(item)?;
50 let trait_name = trait_definition.ident.clone();
51
52 if no_bounds.is_none() {
54 trait_definition.supertraits.push(parse_quote!('static));
55
56 for param in &mut trait_definition.generics.params {
57 if let syn::GenericParam::Type(param) = param {
59 param.bounds.push(parse_quote!('static));
60 }
61 }
62
63 for item in &mut trait_definition.items {
64 if let TraitItem::Type(assoc) = item {
66 assoc.bounds.push(parse_quote!('static));
67 }
68 }
69 }
70
71 let mut impl_generics_list = vec![];
72 let mut trait_generics_list = vec![];
73 let where_clause = trait_definition.generics.where_clause.clone();
74
75 for param in &trait_definition.generics.params {
76 impl_generics_list.push(param.clone());
77 match param {
78 syn::GenericParam::Type(param) => {
79 let ident = ¶m.ident;
80 trait_generics_list.push(quote! { #ident });
81 }
82 syn::GenericParam::Lifetime(param) => {
83 let ident = ¶m.lifetime;
84 trait_generics_list.push(quote! { #ident });
85 }
86 syn::GenericParam::Const(param) => {
87 let ident = ¶m.ident;
88 trait_generics_list.push(quote! { #ident });
89 }
90 }
91 }
92
93 for item in &trait_definition.items {
95 if let TraitItem::Type(assoc) = item {
96 if !assoc.generics.params.is_empty() {
97 return Err(syn::Error::new(
98 assoc.ident.span(),
99 "Generic associated types are not supported in trait queries",
100 ));
101 }
102 let ident = &assoc.ident;
103 let lower_ident = format_ident!("__{ident}");
104 let bound = &assoc.bounds;
105 impl_generics_list.push(parse_quote! { #lower_ident: #bound });
106 trait_generics_list.push(quote! { #ident = #lower_ident });
107 }
108 }
109
110 let impl_generics = quote! { <#( #impl_generics_list ,)*> };
111 let trait_generics = quote! { <#( #trait_generics_list ,)*> };
112
113 let trait_object = quote! { dyn #trait_name #trait_generics };
114
115 let my_crate = proc_macro_crate::crate_name("bevy-trait-query").unwrap();
116 let my_crate = match my_crate {
117 proc_macro_crate::FoundCrate::Itself => quote! { bevy_trait_query },
118 proc_macro_crate::FoundCrate::Name(x) => {
119 let ident = quote::format_ident!("{x}");
120 quote! { #ident }
121 }
122 };
123
124 let imports = quote! { #my_crate::imports };
125
126 let trait_query = quote! { #my_crate::TraitQuery };
127
128 let mut marker_impl_generics_list = impl_generics_list.clone();
129 marker_impl_generics_list
130 .push(parse_quote!(__Component: #trait_name #trait_generics + #imports::Component));
131 let marker_impl_generics = quote! { <#( #marker_impl_generics_list ,)*> };
132
133 let marker_impl_code = quote! {
134 impl #impl_generics #trait_query for #trait_object #where_clause {}
135
136 impl #marker_impl_generics #my_crate::TraitQueryMarker::<#trait_object> for (__Component,)
137 #where_clause
138 {
139 type Covered = __Component;
140 fn cast(ptr: *mut u8) -> *mut #trait_object {
141 ptr as *mut __Component as *mut _
142 }
143 }
144 };
145
146 let mut impl_generics_with_lifetime = impl_generics_list.clone();
147 impl_generics_with_lifetime.insert(0, parse_quote!('__a));
148 let impl_generics_with_lifetime = quote! { <#( #impl_generics_with_lifetime ,)*> };
149
150 let trait_object_query_code = quote! {
151 unsafe impl #impl_generics #imports::QueryData for &#trait_object
152 #where_clause
153 {
154 type ReadOnly = Self;
155
156 const IS_READ_ONLY: bool = true;
157
158 type Item<'__w, '__s> = #my_crate::ReadTraits<'__w, #trait_object>;
159
160 #[inline]
161 fn shrink<'wlong: 'wshort, 'wshort, 's>(
162 item: Self::Item<'wlong, 's>,
163 ) -> Self::Item<'wshort, 's> {
164 item
165 }
166
167 #[inline]
168 unsafe fn fetch<'w, 's>(
169 state: &'s Self::State,
170 fetch: &mut Self::Fetch<'w>,
171 entity: #imports::Entity,
172 table_row: #imports::TableRow,
173 ) -> Self::Item<'w, 's> {
174 <#my_crate::All<&#trait_object> as #imports::QueryData>::fetch(
175 state,
176 fetch,
177 entity,
178 table_row,
179 )
180 }
181 }
182 unsafe impl #impl_generics #imports::ReadOnlyQueryData for &#trait_object
183 #where_clause
184 {}
185
186 unsafe impl #impl_generics_with_lifetime #imports::WorldQuery for &'__a #trait_object
187 #where_clause
188 {
189 type Fetch<'__w> = <#my_crate::All<&'__a #trait_object> as #imports::WorldQuery>::Fetch<'__w>;
190 type State = #my_crate::TraitQueryState<#trait_object>;
191
192 #[inline]
193 unsafe fn init_fetch<'w>(
194 world: #imports::UnsafeWorldCell<'w>,
195 state: &Self::State,
196 last_run: #imports::Tick,
197 this_run: #imports::Tick,
198 ) -> Self::Fetch<'w> {
199 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::init_fetch(
200 world,
201 state,
202 last_run,
203 this_run,
204 )
205 }
206
207 const IS_DENSE: bool = <#my_crate::All<&#trait_object> as #imports::WorldQuery>::IS_DENSE;
208
209 #[inline]
210 unsafe fn set_archetype<'w>(
211 fetch: &mut Self::Fetch<'w>,
212 state: &Self::State,
213 archetype: &'w #imports::Archetype,
214 tables: &'w #imports::Table,
215 ) {
216 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::set_archetype(
217 fetch, state, archetype, tables,
218 );
219 }
220
221 #[inline]
222 unsafe fn set_table<'w>(
223 fetch: &mut Self::Fetch<'w>,
224 state: &Self::State,
225 table: &'w #imports::Table,
226 ) {
227 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::set_table(fetch, state, table);
228 }
229
230 #[inline]
231 fn update_component_access(
232 state: &Self::State,
233 access: &mut #imports::FilteredAccess,
234 ) {
235 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::update_component_access(
236 state, access,
237 );
238 }
239
240 #[inline]
241 fn init_state(world: &mut #imports::World) -> Self::State {
242 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::init_state(world)
243 }
244
245 #[inline]
246 fn get_state(_: &#imports::Components) -> Option<Self::State> {
247 panic!("transmuting and any other operations concerning the state of a query are currently broken and shouldn't be used. See https://github.com/JoJoJet/bevy-trait-query/issues/59");
249 }
250
251 #[inline]
252 fn matches_component_set(
253 state: &Self::State,
254 set_contains_id: &impl Fn(#imports::ComponentId) -> bool,
255 ) -> bool {
256 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::matches_component_set(state, set_contains_id)
257 }
258
259 #[inline]
260 fn shrink_fetch<'wlong: 'wshort, 'wshort>(fetch: Self::Fetch<'wlong>) -> Self::Fetch<'wshort> {
261 fetch
262 }
263 }
264
265 unsafe impl #impl_generics_with_lifetime #imports::QueryData for &'__a mut #trait_object
266 #where_clause
267 {
268 type ReadOnly = &'__a #trait_object;
269
270 type Item<'__w, '__s> = #my_crate::WriteTraits<'__w, #trait_object>;
271
272 const IS_READ_ONLY: bool = false;
273
274 #[inline]
275 fn shrink<'wlong: 'wshort, 'wshort, 's>(
276 item: Self::Item<'wlong, 's>,
277 ) -> Self::Item<'wshort, 's> {
278 item
279 }
280
281 #[inline]
282 unsafe fn fetch<'w, 's>(
283 state: &'s Self::State,
284 fetch: &mut Self::Fetch<'w>,
285 entity: #imports::Entity,
286 table_row: #imports::TableRow,
287 ) -> Self::Item<'w, 's> {
288 <#my_crate::All<&mut #trait_object> as #imports::QueryData>::fetch(
289 state,
290 fetch,
291 entity,
292 table_row,
293 )
294 }
295 }
296
297 unsafe impl #impl_generics_with_lifetime #imports::WorldQuery for &'__a mut #trait_object
298 #where_clause
299 {
300 type Fetch<'__w> = <#my_crate::All<&'__a #trait_object> as #imports::WorldQuery>::Fetch<'__w>;
301 type State = #my_crate::TraitQueryState<#trait_object>;
302
303 #[inline]
304 unsafe fn init_fetch<'w>(
305 world: #imports::UnsafeWorldCell<'w>,
306 state: &Self::State,
307 last_run: #imports::Tick,
308 this_run: #imports::Tick,
309 ) -> Self::Fetch<'w> {
310 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::init_fetch(
311 world,
312 state,
313 last_run,
314 this_run,
315 )
316 }
317
318 const IS_DENSE: bool = <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::IS_DENSE;
319
320 #[inline]
321 unsafe fn set_archetype<'w>(
322 fetch: &mut Self::Fetch<'w>,
323 state: &Self::State,
324 archetype: &'w #imports::Archetype,
325 table: &'w #imports::Table,
326 ) {
327 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::set_archetype(
328 fetch, state, archetype, table,
329 );
330 }
331
332 #[inline]
333 unsafe fn set_table<'w>(
334 fetch: &mut Self::Fetch<'w>,
335 state: &Self::State,
336 table: &'w #imports::Table,
337 ) {
338 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::set_table(fetch, state, table);
339 }
340
341 #[inline]
342 fn update_component_access(
343 state: &Self::State,
344 access: &mut #imports::FilteredAccess,
345 ) {
346 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::update_component_access(
347 state, access,
348 );
349 }
350
351 #[inline]
352 fn init_state(world: &mut #imports::World) -> Self::State {
353 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::init_state(world)
354 }
355
356 #[inline]
357 fn get_state(_: &#imports::Components) -> Option<Self::State> {
358 panic!("transmuting and any other operations concerning the state of a query are currently broken and shouldn't be used. See https://github.com/JoJoJet/bevy-trait-query/issues/59");
360 }
361
362 #[inline]
363 fn matches_component_set(
364 state: &Self::State,
365 set_contains_id: &impl Fn(#imports::ComponentId) -> bool,
366 ) -> bool {
367 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::matches_component_set(state, set_contains_id)
368 }
369
370 #[inline]
371 fn shrink_fetch<'wlong: 'wshort, 'wshort>(fetch: Self::Fetch<'wlong>) -> Self::Fetch<'wshort> {
372 fetch
373 }
374 }
375 };
376
377 Ok(quote! {
378 #trait_definition
379
380 #marker_impl_code
381
382 #trait_object_query_code
383 })
384}