use bevy_macro_utils::BevyManifest;
use proc_macro::TokenStream;
use proc_macro_crate::{FoundCrate, crate_name};
use quote::quote;
use syn::spanned::Spanned as _;
use syn::{Data, DataStruct, DeriveInput, Fields, Index, Member, Path, Type, parse_quote};
pub const CHAIN_EVENT: &str = "chain_event";
pub const RELATED_CHAIN_EVENT: &str = "related_chain_event";
pub const RELATIONSHIP: &str = "relationship";
pub const RELATIONSHIP_TARGET: &str = "relationship_target";
pub const TRIGGER: &str = "trigger";
pub const EVENT_TARGET: &str = "event_target";
pub fn derive_chain_event(input: TokenStream) -> TokenStream {
match chain_event(syn::parse_macro_input!(input as DeriveInput)) {
Ok(expr) => expr,
Err(err) => err.to_compile_error().into(),
}
}
fn chain_event(mut ast: DeriveInput) -> syn::Result<TokenStream> {
ast.generics
.make_where_clause()
.predicates
.push(parse_quote! { Self: Send + Sync + 'static });
let mut trigger: Option<Type> = None;
let bevy_event_chain_path = match crate_name("bevy_event_chain") {
Ok(FoundCrate::Name(name)) => {
let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
quote! { #ident }
}
Ok(FoundCrate::Itself) => quote! { crate },
Err(_) => quote! { ::bevy_event_chain },
};
let bevy_ecs_path: Path = BevyManifest::shared(|manifest| manifest.get_path("bevy_ecs"));
let mut processed_attrs = Vec::new();
for attr in ast
.attrs
.iter()
.filter(|attr| attr.path().is_ident(CHAIN_EVENT))
{
attr.parse_nested_meta(|meta| match meta.path.get_ident() {
Some(ident) if processed_attrs.iter().any(|i| ident == i) => {
Err(meta.error(format!("duplicate attribute: {ident}")))
}
Some(ident) if ident == TRIGGER => {
trigger = Some(meta.value()?.parse()?);
processed_attrs.push(TRIGGER);
Ok(())
}
Some(ident) => Err(meta.error(format!("unsupported attribute: {ident}"))),
None => Err(meta.error("expected identifier")),
})?;
}
let chain_field = get_event_target_field(&ast)?;
let struct_name = &ast.ident;
let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl();
let trigger = if let Some(trigger) = trigger {
quote! {#trigger}
} else {
quote! {#bevy_ecs_path::event::EntityTrigger}
};
Ok(quote! {
impl #impl_generics #bevy_ecs_path::event::Event for #struct_name #type_generics #where_clause {
type Trigger<'a> = #trigger;
}
impl #impl_generics #bevy_ecs_path::event::EntityEvent for #struct_name #type_generics #where_clause {
fn event_target(&self) -> #bevy_ecs_path::entity::Entity {
self.#chain_field.event_target()
}
}
impl #impl_generics #bevy_event_chain_path::ChainEvent for #struct_name #type_generics #where_clause {
fn next(&self) -> Self {
let mut next = self.clone();
next.#chain_field = self.#chain_field.next();
next
}
}
}.into())
}
pub fn derive_related_chain_event(input: TokenStream) -> TokenStream {
match related_chain_event(syn::parse_macro_input!(input as DeriveInput)) {
Ok(expr) => expr,
Err(err) => err.to_compile_error().into(),
}
}
fn related_chain_event(mut ast: DeriveInput) -> syn::Result<TokenStream> {
ast.generics
.make_where_clause()
.predicates
.push(parse_quote! { Self: Send + Sync + 'static });
let mut relationship: Option<Type> = None;
let mut relationship_target: Option<Type> = None;
let bevy_event_chain_path = match crate_name("bevy_event_chain") {
Ok(FoundCrate::Name(name)) => {
let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
quote! { #ident }
}
Ok(FoundCrate::Itself) => quote! { crate },
Err(_) => quote! { ::bevy_event_chain },
};
let bevy_ecs_path: Path = BevyManifest::shared(|manifest| manifest.get_path("bevy_ecs"));
let mut processed_attrs = Vec::new();
for attr in ast
.attrs
.iter()
.filter(|attr| attr.path().is_ident(RELATED_CHAIN_EVENT))
{
attr.parse_nested_meta(|meta| match meta.path.get_ident() {
Some(ident) if processed_attrs.iter().any(|i| ident == i) => {
Err(meta.error(format!("duplicate attribute: {ident}")))
}
Some(ident) if ident == RELATIONSHIP => {
relationship = Some(meta.value()?.parse()?);
processed_attrs.push(RELATIONSHIP);
Ok(())
}
Some(ident) if ident == RELATIONSHIP_TARGET => {
relationship_target = Some(meta.value()?.parse()?);
processed_attrs.push(RELATIONSHIP_TARGET);
Ok(())
}
Some(ident) => Err(meta.error(format!("unsupported attribute: {ident}"))),
None => Err(meta.error("expected identifier")),
})?;
}
let (relationship, relationship_target) = match (relationship, relationship_target) {
(Some(r), Some(rt)) => (r, rt),
(Some(r), None) => {
let rt = parse_quote! { <#r as #bevy_ecs_path::relationship::Relationship>::RelationshipTarget };
(r, rt)
}
(None, Some(rt)) => {
let r = parse_quote! { <#rt as #bevy_ecs_path::relationship::RelationshipTarget>::Relationship };
(r, rt)
}
(None, None) => {
return Err(syn::Error::new(
ast.span(),
"ChainEvent derive requires either a 'relationship' or 'relationship_target' attribute, e.g. #[chain_event(relationship = MyRelationship)]",
));
}
};
let chain_field = get_event_target_field(&ast)?;
let struct_name = &ast.ident;
let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl();
Ok(quote! {
impl #impl_generics #bevy_ecs_path::event::Event for #struct_name #type_generics #where_clause {
type Trigger<'a> = #bevy_event_chain_path::EntityComponentTrigger;
}
impl #impl_generics #bevy_ecs_path::event::EntityEvent for #struct_name #type_generics #where_clause {
fn event_target(&self) -> #bevy_ecs_path::entity::Entity {
self.#chain_field.event_target()
}
}
impl #impl_generics #bevy_event_chain_path::RelatedChainEvent for #struct_name #type_generics #where_clause {
type Relationship = #relationship;
type RelationshipTarget = #relationship_target;
fn next(&self) -> Self {
let mut next = self.clone();
next.#chain_field = self.#chain_field.next();
next
}
fn get_trigger(&self) -> #bevy_event_chain_path::EntityComponentTrigger {
self.#chain_field.get_trigger()
}
}
}.into())
}
fn get_event_target_field(ast: &DeriveInput) -> syn::Result<Member> {
let Data::Struct(DataStruct { fields, .. }) = &ast.data else {
return Err(syn::Error::new(
ast.span(),
"ChainEvent can only be derived for structs.",
));
};
match fields {
Fields::Named(fields) => fields.named.iter().find_map(|field| {
if field.ident.as_ref().is_some_and(|i| i == "chain") || field
.attrs
.iter()
.any(|attr| attr.path().is_ident(EVENT_TARGET)) {
Some(Member::Named(field.ident.clone()?))
} else {
None
}
}).ok_or(syn::Error::new(
fields.span(),
"ChainEvent derive expected a field name 'chain' or a field annotated with #[event_target]."
)),
Fields::Unnamed(fields) if fields.unnamed.len() == 1 => Ok(Member::Unnamed(Index::from(0))),
Fields::Unnamed(fields) => fields.unnamed.iter().enumerate().find_map(|(index, field)| {
if field
.attrs
.iter()
.any(|attr| attr.path().is_ident(EVENT_TARGET)) {
Some(Member::Unnamed(Index::from(index)))
} else {
None
}
})
.ok_or(syn::Error::new(
fields.span(),
"ChainEvent derive expected unnamed structs with one field or with a field annotated with #[event_target].",
)),
Fields::Unit => Err(syn::Error::new(
fields.span(),
"ChainEvent derive does not work on unit structs. Your type must have a field to store the `EventChain` target, such as `Attack(EventChain)` or `Attack { chain: EventChain }`.",
)),
}
}