1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{ItemTrait, Result, TraitItem, parse_quote};
5
6#[proc_macro_attribute]
34pub fn queryable(attr: TokenStream, item: TokenStream) -> TokenStream {
35 impl_trait_query(attr, item)
36 .unwrap_or_else(syn::Error::into_compile_error)
37 .into()
38}
39
40fn impl_trait_query(arg: TokenStream, item: TokenStream) -> Result<TokenStream2> {
41 syn::custom_keyword!(no_bounds);
42 let no_bounds: Option<no_bounds> = syn::parse(arg).map_err(|e| {
43 syn::Error::new(
44 e.span(),
45 "Valid forms are: `#[queryable]` and `#[queryable(no_bounds)]`",
46 )
47 })?;
48
49 let mut trait_definition = syn::parse::<ItemTrait>(item)?;
50 let trait_name = trait_definition.ident.clone();
51
52 if no_bounds.is_none() {
54 trait_definition.supertraits.push(parse_quote!('static));
55
56 for param in &mut trait_definition.generics.params {
57 if let syn::GenericParam::Type(param) = param {
59 param.bounds.push(parse_quote!('static));
60 }
61 }
62
63 for item in &mut trait_definition.items {
64 if let TraitItem::Type(assoc) = item {
66 assoc.bounds.push(parse_quote!('static));
67 }
68 }
69 }
70
71 let mut impl_generics_list = vec![];
72 let mut trait_generics_list = vec![];
73 let where_clause = trait_definition.generics.where_clause.clone();
74
75 for param in &trait_definition.generics.params {
76 impl_generics_list.push(param.clone());
77 match param {
78 syn::GenericParam::Type(param) => {
79 let ident = ¶m.ident;
80 trait_generics_list.push(quote! { #ident });
81 }
82 syn::GenericParam::Lifetime(param) => {
83 let ident = ¶m.lifetime;
84 trait_generics_list.push(quote! { #ident });
85 }
86 syn::GenericParam::Const(param) => {
87 let ident = ¶m.ident;
88 trait_generics_list.push(quote! { #ident });
89 }
90 }
91 }
92
93 for item in &trait_definition.items {
95 if let TraitItem::Type(assoc) = item {
96 if !assoc.generics.params.is_empty() {
97 return Err(syn::Error::new(
98 assoc.ident.span(),
99 "Generic associated types are not supported in trait queries",
100 ));
101 }
102 let ident = &assoc.ident;
103 let lower_ident = format_ident!("__{ident}");
104 let bound = &assoc.bounds;
105 impl_generics_list.push(parse_quote! { #lower_ident: #bound });
106 trait_generics_list.push(quote! { #ident = #lower_ident });
107 }
108 }
109
110 let impl_generics = quote! { <#( #impl_generics_list ,)*> };
111 let trait_generics = quote! { <#( #trait_generics_list ,)*> };
112
113 let trait_object = quote! { dyn #trait_name #trait_generics };
114
115 let my_crate = proc_macro_crate::crate_name("bevy-trait-query").unwrap();
116 let my_crate = match my_crate {
117 proc_macro_crate::FoundCrate::Itself => quote! { bevy_trait_query },
118 proc_macro_crate::FoundCrate::Name(x) => {
119 let ident = quote::format_ident!("{x}");
120 quote! { #ident }
121 }
122 };
123
124 let imports = quote! { #my_crate::imports };
125
126 let trait_query = quote! { #my_crate::TraitQuery };
127
128 let mut marker_impl_generics_list = impl_generics_list.clone();
129 marker_impl_generics_list
130 .push(parse_quote!(__Component: #trait_name #trait_generics + #imports::Component));
131 let marker_impl_generics = quote! { <#( #marker_impl_generics_list ,)*> };
132
133 let marker_impl_code = quote! {
134 impl #impl_generics #trait_query for #trait_object #where_clause {}
135
136 impl #marker_impl_generics #my_crate::TraitQueryMarker::<#trait_object> for (__Component,)
137 #where_clause
138 {
139 type Covered = __Component;
140 fn cast(ptr: *mut u8) -> *mut #trait_object {
141 ptr as *mut __Component as *mut _
142 }
143 }
144 };
145
146 let mut impl_generics_with_lifetime = impl_generics_list.clone();
147 impl_generics_with_lifetime.insert(0, parse_quote!('__a));
148 let impl_generics_with_lifetime = quote! { <#( #impl_generics_with_lifetime ,)*> };
149
150 let trait_object_query_code = quote! {
151 unsafe impl #impl_generics #imports::QueryData for &#trait_object
152 #where_clause
153 {
154 type ReadOnly = Self;
155
156 const IS_READ_ONLY: bool = true;
157 const IS_ARCHETYPAL: bool = false;
158
159 type Item<'__w, '__s> = #my_crate::ReadTraits<'__w, #trait_object>;
160
161 #[inline]
162 fn shrink<'wlong: 'wshort, 'wshort, 's>(
163 item: Self::Item<'wlong, 's>,
164 ) -> Self::Item<'wshort, 's> {
165 item
166 }
167
168 #[inline]
169 unsafe fn fetch<'w, 's>(
170 state: &'s Self::State,
171 fetch: &mut Self::Fetch<'w>,
172 entity: #imports::Entity,
173 table_row: #imports::TableRow,
174 ) -> Option<Self::Item<'w, 's>> {
175 <#my_crate::All<&#trait_object> as #imports::QueryData>::fetch(
176 state,
177 fetch,
178 entity,
179 table_row,
180 )
181 }
182
183 fn iter_access(
184 _state: &Self::State,
185 ) -> impl Iterator<Item = bevy_ecs::query::EcsAccessType<'_>> {
186 std::iter::empty()
187 }
188 }
189 unsafe impl #impl_generics #imports::ReadOnlyQueryData for &#trait_object
190 #where_clause
191 {}
192
193 unsafe impl #impl_generics_with_lifetime #imports::WorldQuery for &'__a #trait_object
194 #where_clause
195 {
196 type Fetch<'__w> = <#my_crate::All<&'__a #trait_object> as #imports::WorldQuery>::Fetch<'__w>;
197 type State = #my_crate::TraitQueryState<#trait_object>;
198
199 #[inline]
200 unsafe fn init_fetch<'w>(
201 world: #imports::UnsafeWorldCell<'w>,
202 state: &Self::State,
203 last_run: #imports::Tick,
204 this_run: #imports::Tick,
205 ) -> Self::Fetch<'w> {
206 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::init_fetch(
207 world,
208 state,
209 last_run,
210 this_run,
211 )
212 }
213
214 const IS_DENSE: bool = <#my_crate::All<&#trait_object> as #imports::WorldQuery>::IS_DENSE;
215
216 #[inline]
217 unsafe fn set_archetype<'w>(
218 fetch: &mut Self::Fetch<'w>,
219 state: &Self::State,
220 archetype: &'w #imports::Archetype,
221 tables: &'w #imports::Table,
222 ) {
223 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::set_archetype(
224 fetch, state, archetype, tables,
225 );
226 }
227
228 #[inline]
229 unsafe fn set_table<'w>(
230 fetch: &mut Self::Fetch<'w>,
231 state: &Self::State,
232 table: &'w #imports::Table,
233 ) {
234 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::set_table(fetch, state, table);
235 }
236
237 #[inline]
238 fn update_component_access(
239 state: &Self::State,
240 access: &mut #imports::FilteredAccess,
241 ) {
242 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::update_component_access(
243 state, access,
244 );
245 }
246
247 #[inline]
248 fn init_state(world: &mut #imports::World) -> Self::State {
249 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::init_state(world)
250 }
251
252 #[inline]
253 fn get_state(_: &#imports::Components) -> Option<Self::State> {
254 panic!("transmuting and any other operations concerning the state of a query are currently broken and shouldn't be used. See https://github.com/JoJoJet/bevy-trait-query/issues/59");
256 }
257
258 #[inline]
259 fn matches_component_set(
260 state: &Self::State,
261 set_contains_id: &impl Fn(#imports::ComponentId) -> bool,
262 ) -> bool {
263 <#my_crate::All<&#trait_object> as #imports::WorldQuery>::matches_component_set(state, set_contains_id)
264 }
265
266 #[inline]
267 fn shrink_fetch<'wlong: 'wshort, 'wshort>(fetch: Self::Fetch<'wlong>) -> Self::Fetch<'wshort> {
268 fetch
269 }
270 }
271
272 unsafe impl #impl_generics_with_lifetime #imports::QueryData for &'__a mut #trait_object
273 #where_clause
274 {
275 type ReadOnly = &'__a #trait_object;
276
277 type Item<'__w, '__s> = #my_crate::WriteTraits<'__w, #trait_object>;
278
279 const IS_READ_ONLY: bool = false;
280 const IS_ARCHETYPAL: bool = false;
281
282 #[inline]
283 fn shrink<'wlong: 'wshort, 'wshort, 's>(
284 item: Self::Item<'wlong, 's>,
285 ) -> Self::Item<'wshort, 's> {
286 item
287 }
288
289 #[inline]
290 unsafe fn fetch<'w, 's>(
291 state: &'s Self::State,
292 fetch: &mut Self::Fetch<'w>,
293 entity: #imports::Entity,
294 table_row: #imports::TableRow,
295 ) -> Option<Self::Item<'w, 's>> {
296 <#my_crate::All<&mut #trait_object> as #imports::QueryData>::fetch(
297 state,
298 fetch,
299 entity,
300 table_row,
301 )
302 }
303
304 fn iter_access(
305 _state: &Self::State,
306 ) -> impl Iterator<Item = bevy_ecs::query::EcsAccessType<'_>> {
307 std::iter::empty()
308 }
309 }
310
311 unsafe impl #impl_generics_with_lifetime #imports::WorldQuery for &'__a mut #trait_object
312 #where_clause
313 {
314 type Fetch<'__w> = <#my_crate::All<&'__a #trait_object> as #imports::WorldQuery>::Fetch<'__w>;
315 type State = #my_crate::TraitQueryState<#trait_object>;
316
317 #[inline]
318 unsafe fn init_fetch<'w>(
319 world: #imports::UnsafeWorldCell<'w>,
320 state: &Self::State,
321 last_run: #imports::Tick,
322 this_run: #imports::Tick,
323 ) -> Self::Fetch<'w> {
324 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::init_fetch(
325 world,
326 state,
327 last_run,
328 this_run,
329 )
330 }
331
332 const IS_DENSE: bool = <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::IS_DENSE;
333
334 #[inline]
335 unsafe fn set_archetype<'w>(
336 fetch: &mut Self::Fetch<'w>,
337 state: &Self::State,
338 archetype: &'w #imports::Archetype,
339 table: &'w #imports::Table,
340 ) {
341 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::set_archetype(
342 fetch, state, archetype, table,
343 );
344 }
345
346 #[inline]
347 unsafe fn set_table<'w>(
348 fetch: &mut Self::Fetch<'w>,
349 state: &Self::State,
350 table: &'w #imports::Table,
351 ) {
352 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::set_table(fetch, state, table);
353 }
354
355 #[inline]
356 fn update_component_access(
357 state: &Self::State,
358 access: &mut #imports::FilteredAccess,
359 ) {
360 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::update_component_access(
361 state, access,
362 );
363 }
364
365 #[inline]
366 fn init_state(world: &mut #imports::World) -> Self::State {
367 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::init_state(world)
368 }
369
370 #[inline]
371 fn get_state(_: &#imports::Components) -> Option<Self::State> {
372 panic!("transmuting and any other operations concerning the state of a query are currently broken and shouldn't be used. See https://github.com/JoJoJet/bevy-trait-query/issues/59");
374 }
375
376 #[inline]
377 fn matches_component_set(
378 state: &Self::State,
379 set_contains_id: &impl Fn(#imports::ComponentId) -> bool,
380 ) -> bool {
381 <#my_crate::All<&mut #trait_object> as #imports::WorldQuery>::matches_component_set(state, set_contains_id)
382 }
383
384 #[inline]
385 fn shrink_fetch<'wlong: 'wshort, 'wshort>(fetch: Self::Fetch<'wlong>) -> Self::Fetch<'wshort> {
386 fetch
387 }
388 }
389 };
390
391 Ok(quote! {
392 #trait_definition
393
394 #marker_impl_code
395
396 #trait_object_query_code
397 })
398}