bevy_event_chain_derive 0.2.0

Procedural macros for bevy_event_chain
Documentation
use bevy_macro_utils::BevyManifest;
use proc_macro::TokenStream;
use proc_macro_crate::{FoundCrate, crate_name};
use quote::quote;
use syn::spanned::Spanned as _;
use syn::{Data, DataStruct, DeriveInput, Fields, Index, Member, Path, Type, parse_quote};

pub const CHAIN_EVENT: &str = "chain_event";
pub const RELATED_CHAIN_EVENT: &str = "related_chain_event";
pub const RELATIONSHIP: &str = "relationship";
pub const RELATIONSHIP_TARGET: &str = "relationship_target";
pub const TRIGGER: &str = "trigger";
pub const EVENT_TARGET: &str = "event_target";

pub fn derive_chain_event(input: TokenStream) -> TokenStream {
    match chain_event(syn::parse_macro_input!(input as DeriveInput)) {
        Ok(expr) => expr,
        Err(err) => err.to_compile_error().into(),
    }
}

fn chain_event(mut ast: DeriveInput) -> syn::Result<TokenStream> {
    ast.generics
        .make_where_clause()
        .predicates
        .push(parse_quote! { Self: Send + Sync + 'static });

    let mut trigger: Option<Type> = None;
    let bevy_event_chain_path = match crate_name("bevy_event_chain") {
        Ok(FoundCrate::Name(name)) => {
            let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
            quote! { #ident }
        }
        Ok(FoundCrate::Itself) => quote! { crate },
        Err(_) => quote! { ::bevy_event_chain },
    };
    let bevy_ecs_path: Path = BevyManifest::shared(|manifest| manifest.get_path("bevy_ecs"));

    let mut processed_attrs = Vec::new();

    for attr in ast
        .attrs
        .iter()
        .filter(|attr| attr.path().is_ident(CHAIN_EVENT))
    {
        attr.parse_nested_meta(|meta| match meta.path.get_ident() {
            Some(ident) if processed_attrs.iter().any(|i| ident == i) => {
                Err(meta.error(format!("duplicate attribute: {ident}")))
            }
            Some(ident) if ident == TRIGGER => {
                trigger = Some(meta.value()?.parse()?);
                processed_attrs.push(TRIGGER);
                Ok(())
            }
            Some(ident) => Err(meta.error(format!("unsupported attribute: {ident}"))),
            None => Err(meta.error("expected identifier")),
        })?;
    }

    let chain_field = get_event_target_field(&ast)?;

    let struct_name = &ast.ident;
    let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl();

    let trigger = if let Some(trigger) = trigger {
        quote! {#trigger}
    } else {
        quote! {#bevy_ecs_path::event::EntityTrigger}
    };
    Ok(quote! {
        impl #impl_generics #bevy_ecs_path::event::Event for #struct_name #type_generics #where_clause {
            type Trigger<'a> = #trigger;
        }

        impl #impl_generics #bevy_ecs_path::event::EntityEvent for #struct_name #type_generics #where_clause {
            fn event_target(&self) -> #bevy_ecs_path::entity::Entity {
                self.#chain_field.event_target()
            }
        }

		impl #impl_generics #bevy_event_chain_path::ChainEvent for #struct_name #type_generics #where_clause {
			fn next(&self) -> Self {
				let mut next = self.clone();
				next.#chain_field = self.#chain_field.next();
				next
			}
		}
    }.into())
}

pub fn derive_related_chain_event(input: TokenStream) -> TokenStream {
    match related_chain_event(syn::parse_macro_input!(input as DeriveInput)) {
        Ok(expr) => expr,
        Err(err) => err.to_compile_error().into(),
    }
}

fn related_chain_event(mut ast: DeriveInput) -> syn::Result<TokenStream> {
    ast.generics
        .make_where_clause()
        .predicates
        .push(parse_quote! { Self: Send + Sync + 'static });

    let mut relationship: Option<Type> = None;
    let mut relationship_target: Option<Type> = None;
    let bevy_event_chain_path = match crate_name("bevy_event_chain") {
        Ok(FoundCrate::Name(name)) => {
            let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
            quote! { #ident }
        }
        Ok(FoundCrate::Itself) => quote! { crate },
        Err(_) => quote! { ::bevy_event_chain },
    };
    let bevy_ecs_path: Path = BevyManifest::shared(|manifest| manifest.get_path("bevy_ecs"));

    let mut processed_attrs = Vec::new();

    for attr in ast
        .attrs
        .iter()
        .filter(|attr| attr.path().is_ident(RELATED_CHAIN_EVENT))
    {
        attr.parse_nested_meta(|meta| match meta.path.get_ident() {
            Some(ident) if processed_attrs.iter().any(|i| ident == i) => {
                Err(meta.error(format!("duplicate attribute: {ident}")))
            }
            Some(ident) if ident == RELATIONSHIP => {
                relationship = Some(meta.value()?.parse()?);
                processed_attrs.push(RELATIONSHIP);
                Ok(())
            }
            Some(ident) if ident == RELATIONSHIP_TARGET => {
                relationship_target = Some(meta.value()?.parse()?);
                processed_attrs.push(RELATIONSHIP_TARGET);
                Ok(())
            }
            Some(ident) => Err(meta.error(format!("unsupported attribute: {ident}"))),
            None => Err(meta.error("expected identifier")),
        })?;
    }

    let (relationship, relationship_target) = match (relationship, relationship_target) {
        (Some(r), Some(rt)) => (r, rt),
        (Some(r), None) => {
            let rt = parse_quote! { <#r as #bevy_ecs_path::relationship::Relationship>::RelationshipTarget };
            (r, rt)
        }
        (None, Some(rt)) => {
            let r = parse_quote! { <#rt as #bevy_ecs_path::relationship::RelationshipTarget>::Relationship };
            (r, rt)
        }
        (None, None) => {
            return Err(syn::Error::new(
                ast.span(),
                "ChainEvent derive requires either a 'relationship' or 'relationship_target' attribute, e.g. #[chain_event(relationship = MyRelationship)]",
            ));
        }
    };

    let chain_field = get_event_target_field(&ast)?;

    let struct_name = &ast.ident;
    let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl();

    Ok(quote! {
        impl #impl_generics #bevy_ecs_path::event::Event for #struct_name #type_generics #where_clause {
            type Trigger<'a> = #bevy_event_chain_path::EntityComponentTrigger;
        }

        impl #impl_generics #bevy_ecs_path::event::EntityEvent for #struct_name #type_generics #where_clause {
            fn event_target(&self) -> #bevy_ecs_path::entity::Entity {
                self.#chain_field.event_target()
            }
        }

		impl #impl_generics #bevy_event_chain_path::RelatedChainEvent for #struct_name #type_generics #where_clause {
			type Relationship = #relationship;
			type RelationshipTarget = #relationship_target;

			fn next(&self) -> Self {
				let mut next = self.clone();
				next.#chain_field = self.#chain_field.next();
				next
			}

			fn get_trigger(&self) -> #bevy_event_chain_path::EntityComponentTrigger {
				self.#chain_field.get_trigger()
			}
		}
    }.into())
}

/// Returns the field with the `#[event_target]` attribute, the only field if unnamed,
/// or the field with the name "entity".
fn get_event_target_field(ast: &DeriveInput) -> syn::Result<Member> {
    let Data::Struct(DataStruct { fields, .. }) = &ast.data else {
        return Err(syn::Error::new(
            ast.span(),
            "ChainEvent can only be derived for structs.",
        ));
    };
    match fields {
        Fields::Named(fields) => fields.named.iter().find_map(|field| {
            if field.ident.as_ref().is_some_and(|i| i == "chain") || field
                .attrs
                .iter()
                .any(|attr| attr.path().is_ident(EVENT_TARGET)) {
                    Some(Member::Named(field.ident.clone()?))
                } else {
                    None
                }
        }).ok_or(syn::Error::new(
            fields.span(),
            "ChainEvent derive expected a field name 'chain' or a field annotated with #[event_target]."
        )),
        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => Ok(Member::Unnamed(Index::from(0))),
        Fields::Unnamed(fields) => fields.unnamed.iter().enumerate().find_map(|(index, field)| {
                if field
                    .attrs
                    .iter()
                    .any(|attr| attr.path().is_ident(EVENT_TARGET)) {
                        Some(Member::Unnamed(Index::from(index)))
                    } else {
                        None
                    }
            })
            .ok_or(syn::Error::new(
                fields.span(),
                "ChainEvent derive expected unnamed structs with one field or with a field annotated with #[event_target].",
            )),
        Fields::Unit => Err(syn::Error::new(
            fields.span(),
            "ChainEvent derive does not work on unit structs. Your type must have a field to store the `EventChain` target, such as `Attack(EventChain)` or `Attack { chain: EventChain }`.",
        )),
    }
}