#[cfg(feature = "reflect")]
use core::any::type_name;
use core::marker::PhantomData;
#[cfg(feature = "reflect")]
use bevy::reflect::utility::GenericTypePathCell;
use bevy::{
prelude::*,
state::state::{StateTransitionEvent, StateTransitionSystems, States},
};
use log::debug;
use smallvec::SmallVec;
use crate::prelude::*;
pub trait StateContextAppExt {
fn sync_context_to_state<C: Component, S: States>(&mut self) -> &mut Self;
}
impl StateContextAppExt for App {
fn sync_context_to_state<C: Component, S: States>(&mut self) -> &mut Self {
debug!(
"registering state sync for `{}` with `{}`",
ShortName::of::<C>(),
ShortName::of::<S>(),
);
self.add_observer(sync_on_insert::<C, S>).add_systems(
StateTransition,
sync_state_contexts::<C, S>
.after(StateTransitionSystems::DependentTransitions)
.before(StateTransitionSystems::ExitSchedules),
)
}
}
fn sync_on_insert<C: Component, S: States>(
insert: On<Insert, ActiveInStates<C, S>>,
mut commands: Commands,
current_state: Option<Res<State<S>>>,
contexts: Query<&ActiveInStates<C, S>>,
activity: Query<&ContextActivity<C>>,
) {
let Ok(active_in) = contexts.get(insert.entity) else {
return;
};
set_context_activity(
&mut commands,
&activity,
insert.entity,
active_in.matches_state(current_state.as_ref().map(|s| s.get())),
);
}
fn sync_state_contexts<C: Component, S: States>(
mut commands: Commands,
mut transitions: MessageReader<StateTransitionEvent<S>>,
contexts: Query<(Entity, &ActiveInStates<C, S>)>,
activity: Query<&ContextActivity<C>>,
) {
let Some(transition) = transitions.read().last() else {
return;
};
for (entity, active_in) in &contexts {
set_context_activity(
&mut commands,
&activity,
entity,
active_in.matches_state(transition.entered.as_ref()),
);
}
}
fn set_context_activity<C: Component>(
commands: &mut Commands,
activity: &Query<&ContextActivity<C>>,
entity: Entity,
active: bool,
) {
if let Ok(current) = activity.get(entity)
&& **current == active
{
return;
}
debug!(
"setting `{}` on `{entity}` to `{active}`",
ShortName::of::<C>(),
);
commands
.entity(entity)
.insert(ContextActivity::<C>::new(active));
}
#[derive(Component)]
#[cfg_attr(
feature = "reflect",
derive(Reflect),
reflect(Clone, Component, type_path = false)
)]
pub struct ActiveInStates<C: Component, S: States> {
states: SmallVec<[S; 1]>,
#[cfg_attr(feature = "reflect", reflect(ignore))]
_marker: PhantomData<C>,
}
impl<C: Component, S: States> ActiveInStates<C, S> {
#[must_use]
pub fn single(state: S) -> Self {
Self {
states: SmallVec::from_buf([state]),
_marker: PhantomData,
}
}
#[must_use]
pub fn new(states: impl IntoIterator<Item = S>) -> Self {
Self {
states: states.into_iter().collect(),
_marker: PhantomData,
}
}
#[must_use]
pub fn matches(&self, current: &S) -> bool {
self.states.contains(current)
}
#[must_use]
fn matches_state(&self, current: Option<&S>) -> bool {
current.is_some_and(|current| self.matches(current))
}
}
impl<C: Component, S: States> Clone for ActiveInStates<C, S> {
fn clone(&self) -> Self {
Self {
states: self.states.clone(),
_marker: PhantomData,
}
}
}
#[cfg(feature = "reflect")]
impl<C: Component, S: States> TypePath for ActiveInStates<C, S> {
fn type_path() -> &'static str {
static CELL: GenericTypePathCell = GenericTypePathCell::new();
CELL.get_or_insert::<Self, _>(|| {
format!(
concat!(module_path!(), "::ActiveInStates<{}, {}>"),
type_name::<C>(),
type_name::<S>()
)
})
}
fn short_type_path() -> &'static str {
static CELL: GenericTypePathCell = GenericTypePathCell::new();
CELL.get_or_insert::<Self, _>(|| {
format!("ActiveInStates<{}, {}>", type_name::<C>(), type_name::<S>())
})
}
fn type_ident() -> Option<&'static str> {
Some("ActiveInStates")
}
fn module_path() -> Option<&'static str> {
Some(module_path!())
}
fn crate_name() -> Option<&'static str> {
Some(module_path!().split(':').next().unwrap())
}
}