1use std::collections::BTreeSet;
2
3use itertools::Itertools;
4use proc_macro2::{Span, TokenStream};
5use proc_macro_crate::FoundCrate;
6use quote::{format_ident, quote};
7use syn::{
8 bracketed, parse::Parse, punctuated::Punctuated, spanned::Spanned, Attribute, DataStruct,
9 DeriveInput, Error, Field, GenericParam, Generics, Ident, ImplGenerics, Index, Lifetime,
10 LifetimeParam, Result, Token, Type, TypeGenerics, TypeParam, Visibility,
11};
12
13#[proc_macro_derive(Fetch, attributes(fetch))]
32pub fn derive_fetch(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
33 let crate_name = match proc_macro_crate::crate_name("flax").expect("Failed to get crate name") {
34 FoundCrate::Itself => Ident::new("crate", Span::call_site()),
35 FoundCrate::Name(name) => Ident::new(&name, Span::call_site()),
36 };
37 do_derive_fetch(crate_name, input.into()).into()
38}
39
40fn do_derive_fetch(crate_name: Ident, input: TokenStream) -> TokenStream {
41 let input = match syn::parse2::<DeriveInput>(input) {
42 Ok(input) => input,
43 Err(err) => return err.to_compile_error(),
44 };
45
46 match input.data {
47 syn::Data::Struct(ref data) => derive_data_struct(crate_name, &input, data)
48 .unwrap_or_else(|err| err.to_compile_error()),
49 syn::Data::Enum(_) => todo!(),
50 syn::Data::Union(_) => todo!(),
51 }
52}
53
54fn derive_data_struct(
55 crate_name: Ident,
56 input: &DeriveInput,
57 data: &DataStruct,
58) -> Result<TokenStream> {
59 let attrs = Attrs::get(&input.attrs)?;
60
61 match data.fields {
62 syn::Fields::Named(_) => {
63 let params = Params::new(&crate_name, &input.vis, input, &attrs)?;
64
65 let prepared_derive = derive_prepared_struct(¶ms);
66
67 let fetch_derive = derive_fetch_struct(¶ms);
68
69 let union_derive = derive_union(¶ms);
70
71 let transforms_derive = derive_transform(¶ms)?;
72
73 Ok(quote! {
74 #fetch_derive
75
76 #prepared_derive
77
78 #union_derive
79
80 #transforms_derive
81 })
82 }
83 syn::Fields::Unnamed(_) => Err(Error::new(
84 Span::call_site(),
85 "Deriving fetch for a tuple struct is not supported",
86 )),
87 syn::Fields::Unit => Err(Error::new(
88 Span::call_site(),
89 "Deriving fetch for a unit struct is not supported",
90 )),
91 }
92}
93
94fn derive_fetch_struct(params: &Params) -> TokenStream {
95 let Params {
96 crate_name,
97 vis,
98 fetch_name,
99 item_name,
100 prepared_name,
101 q_generics,
102 fields,
103 field_names,
104 field_types,
105 attrs,
106 ..
107 } = params;
108
109 let item_ty = params.q_ty();
110 let item_impl = params.q_impl();
111 let item_msg = format!("The item returned by {fetch_name}");
112
113 let prep_ty = params.w_ty();
114
115 let extras = match &attrs.item_derives {
116 Some(extras) => {
117 quote! { #[derive(#extras)]}
118 }
119 None => quote! {},
120 };
121
122 let fetch_impl = params.w_impl();
123 let fetch_ty = params.base_ty();
124
125 let item_fields = fields
126 .iter()
127 .map(|v| {
128 let vis = v.vis;
129 let ident = v.ident;
130 let ty = v.ty;
131 quote! {
132 #vis #ident: <#ty as #crate_name::fetch::FetchItem<'q>>::Item,
133 }
134 })
135 .collect::<TokenStream>();
136
137 quote! {
138 #[doc = #item_msg]
139 #extras
140 #vis struct #item_name #q_generics {
141 #item_fields
142 }
143
144 #[automatically_derived]
149 impl #item_impl #crate_name::fetch::FetchItem<'q> for #fetch_name #fetch_ty {
150 type Item = #item_name #item_ty;
151 }
152
153 #[automatically_derived]
154 impl #fetch_impl #crate_name::Fetch<'w> for #fetch_name #fetch_ty
155 where #(#field_types: 'static,)*
156 {
157 const MUTABLE: bool = #(<#field_types as #crate_name::Fetch <'w>>::MUTABLE)||*;
158
159 type Prepared = #prepared_name #prep_ty;
160
161 #[inline]
162 fn prepare( &'w self, data: #crate_name::fetch::FetchPrepareData<'w>
163 ) -> Option<Self::Prepared> {
164 Some(Self::Prepared {
165 #(#field_names: #crate_name::Fetch::prepare(&self.#field_names, data)?,)*
166 })
167 }
168
169 #[inline]
170 fn filter_arch(&self, data: #crate_name::fetch::FetchAccessData) -> bool {
171 #(#crate_name::Fetch::filter_arch(&self.#field_names, data))&&*
172 }
173
174 fn describe(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
175 let mut s = f.debug_struct(stringify!(#fetch_name));
176
177 #(
178 s.field(stringify!(#field_names), &#crate_name::fetch::FmtQuery(&self.#field_names));
179 )*
180
181 s.finish()
182 }
183
184 fn access(&self, data: #crate_name::fetch::FetchAccessData, dst: &mut Vec<#crate_name::system::Access>) {
185 #(#crate_name::Fetch::access(&self.#field_names, data, dst));*
186 }
187
188 fn searcher(&self, searcher: &mut #crate_name::query::ArchetypeSearcher) {
189 #(#crate_name::Fetch::searcher(&self.#field_names, searcher);)*
190 }
191 }
192 }
193}
194
195fn prepend_generics(prepend: &[GenericParam], generics: &Generics) -> Generics {
196 let mut generics = generics.clone();
197 generics.params = prepend.iter().cloned().chain(generics.params).collect();
198
199 generics
200}
201
202fn derive_union(params: &Params) -> TokenStream {
204 let Params {
205 crate_name,
206 fields,
207 prepared_name,
208 ..
209 } = params;
210
211 let impl_generics = params.wq_impl();
212
213 let prep_ty = params.w_ty();
214
215 let filter_fields = fields.iter().filter(|v| !v.attrs.ignore).map(|v| v.ident);
217 let filter_types = fields.iter().filter(|v| !v.attrs.ignore).map(|v| v.ty);
218
219 quote! {
220 #[automatically_derived]
221 impl #impl_generics #crate_name::fetch::UnionFilter for #prepared_name #prep_ty where #prepared_name #prep_ty: #crate_name::fetch::PreparedFetch<'q> {
222 const HAS_UNION_FILTER: bool = #(<<#filter_types as #crate_name::fetch::Fetch<'w>>::Prepared as #crate_name::fetch::PreparedFetch<'q>>::HAS_FILTER)&&*;
223
224 unsafe fn filter_union(&mut self, slots: #crate_name::archetype::Slice) -> #crate_name::archetype::Slice {
225 #crate_name::fetch::PreparedFetch::filter_slots(&mut #crate_name::filter::Union((#(&mut self.#filter_fields,)*)), slots)
226 }
227 }
228 }
229}
230
231fn derive_transform(params: &Params) -> Result<TokenStream> {
233 let Params {
234 crate_name,
235 vis,
236 fields,
237 fetch_name,
238 attrs,
239 ..
240 } = params;
241
242 if attrs.transforms.is_empty() {
243 return Ok(quote! {});
244 }
245
246 let ty_generics = ('A'..='Z')
248 .zip(fields)
249 .filter(|(_, v)| !v.attrs.ignore)
250 .map(|(c, _)| format_ident!("{}", c))
251 .map(|v| GenericParam::Type(TypeParam::from(v)))
252 .collect_vec();
253
254 let transformed_name = format_ident!("{fetch_name}Transformed");
255 use quote::ToTokens;
256
257 let transformed_struct = {
258 let fields = ('A'..='Z').zip(fields).map(|(c, field)| {
259 let ty = if field.attrs.ignore {
260 field.ty.to_token_stream()
261 } else {
262 format_ident!("{}", c).to_token_stream()
263 };
264
265 let vis = field.vis;
266 let ident = field.ident;
267 quote! {
268 #vis #ident: #ty,
269 }
270 });
271
272 quote! {
273 #vis struct #transformed_name<#(#ty_generics: for<'x> #crate_name::fetch::Fetch<'x>),*>{
274 #(#fields)*
275 }
276 }
277 };
278
279 let input =
280 syn::parse2::<DeriveInput>(transformed_struct).expect("Generated struct is always valid");
281
282 let transformed_attrs = Attrs::default();
283
284 let mut transformed_params = Params::new(crate_name, vis, &input, &transformed_attrs)?;
285 for (dst, src) in transformed_params.fields.iter_mut().zip(fields) {
286 dst.attrs = src.attrs.clone();
287 }
288
289 let fetch = derive_fetch_struct(&transformed_params);
290
291 let prepared = derive_prepared_struct(&transformed_params);
292 let union = derive_union(&transformed_params);
293
294 let transforms = attrs
295 .transforms
296 .iter()
297 .map(|method| {
298 let method = method.to_tokens(crate_name);
299
300 let trait_name = quote! { #crate_name::fetch::TransformFetch<#method> };
301
302 let types = fields
303 .iter()
304 .filter_map(|field| {
305 if field.attrs.ignore {
306 None
307 } else {
308 let ty = field.ty;
309 Some(quote! {
310 <#ty as #trait_name>::Output
311 })
312 }
313 })
314 .collect_vec();
315
316 let initializers = fields
317 .iter()
318 .map(|field| {
319 let ident = field.ident;
320 let ty = field.ty;
321 if field.attrs.ignore {
322 quote! {
323 #ident: self.#ident
324 }
325 } else {
326 quote! {
327 #ident: <#ty as #trait_name>::transform_fetch(self.#ident, method)
328 }
329 }
330 })
331 .collect_vec();
332
333 quote! {
334 #[automatically_derived]
335 impl #trait_name for #fetch_name
336 {
337 type Output = #crate_name::filter::Union<#transformed_name<#(#types,)*>>;
338 fn transform_fetch(self, method: #method) -> Self::Output {
339 #crate_name::filter::Union(#transformed_name {
340 #(#initializers,)*
341 })
342 }
343 }
344 }
345 })
346 .collect_vec();
347
348 Ok(quote! {
349 #input
350
351 #fetch
352
353 #prepared
354
355 #union
356
357 #(#transforms)*
358 })
359}
360
361fn derive_prepared_struct(params: &Params) -> TokenStream {
362 let Params {
363 crate_name,
364 vis,
365 fetch_name,
366 item_name,
367 prepared_name,
368 fields,
369 field_names,
370 field_types,
371 w_generics,
372 ..
373 } = params;
374
375 let msg = format!("The prepared fetch for {fetch_name}");
376
377 let prep_impl = params.wq_impl();
378 let prep_ty = params.w_ty();
379 let item_ty = params.q_ty();
380
381 let field_idx = (0..field_names.len()).map(Index::from);
382 let filter_fields = fields.iter().filter(|v| !v.attrs.ignore).map(|v| v.ident);
383
384 quote! {
385 #[doc = #msg]
386 #vis struct #prepared_name #w_generics {
387 #(#field_names: <#field_types as #crate_name::Fetch <'w>>::Prepared,)*
388 }
389
390 #[automatically_derived]
391 impl #prep_impl #crate_name::fetch::PreparedFetch<'q> for #prepared_name #prep_ty
392 where #(#field_types: 'static,)*
393 {
394 type Item = #item_name #item_ty;
395 type Chunk = (#(<<#field_types as #crate_name::fetch::Fetch<'w>>::Prepared as #crate_name::fetch::PreparedFetch<'q>>::Chunk,)*);
396
397 const HAS_FILTER: bool = #(<<#field_types as #crate_name::fetch::Fetch<'w>>::Prepared as #crate_name::fetch::PreparedFetch<'q>>::HAS_FILTER)||*;
398
399 #[inline]
400 unsafe fn fetch_next(chunk: &mut Self::Chunk) -> Self::Item {
401 Self::Item {
402 #(#field_names: <<#field_types as #crate_name::fetch::Fetch<'w>>::Prepared as #crate_name::fetch::PreparedFetch<'q>>::fetch_next(&mut chunk.#field_idx),)*
403 }
404 }
405
406 #[inline]
407 unsafe fn filter_slots(&mut self, slots: #crate_name::archetype::Slice) -> #crate_name::archetype::Slice {
408 #crate_name::fetch::PreparedFetch::filter_slots(&mut (#(&mut self.#filter_fields,)*), slots)
409 }
410
411 #[inline]
412 unsafe fn create_chunk(&'q mut self, slots: #crate_name::archetype::Slice) -> Self::Chunk {
413 (
414 #(#crate_name::fetch::PreparedFetch::create_chunk(&mut self.#field_names, slots),)*
415 )
416 }
417 }
418 }
419}
420
421#[derive(Clone)]
422struct ParsedField<'a> {
423 vis: &'a Visibility,
424 ty: &'a Type,
425 ident: &'a Ident,
426 attrs: FieldAttrs,
427}
428
429impl<'a> ParsedField<'a> {
430 fn get(field: &'a Field) -> Result<Self> {
431 let attrs = FieldAttrs::get(&field.attrs)?;
432
433 let ident = field
434 .ident
435 .as_ref()
436 .ok_or(Error::new(field.span(), "Only named fields are supported"))?;
437
438 Ok(Self {
439 vis: &field.vis,
440 ty: &field.ty,
441 ident,
442 attrs,
443 })
444 }
445}
446
447#[derive(Default, Debug, Clone)]
448struct FieldAttrs {
449 ignore: bool,
450}
451
452impl FieldAttrs {
453 fn get(input: &[Attribute]) -> Result<Self> {
454 let mut res = Self::default();
455
456 for attr in input {
457 if !attr.path().is_ident("fetch") {
458 continue;
459 }
460
461 match &attr.meta {
462 syn::Meta::List(list) => {
463 list.parse_nested_meta(|meta| {
466 if meta.path.is_ident("ignore") {
468 res.ignore = true;
469 Ok(())
470 } else {
471 Err(Error::new(
472 meta.path.span(),
473 "Unknown fetch field attribute",
474 ))
475 }
476 })?;
477 }
478 _ => {
479 return Err(Error::new(
480 Span::call_site(),
481 "Expected a MetaList for `fetch`",
482 ))
483 }
484 };
485 }
486
487 Ok(res)
488 }
489}
490
491#[derive(Default)]
492struct Attrs {
493 item_derives: Option<Punctuated<Ident, Token![,]>>,
494 transforms: BTreeSet<TransformIdent>,
495}
496
497impl Attrs {
498 fn get(input: &[Attribute]) -> Result<Self> {
499 let mut res = Self::default();
500
501 for attr in input {
502 if !attr.path().is_ident("fetch") {
503 continue;
504 }
505
506 match &attr.meta {
507 syn::Meta::List(list) => {
508 list.parse_nested_meta(|meta| {
511 if meta.path.is_ident("item_derives") {
513 let value = meta.value()?;
514 let content;
515 bracketed!(content in value);
516 let content =
517 <Punctuated<Ident, Token![,]>>::parse_terminated(&content)?;
518
519 res.item_derives = Some(content);
520 Ok(())
521 } else if meta.path.is_ident("transforms") {
522 let value = meta.value()?;
523 let content;
524 bracketed!(content in value);
525 let content =
526 <Punctuated<TransformIdent, Token![,]>>::parse_terminated(
527 &content,
528 )?;
529
530 res.transforms.extend(content);
531 Ok(())
532 } else {
533 Err(Error::new(meta.path.span(), "Unknown fetch attribute"))
534 }
535 })?;
536 }
537 _ => {
538 return Err(Error::new(
539 Span::call_site(),
540 "Expected a MetaList for `fetch`",
541 ))
542 }
543 };
544 }
545
546 Ok(res)
547 }
548}
549
550#[derive(Clone)]
551struct Params<'a> {
552 crate_name: &'a Ident,
553 vis: &'a Visibility,
554
555 fetch_name: Ident,
556 item_name: Ident,
557 prepared_name: Ident,
558
559 generics: &'a Generics,
560 w_generics: Generics,
561 q_generics: Generics,
562 wq_generics: Generics,
563
564 fields: Vec<ParsedField<'a>>,
565 field_names: Vec<&'a Ident>,
566 field_types: Vec<&'a Type>,
567
568 attrs: &'a Attrs,
569}
570
571impl<'a> Params<'a> {
572 fn new(
573 crate_name: &'a Ident,
574 vis: &'a Visibility,
575 input: &'a DeriveInput,
576 attrs: &'a Attrs,
577 ) -> Result<Self> {
578 let fields = match &input.data {
579 syn::Data::Struct(data) => match &data.fields {
580 syn::Fields::Named(fields) => fields,
581 _ => unreachable!(),
582 },
583
584 _ => unreachable!(),
585 };
586
587 let fetch_name = input.ident.clone();
588
589 let w_lf = LifetimeParam::new(Lifetime::new("'w", Span::call_site()));
590 let q_lf = LifetimeParam::new(Lifetime::new("'q", Span::call_site()));
591
592 let fields = fields
593 .named
594 .iter()
595 .map(ParsedField::get)
596 .collect::<Result<Vec<_>>>()?;
597
598 let field_names = fields.iter().map(|v| v.ident).collect_vec();
599 let field_types = fields.iter().map(|v| v.ty).collect_vec();
600
601 Ok(Self {
602 crate_name,
603 vis,
604 generics: &input.generics,
605 fields,
606 field_names,
607 field_types,
608 attrs,
609 item_name: format_ident!("{fetch_name}Item"),
610 prepared_name: format_ident!("Prepared{fetch_name}"),
611 fetch_name,
612 w_generics: prepend_generics(&[GenericParam::Lifetime(w_lf.clone())], &input.generics),
613 q_generics: prepend_generics(&[GenericParam::Lifetime(q_lf.clone())], &input.generics),
614
615 wq_generics: prepend_generics(
616 &[
617 GenericParam::Lifetime(w_lf.clone()),
618 GenericParam::Lifetime(q_lf.clone()),
619 ],
620 &input.generics,
621 ),
622 })
623 }
624
625 fn q_impl(&self) -> ImplGenerics {
626 self.q_generics.split_for_impl().0
627 }
628
629 fn wq_impl(&self) -> ImplGenerics {
630 self.wq_generics.split_for_impl().0
631 }
632
633 fn w_impl(&self) -> ImplGenerics {
634 self.w_generics.split_for_impl().0
635 }
636
637 fn base_ty(&self) -> TypeGenerics {
638 self.generics.split_for_impl().1
639 }
640
641 fn q_ty(&self) -> TypeGenerics {
642 self.q_generics.split_for_impl().1
643 }
644
645 fn w_ty(&self) -> TypeGenerics {
646 self.w_generics.split_for_impl().1
647 }
648}
649
650#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
651enum TransformIdent {
652 Modified,
653 Added,
654}
655
656impl TransformIdent {
657 fn to_tokens(&self, crate_name: &Ident) -> TokenStream {
658 match self {
659 Self::Modified => quote!(#crate_name::fetch::Modified),
660 Self::Added => quote!(#crate_name::fetch::Added),
661 }
662 }
663}
664
665impl Parse for TransformIdent {
666 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
667 let ident = input.parse::<Ident>()?;
668 if ident == "Modified" {
669 Ok(Self::Modified)
670 } else if ident == "Added" {
671 Ok(Self::Added)
672 } else {
673 Err(Error::new(
674 ident.span(),
675 format!("Unknown transform {ident}"),
676 ))
677 }
678 }
679}