use crate::prelude::*;
use beet_core::prelude::*;
use proc_macro2::TokenStream;
use quote::quote;
use syn::Ident;
use syn::ItemFn;
use syn::Result;
use syn::ReturnType;
pub fn template_macro(input: ItemFn) -> TokenStream {
parse(input).unwrap_or_else(|err| err.into_compile_error())
}
fn template_ident(input: &ItemFn) -> Ident {
use heck::ToUpperCamelCase;
let name = input.sig.ident.to_string();
let upper = name.to_upper_camel_case();
Ident::new(&upper, input.sig.ident.span())
}
fn parse(input: ItemFn) -> Result<TokenStream> {
let fields = NodeField::parse_item_fn(&input)?;
let ident = template_ident(&input);
let define_struct = define_struct(&input, &ident, &fields)?;
let impl_template_bundle = impl_template_bundle(&input, &ident, &fields)?;
let imports = if pkg_ext::is_internal() {
quote! {}
} else {
quote! {
use beet::prelude::*;
}
};
Ok(quote! {
#imports
#define_struct
#impl_template_bundle
})
}
fn define_struct(
func: &ItemFn,
ident: &Ident,
fields: &[NodeField],
) -> Result<TokenStream> {
let attrs = &func.attrs;
let (_, type_generics, where_clause) = func.sig.generics.split_for_impl();
let fields = prop_fields(fields).map(|f| {
let ident = &f.ident;
let attrs = &f.attrs;
let ty = f.ty;
quote! {
#(#attrs)*
pub #ident: #ty
}
});
let vis = &func.vis;
Ok(quote! {
#(#attrs)*
#[derive(Props)]
#vis struct #ident #type_generics #where_clause {
#(#fields),*
}
})
}
fn impl_template_bundle(
func: &ItemFn,
ident: &Ident,
fields: &[NodeField],
) -> Result<TokenStream> {
let destructure_props = prop_fields(fields).map(|field| {
let mutability = field.mutability;
let ident = &field.ident;
quote! {
#mutability #ident
}
});
let param_fields = system_param_fields(fields).map(|field| {
let ident = &field.ident;
let ty = &field.ty;
let attrs = &field.non_field_attrs();
let mutability = field.mutability;
quote! {
#(#attrs)*
#mutability #ident: #ty
}
});
let (impl_generics, type_generics, where_clause) =
func.sig.generics.split_for_impl();
let return_type = with_captured_lifetimes(&func);
let body = &func.block.stmts;
let entity_ident = entity_param_ident(fields)
.cloned()
.unwrap_or_else(|| Ident::new("entity", ident.span()));
let err_msg =
format!("Error running template system for `{}`: {{}}", ident);
let returns_result = if let ReturnType::Type(_, ty) = &return_type
&& let syn::Type::Path(type_path) = &**ty
{
type_path
.path
.segments
.last()
.map_or(false, |segment| segment.ident == "Result")
} else {
false
};
let maybe_unwrap = if returns_result {
quote! { .unwrap_or_exit() }
} else {
Default::default()
};
let maybe_generics = if returns_result {
quote!(::<_, Result<_>, _, _>)
} else {
Default::default()
};
Ok(quote! {
impl #impl_generics #ident #type_generics #where_clause {
pub fn system(#[allow(unused_variables, unused_assignments)]In((#entity_ident, input)): In<(Entity,Self)>, #(#param_fields),*) #return_type{
#[allow(unused_variables, unused_assignments)]
let Self { #(#destructure_props),* } = input;
#(#body)*
}
}
impl #impl_generics IntoBundle<Self> for #ident #type_generics #where_clause {
fn into_bundle(self) -> impl Bundle {
OnSpawn::new(move |entity_world_mut: &mut EntityWorldMut| {
let id = entity_world_mut.id();
let bundle = entity_world_mut.world_scope(|world| {
world.run_system_cached_with #maybe_generics(Self::system, (id,self)).map_err(|err|
bevyhow!(#err_msg, err)
).unwrap_or_exit()
})#maybe_unwrap;
entity_world_mut.insert(bundle.into_bundle());
})
}
}
})
}
fn with_captured_lifetimes(func: &ItemFn) -> ReturnType {
fn impl_recursive(func: &ItemFn, ty: &mut syn::Type) {
match ty {
syn::Type::Path(type_path) => {
for segment in &mut type_path.path.segments {
if let syn::PathArguments::AngleBracketed(args) =
&mut segment.arguments
{
for arg in &mut args.args {
if let syn::GenericArgument::Type(ty) = arg {
impl_recursive(func, ty);
}
}
}
}
}
syn::Type::ImplTrait(impl_trait) => {
let bound = if func.sig.generics.params.is_empty() {
syn::parse_quote! { use<> }
} else {
let (_, type_generics, _) =
func.sig.generics.split_for_impl();
syn::parse_quote! { use #type_generics }
};
impl_trait.bounds.push(bound);
}
_ => {}
}
}
let mut return_type = func.sig.output.clone();
if let ReturnType::Type(_, ty) = &mut return_type {
impl_recursive(func, &mut *ty);
}
return_type
}
const SYSTEM_PARAM_IDENTS: [&str; 7] = [
"World",
"Commands",
"Res",
"ResMut",
"Query",
"Populated",
"When",
];
fn prop_fields<'a>(
fields: &'a [NodeField],
) -> impl Iterator<Item = &'a NodeField<'a>> {
fields.iter().filter(|f| {
!f.last_segment_matches("Entity") && !is_system_param_field(f)
})
}
fn system_param_fields<'a>(
fields: &'a [NodeField],
) -> impl Iterator<Item = &'a NodeField<'a>> {
fields.iter().filter(|f| {
is_system_param_field(f) && !f.last_segment_matches("Entity")
})
}
fn is_system_param_field(field: &NodeField) -> bool {
field.field_attributes.contains("param")
|| SYSTEM_PARAM_IDENTS
.iter()
.any(|id| field.last_segment_matches(id))
}
fn entity_param_ident<'a>(fields: &'a [NodeField]) -> Option<&'a Ident> {
fields
.iter()
.find(|field| field.last_segment_matches("Entity"))
.map(|field| field.ident)
}
#[cfg(test)]
mod test {
use super::with_captured_lifetimes;
use crate::prelude::*;
use beet_core::prelude::*;
use syn::PathSegment;
#[test]
fn capture_lifetimes_test() {
with_captured_lifetimes(&syn::parse_quote! {
fn foo<T>() -> impl Bundle {}
})
.xpect_eq(syn::parse_quote! {-> impl Bundle + use<T> });
with_captured_lifetimes(
&syn::parse_quote! {fn bar() -> Result<impl Bundle, ()>{} },
)
.xpect_eq(syn::parse_quote! {-> Result<impl Bundle + use<>, ()> });
}
#[test]
fn segments() {
let a: PathSegment = syn::parse_quote! {Foo};
a.ident.xpect_eq("Foo");
let a: PathSegment = syn::parse_quote! {Foo<Bar>};
a.ident.xpect_eq("Foo");
}
#[test]
fn simple() {
template_macro(syn::parse_quote! {
pub(crate) fn MyNode(
foo:u32,
mut bar:u32
) -> impl Bundle{()}
})
.xpect_snapshot();
}
#[test]
fn complex() {
template_macro(syn::parse_quote! {
pub(crate) fn MyNode(
foo:u32,
mut bar:u32,
my_entity:Entity,
world: &mut World,
res: Res<Time>,
mut query: Query<&mut Transform>,
#[field(param)]
custom_query: CustomQuery
) -> impl Bundle{()}
})
.xpect_snapshot();
}
}