flax_derive/
lib.rs

1use std::collections::BTreeSet;
2
3use itertools::Itertools;
4use proc_macro2::{Span, TokenStream};
5use proc_macro_crate::FoundCrate;
6use quote::{format_ident, quote};
7use syn::{
8    bracketed, parse::Parse, punctuated::Punctuated, spanned::Spanned, Attribute, DataStruct,
9    DeriveInput, Error, Field, GenericParam, Generics, Ident, ImplGenerics, Index, Lifetime,
10    LifetimeParam, Result, Token, Type, TypeGenerics, TypeParam, Visibility,
11};
12
13/// ```rust,ignore
14/// #[derive(Fetch)]
15/// #[fetch(item_derives = [Debug], transforms = [Modified])]
16/// struct CustomFetch {
17///     #[fetch(ignore)]
18///     rotation: Mutable<glam::Quat>,
19///     position: Component<glam::Vec3>,
20///     id: EntityIds,
21/// }
22/// ```
23/// # Struct Attributes
24///
25/// - `item_derives`: Derive additional traits for the item returned by the fetch.
26/// - `transforms`: Implement `Transform` for the specified transform kinds.
27///
28/// # Field Attributes
29/// - `ignore`: ignore slot-filtering and transformations for a field.
30///     Useful for including a `Mutable` in a change query.
31#[proc_macro_derive(Fetch, attributes(fetch))]
32pub fn derive_fetch(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
33    let crate_name = match proc_macro_crate::crate_name("flax").expect("Failed to get crate name") {
34        FoundCrate::Itself => Ident::new("crate", Span::call_site()),
35        FoundCrate::Name(name) => Ident::new(&name, Span::call_site()),
36    };
37    do_derive_fetch(crate_name, input.into()).into()
38}
39
40fn do_derive_fetch(crate_name: Ident, input: TokenStream) -> TokenStream {
41    let input = match syn::parse2::<DeriveInput>(input) {
42        Ok(input) => input,
43        Err(err) => return err.to_compile_error(),
44    };
45
46    match input.data {
47        syn::Data::Struct(ref data) => derive_data_struct(crate_name, &input, data)
48            .unwrap_or_else(|err| err.to_compile_error()),
49        syn::Data::Enum(_) => todo!(),
50        syn::Data::Union(_) => todo!(),
51    }
52}
53
54fn derive_data_struct(
55    crate_name: Ident,
56    input: &DeriveInput,
57    data: &DataStruct,
58) -> Result<TokenStream> {
59    let attrs = Attrs::get(&input.attrs)?;
60
61    match data.fields {
62        syn::Fields::Named(_) => {
63            let params = Params::new(&crate_name, &input.vis, input, &attrs)?;
64
65            let prepared_derive = derive_prepared_struct(&params);
66
67            let fetch_derive = derive_fetch_struct(&params);
68
69            let union_derive = derive_union(&params);
70
71            let transforms_derive = derive_transform(&params)?;
72
73            Ok(quote! {
74                #fetch_derive
75
76                #prepared_derive
77
78                #union_derive
79
80                #transforms_derive
81            })
82        }
83        syn::Fields::Unnamed(_) => Err(Error::new(
84            Span::call_site(),
85            "Deriving fetch for a tuple struct is not supported",
86        )),
87        syn::Fields::Unit => Err(Error::new(
88            Span::call_site(),
89            "Deriving fetch for a unit struct is not supported",
90        )),
91    }
92}
93
94fn derive_fetch_struct(params: &Params) -> TokenStream {
95    let Params {
96        crate_name,
97        vis,
98        fetch_name,
99        item_name,
100        prepared_name,
101        q_generics,
102        fields,
103        field_names,
104        field_types,
105        attrs,
106        ..
107    } = params;
108
109    let item_ty = params.q_ty();
110    let item_impl = params.q_impl();
111    let item_msg = format!("The item returned by {fetch_name}");
112
113    let prep_ty = params.w_ty();
114
115    let extras = match &attrs.item_derives {
116        Some(extras) => {
117            quote! { #[derive(#extras)]}
118        }
119        None => quote! {},
120    };
121
122    let fetch_impl = params.w_impl();
123    let fetch_ty = params.base_ty();
124
125    let item_fields = fields
126        .iter()
127        .map(|v| {
128            let vis = v.vis;
129            let ident = v.ident;
130            let ty = v.ty;
131            quote! {
132                #vis #ident: <#ty as #crate_name::fetch::FetchItem<'q>>::Item,
133            }
134        })
135        .collect::<TokenStream>();
136
137    quote! {
138        #[doc = #item_msg]
139        #extras
140        #vis struct #item_name #q_generics {
141            #item_fields
142        }
143
144        // #vis struct #batch_name #wq_generics {
145        //     #(#field_names: <<#field_types as #crate_name::fetch::Fetch<'w>::Prepared> as #crate_name::fetch::PreparedFetch<#q_lf>>::Chunk,)*
146        // }
147
148        #[automatically_derived]
149        impl #item_impl #crate_name::fetch::FetchItem<'q> for #fetch_name #fetch_ty {
150            type Item = #item_name #item_ty;
151        }
152
153        #[automatically_derived]
154        impl #fetch_impl #crate_name::Fetch<'w> for #fetch_name #fetch_ty
155            where #(#field_types: 'static,)*
156        {
157            const MUTABLE: bool = #(<#field_types as #crate_name::Fetch <'w>>::MUTABLE)||*;
158
159            type Prepared = #prepared_name #prep_ty;
160
161            #[inline]
162            fn prepare( &'w self, data: #crate_name::fetch::FetchPrepareData<'w>
163            ) -> Option<Self::Prepared> {
164                Some(Self::Prepared {
165                    #(#field_names: #crate_name::Fetch::prepare(&self.#field_names, data)?,)*
166                })
167            }
168
169            #[inline]
170            fn filter_arch(&self, data: #crate_name::fetch::FetchAccessData) -> bool {
171                #(#crate_name::Fetch::filter_arch(&self.#field_names, data))&&*
172            }
173
174            fn describe(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
175                let mut s = f.debug_struct(stringify!(#fetch_name));
176
177                #(
178                    s.field(stringify!(#field_names), &#crate_name::fetch::FmtQuery(&self.#field_names));
179                )*
180
181                s.finish()
182            }
183
184            fn access(&self, data: #crate_name::fetch::FetchAccessData, dst: &mut Vec<#crate_name::system::Access>) {
185                 #(#crate_name::Fetch::access(&self.#field_names, data, dst));*
186            }
187
188            fn searcher(&self, searcher: &mut #crate_name::query::ArchetypeSearcher) {
189                #(#crate_name::Fetch::searcher(&self.#field_names, searcher);)*
190            }
191        }
192    }
193}
194
195fn prepend_generics(prepend: &[GenericParam], generics: &Generics) -> Generics {
196    let mut generics = generics.clone();
197    generics.params = prepend.iter().cloned().chain(generics.params).collect();
198
199    generics
200}
201
202/// Implements the filtering of the struct fields using a set union
203fn derive_union(params: &Params) -> TokenStream {
204    let Params {
205        crate_name,
206        fields,
207        prepared_name,
208        ..
209    } = params;
210
211    let impl_generics = params.wq_impl();
212
213    let prep_ty = params.w_ty();
214
215    // Make sure not to *or* ignored fields
216    let filter_fields = fields.iter().filter(|v| !v.attrs.ignore).map(|v| v.ident);
217    let filter_types = fields.iter().filter(|v| !v.attrs.ignore).map(|v| v.ty);
218
219    quote! {
220        #[automatically_derived]
221        impl #impl_generics #crate_name::fetch::UnionFilter for #prepared_name #prep_ty where #prepared_name #prep_ty: #crate_name::fetch::PreparedFetch<'q> {
222            const HAS_UNION_FILTER: bool = #(<<#filter_types as #crate_name::fetch::Fetch<'w>>::Prepared as #crate_name::fetch::PreparedFetch<'q>>::HAS_FILTER)&&*;
223
224            unsafe fn filter_union(&mut self, slots: #crate_name::archetype::Slice) -> #crate_name::archetype::Slice {
225                #crate_name::fetch::PreparedFetch::filter_slots(&mut #crate_name::filter::Union((#(&mut self.#filter_fields,)*)), slots)
226            }
227        }
228    }
229}
230
231/// Implements the filtering of the struct fields using a set union
232fn derive_transform(params: &Params) -> Result<TokenStream> {
233    let Params {
234        crate_name,
235        vis,
236        fields,
237        fetch_name,
238        attrs,
239        ..
240    } = params;
241
242    if attrs.transforms.is_empty() {
243        return Ok(quote! {});
244    }
245
246    // Replace all the fields with generics to allow transforming into different types
247    let ty_generics = ('A'..='Z')
248        .zip(fields)
249        .filter(|(_, v)| !v.attrs.ignore)
250        .map(|(c, _)| format_ident!("{}", c))
251        .map(|v| GenericParam::Type(TypeParam::from(v)))
252        .collect_vec();
253
254    let transformed_name = format_ident!("{fetch_name}Transformed");
255    use quote::ToTokens;
256
257    let transformed_struct = {
258        let fields = ('A'..='Z').zip(fields).map(|(c, field)| {
259            let ty = if field.attrs.ignore {
260                field.ty.to_token_stream()
261            } else {
262                format_ident!("{}", c).to_token_stream()
263            };
264
265            let vis = field.vis;
266            let ident = field.ident;
267            quote! {
268               #vis #ident: #ty,
269            }
270        });
271
272        quote! {
273            #vis struct #transformed_name<#(#ty_generics: for<'x> #crate_name::fetch::Fetch<'x>),*>{
274                #(#fields)*
275            }
276        }
277    };
278
279    let input =
280        syn::parse2::<DeriveInput>(transformed_struct).expect("Generated struct is always valid");
281
282    let transformed_attrs = Attrs::default();
283
284    let mut transformed_params = Params::new(crate_name, vis, &input, &transformed_attrs)?;
285    for (dst, src) in transformed_params.fields.iter_mut().zip(fields) {
286        dst.attrs = src.attrs.clone();
287    }
288
289    let fetch = derive_fetch_struct(&transformed_params);
290
291    let prepared = derive_prepared_struct(&transformed_params);
292    let union = derive_union(&transformed_params);
293
294    let transforms = attrs
295        .transforms
296        .iter()
297        .map(|method| {
298            let method = method.to_tokens(crate_name);
299
300            let trait_name = quote! { #crate_name::fetch::TransformFetch<#method> };
301
302            let types = fields
303                .iter()
304                .filter_map(|field| {
305                    if field.attrs.ignore {
306                        None
307                    } else {
308                        let ty = field.ty;
309                        Some(quote! {
310                            <#ty as #trait_name>::Output
311                        })
312                    }
313                })
314                .collect_vec();
315
316            let initializers = fields
317                .iter()
318                .map(|field| {
319                    let ident = field.ident;
320                    let ty = field.ty;
321                    if field.attrs.ignore {
322                        quote! {
323                            #ident: self.#ident
324                        }
325                    } else {
326                        quote! {
327                            #ident: <#ty as #trait_name>::transform_fetch(self.#ident, method)
328                        }
329                    }
330                })
331                .collect_vec();
332
333            quote! {
334                #[automatically_derived]
335                impl #trait_name for #fetch_name
336                {
337                    type Output = #crate_name::filter::Union<#transformed_name<#(#types,)*>>;
338                    fn transform_fetch(self, method: #method) -> Self::Output {
339                        #crate_name::filter::Union(#transformed_name {
340                            #(#initializers,)*
341                        })
342                    }
343                }
344            }
345        })
346        .collect_vec();
347
348    Ok(quote! {
349        #input
350
351        #fetch
352
353        #prepared
354
355        #union
356
357        #(#transforms)*
358    })
359}
360
361fn derive_prepared_struct(params: &Params) -> TokenStream {
362    let Params {
363        crate_name,
364        vis,
365        fetch_name,
366        item_name,
367        prepared_name,
368        fields,
369        field_names,
370        field_types,
371        w_generics,
372        ..
373    } = params;
374
375    let msg = format!("The prepared fetch for {fetch_name}");
376
377    let prep_impl = params.wq_impl();
378    let prep_ty = params.w_ty();
379    let item_ty = params.q_ty();
380
381    let field_idx = (0..field_names.len()).map(Index::from);
382    let filter_fields = fields.iter().filter(|v| !v.attrs.ignore).map(|v| v.ident);
383
384    quote! {
385        #[doc = #msg]
386        #vis struct #prepared_name #w_generics {
387            #(#field_names: <#field_types as #crate_name::Fetch <'w>>::Prepared,)*
388        }
389
390        #[automatically_derived]
391        impl #prep_impl #crate_name::fetch::PreparedFetch<'q> for #prepared_name #prep_ty
392            where #(#field_types: 'static,)*
393        {
394            type Item = #item_name #item_ty;
395            type Chunk = (#(<<#field_types as #crate_name::fetch::Fetch<'w>>::Prepared as #crate_name::fetch::PreparedFetch<'q>>::Chunk,)*);
396
397            const HAS_FILTER: bool = #(<<#field_types as #crate_name::fetch::Fetch<'w>>::Prepared as #crate_name::fetch::PreparedFetch<'q>>::HAS_FILTER)||*;
398
399            #[inline]
400            unsafe fn fetch_next(chunk: &mut Self::Chunk) -> Self::Item {
401                Self::Item {
402                    #(#field_names: <<#field_types as #crate_name::fetch::Fetch<'w>>::Prepared as #crate_name::fetch::PreparedFetch<'q>>::fetch_next(&mut chunk.#field_idx),)*
403                }
404            }
405
406            #[inline]
407            unsafe fn filter_slots(&mut self, slots: #crate_name::archetype::Slice) -> #crate_name::archetype::Slice {
408                #crate_name::fetch::PreparedFetch::filter_slots(&mut (#(&mut self.#filter_fields,)*), slots)
409            }
410
411            #[inline]
412            unsafe fn create_chunk(&'q mut self, slots: #crate_name::archetype::Slice) -> Self::Chunk {
413                (
414                    #(#crate_name::fetch::PreparedFetch::create_chunk(&mut self.#field_names, slots),)*
415                )
416            }
417        }
418    }
419}
420
421#[derive(Clone)]
422struct ParsedField<'a> {
423    vis: &'a Visibility,
424    ty: &'a Type,
425    ident: &'a Ident,
426    attrs: FieldAttrs,
427}
428
429impl<'a> ParsedField<'a> {
430    fn get(field: &'a Field) -> Result<Self> {
431        let attrs = FieldAttrs::get(&field.attrs)?;
432
433        let ident = field
434            .ident
435            .as_ref()
436            .ok_or(Error::new(field.span(), "Only named fields are supported"))?;
437
438        Ok(Self {
439            vis: &field.vis,
440            ty: &field.ty,
441            ident,
442            attrs,
443        })
444    }
445}
446
447#[derive(Default, Debug, Clone)]
448struct FieldAttrs {
449    ignore: bool,
450}
451
452impl FieldAttrs {
453    fn get(input: &[Attribute]) -> Result<Self> {
454        let mut res = Self::default();
455
456        for attr in input {
457            if !attr.path().is_ident("fetch") {
458                continue;
459            }
460
461            match &attr.meta {
462                syn::Meta::List(list) => {
463                    // Parse list
464
465                    list.parse_nested_meta(|meta| {
466                        // item = [Debug, PartialEq]
467                        if meta.path.is_ident("ignore") {
468                            res.ignore = true;
469                            Ok(())
470                        } else {
471                            Err(Error::new(
472                                meta.path.span(),
473                                "Unknown fetch field attribute",
474                            ))
475                        }
476                    })?;
477                }
478                _ => {
479                    return Err(Error::new(
480                        Span::call_site(),
481                        "Expected a MetaList for `fetch`",
482                    ))
483                }
484            };
485        }
486
487        Ok(res)
488    }
489}
490
491#[derive(Default)]
492struct Attrs {
493    item_derives: Option<Punctuated<Ident, Token![,]>>,
494    transforms: BTreeSet<TransformIdent>,
495}
496
497impl Attrs {
498    fn get(input: &[Attribute]) -> Result<Self> {
499        let mut res = Self::default();
500
501        for attr in input {
502            if !attr.path().is_ident("fetch") {
503                continue;
504            }
505
506            match &attr.meta {
507                syn::Meta::List(list) => {
508                    // Parse list
509
510                    list.parse_nested_meta(|meta| {
511                        // item = [Debug, PartialEq]
512                        if meta.path.is_ident("item_derives") {
513                            let value = meta.value()?;
514                            let content;
515                            bracketed!(content in value);
516                            let content =
517                                <Punctuated<Ident, Token![,]>>::parse_terminated(&content)?;
518
519                            res.item_derives = Some(content);
520                            Ok(())
521                        } else if meta.path.is_ident("transforms") {
522                            let value = meta.value()?;
523                            let content;
524                            bracketed!(content in value);
525                            let content =
526                                <Punctuated<TransformIdent, Token![,]>>::parse_terminated(
527                                    &content,
528                                )?;
529
530                            res.transforms.extend(content);
531                            Ok(())
532                        } else {
533                            Err(Error::new(meta.path.span(), "Unknown fetch attribute"))
534                        }
535                    })?;
536                }
537                _ => {
538                    return Err(Error::new(
539                        Span::call_site(),
540                        "Expected a MetaList for `fetch`",
541                    ))
542                }
543            };
544        }
545
546        Ok(res)
547    }
548}
549
550#[derive(Clone)]
551struct Params<'a> {
552    crate_name: &'a Ident,
553    vis: &'a Visibility,
554
555    fetch_name: Ident,
556    item_name: Ident,
557    prepared_name: Ident,
558
559    generics: &'a Generics,
560    w_generics: Generics,
561    q_generics: Generics,
562    wq_generics: Generics,
563
564    fields: Vec<ParsedField<'a>>,
565    field_names: Vec<&'a Ident>,
566    field_types: Vec<&'a Type>,
567
568    attrs: &'a Attrs,
569}
570
571impl<'a> Params<'a> {
572    fn new(
573        crate_name: &'a Ident,
574        vis: &'a Visibility,
575        input: &'a DeriveInput,
576        attrs: &'a Attrs,
577    ) -> Result<Self> {
578        let fields = match &input.data {
579            syn::Data::Struct(data) => match &data.fields {
580                syn::Fields::Named(fields) => fields,
581                _ => unreachable!(),
582            },
583
584            _ => unreachable!(),
585        };
586
587        let fetch_name = input.ident.clone();
588
589        let w_lf = LifetimeParam::new(Lifetime::new("'w", Span::call_site()));
590        let q_lf = LifetimeParam::new(Lifetime::new("'q", Span::call_site()));
591
592        let fields = fields
593            .named
594            .iter()
595            .map(ParsedField::get)
596            .collect::<Result<Vec<_>>>()?;
597
598        let field_names = fields.iter().map(|v| v.ident).collect_vec();
599        let field_types = fields.iter().map(|v| v.ty).collect_vec();
600
601        Ok(Self {
602            crate_name,
603            vis,
604            generics: &input.generics,
605            fields,
606            field_names,
607            field_types,
608            attrs,
609            item_name: format_ident!("{fetch_name}Item"),
610            prepared_name: format_ident!("Prepared{fetch_name}"),
611            fetch_name,
612            w_generics: prepend_generics(&[GenericParam::Lifetime(w_lf.clone())], &input.generics),
613            q_generics: prepend_generics(&[GenericParam::Lifetime(q_lf.clone())], &input.generics),
614
615            wq_generics: prepend_generics(
616                &[
617                    GenericParam::Lifetime(w_lf.clone()),
618                    GenericParam::Lifetime(q_lf.clone()),
619                ],
620                &input.generics,
621            ),
622        })
623    }
624
625    fn q_impl(&self) -> ImplGenerics {
626        self.q_generics.split_for_impl().0
627    }
628
629    fn wq_impl(&self) -> ImplGenerics {
630        self.wq_generics.split_for_impl().0
631    }
632
633    fn w_impl(&self) -> ImplGenerics {
634        self.w_generics.split_for_impl().0
635    }
636
637    fn base_ty(&self) -> TypeGenerics {
638        self.generics.split_for_impl().1
639    }
640
641    fn q_ty(&self) -> TypeGenerics {
642        self.q_generics.split_for_impl().1
643    }
644
645    fn w_ty(&self) -> TypeGenerics {
646        self.w_generics.split_for_impl().1
647    }
648}
649
650#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
651enum TransformIdent {
652    Modified,
653    Added,
654}
655
656impl TransformIdent {
657    fn to_tokens(&self, crate_name: &Ident) -> TokenStream {
658        match self {
659            Self::Modified => quote!(#crate_name::fetch::Modified),
660            Self::Added => quote!(#crate_name::fetch::Added),
661        }
662    }
663}
664
665impl Parse for TransformIdent {
666    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
667        let ident = input.parse::<Ident>()?;
668        if ident == "Modified" {
669            Ok(Self::Modified)
670        } else if ident == "Added" {
671            Ok(Self::Added)
672        } else {
673            Err(Error::new(
674                ident.span(),
675                format!("Unknown transform {ident}"),
676            ))
677        }
678    }
679}