use std::{
any::{type_name, Any, TypeId},
fmt::Debug,
marker::PhantomData,
};
use bevy::{
ecs::system::{Command, EntityCommands, SystemState},
tasks::{ComputeTaskPool, ParallelSliceMut},
utils::HashMap,
};
use crate::{
prelude::*,
set::StateSet,
state::{Insert, OnEvent},
};
pub(crate) fn machine_plugin(app: &mut App) {
app.add_systems(PostUpdate, transition.in_set(StateSet::Transition));
}
trait Transition: Debug + Send + Sync + 'static {
fn init(&mut self, world: &mut World);
fn run(&mut self, world: &World, entity: Entity) -> Option<(Box<dyn Insert>, TypeId)>;
}
struct TransitionImpl<Trig, Prev, Build, Next>
where
Trig: Trigger,
Prev: MachineState,
Build: 'static + Fn(&Prev, Trig::Ok) -> Option<Next> + Send + Sync,
Next: Component + MachineState,
{
pub trigger: Trig,
pub builder: Build,
system_state: Option<SystemState<Trig::Param<'static, 'static>>>,
phantom: PhantomData<Prev>,
}
impl<Trig, Prev, Build, Next> Debug for TransitionImpl<Trig, Prev, Build, Next>
where
Trig: Trigger,
Prev: MachineState,
Build: Fn(&Prev, Trig::Ok) -> Option<Next> + Send + Sync,
Next: Component + MachineState,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TransitionImpl")
.field("trigger", &self.trigger.type_id())
.field("builder", &self.builder.type_id())
.field("system_state", &self.system_state.type_id())
.field("phantom", &self.phantom)
.finish()
}
}
impl<Trig, Prev, Build, Next> Transition for TransitionImpl<Trig, Prev, Build, Next>
where
Trig: Trigger,
Prev: MachineState,
Build: Fn(&Prev, Trig::Ok) -> Option<Next> + Send + Sync,
Next: Component + MachineState,
{
fn init(&mut self, world: &mut World) {
if self.system_state.is_none() {
self.system_state = Some(SystemState::new(world));
}
}
fn run(&mut self, world: &World, entity: Entity) -> Option<(Box<dyn Insert>, TypeId)> {
let state = self.system_state.as_mut().unwrap();
let Ok(res) = self.trigger.trigger(entity, state.get(world)) else { return None };
(self.builder)(Prev::from_entity(entity, world), res)
.map(|state| (Box::new(state) as Box<dyn Insert>, TypeId::of::<Next>()))
}
}
impl<Trig, Prev, Build, Next> TransitionImpl<Trig, Prev, Build, Next>
where
Trig: Trigger,
Prev: MachineState,
Build: Fn(&Prev, Trig::Ok) -> Option<Next> + Send + Sync,
Next: Component + MachineState,
{
pub fn new(trigger: Trig, builder: Build) -> Self {
Self {
trigger,
builder,
system_state: None,
phantom: PhantomData,
}
}
}
#[derive(Debug)]
struct StateMetadata {
name: String,
on_enter: Vec<OnEvent>,
on_exit: Vec<OnEvent>,
}
impl StateMetadata {
fn new<S: MachineState>() -> Self {
Self {
name: type_name::<S>().to_owned(),
on_enter: default(),
on_exit: vec![OnEvent::Entity(Box::new(|entity: &mut EntityCommands| {
S::remove(entity);
}))],
}
}
}
#[derive(Component)]
pub struct StateMachine {
states: HashMap<TypeId, StateMetadata>,
transitions: Vec<(TypeId, Box<dyn Transition>)>,
log_transitions: bool,
}
impl Default for StateMachine {
fn default() -> Self {
Self {
states: HashMap::from([(
TypeId::of::<AnyState>(),
StateMetadata {
name: "AnyState".to_owned(),
on_enter: vec![],
on_exit: vec![],
},
)]),
transitions: vec![],
log_transitions: false,
}
}
}
impl StateMachine {
pub fn with_state<S: Clone + Component>(mut self) -> Self {
self.metadata_mut::<S>();
self
}
pub fn trans<S: MachineState>(
self,
trigger: impl Trigger,
state: impl Clone + Component,
) -> Self {
self.trans_builder(trigger, move |_: &S, _| Some(state.clone()))
}
fn metadata_mut<S: MachineState>(&mut self) -> &mut StateMetadata {
self.states
.entry(TypeId::of::<S>())
.or_insert(StateMetadata::new::<S>())
}
pub fn trans_builder<Prev: MachineState, Trig: Trigger, Next: Clone + Component>(
mut self,
trigger: Trig,
builder: impl 'static + Clone + Fn(&Prev, Trig::Ok) -> Option<Next> + Send + Sync,
) -> Self {
self.metadata_mut::<Prev>();
self.metadata_mut::<Next>();
let transition = TransitionImpl::<_, Prev, _, _>::new(trigger, builder);
self.transitions.push((
TypeId::of::<Prev>(),
Box::new(transition) as Box<dyn Transition>,
));
self
}
pub fn on_enter<S: MachineState>(
mut self,
on_enter: impl 'static + Fn(&mut EntityCommands) + Send + Sync,
) -> Self {
self.metadata_mut::<S>()
.on_enter
.push(OnEvent::Entity(Box::new(on_enter)));
self
}
pub fn on_exit<S: MachineState>(
mut self,
on_exit: impl 'static + Fn(&mut EntityCommands) + Send + Sync,
) -> Self {
self.metadata_mut::<S>()
.on_exit
.push(OnEvent::Entity(Box::new(on_exit)));
self
}
pub fn command_on_enter<S: MachineState>(
mut self,
command: impl Clone + Command + Sync,
) -> Self {
self.metadata_mut::<S>()
.on_enter
.push(OnEvent::Command(Box::new(command)));
self
}
pub fn command_on_exit<S: MachineState>(
mut self,
command: impl Clone + Command + Sync,
) -> Self {
self.metadata_mut::<S>()
.on_exit
.push(OnEvent::Command(Box::new(command)));
self
}
pub fn set_trans_logging(mut self, log_transitions: bool) -> Self {
self.log_transitions = log_transitions;
self
}
fn init_transitions(&mut self, world: &mut World) {
for (_, transition) in &mut self.transitions {
transition.init(world);
}
}
fn run(&mut self, world: &World, entity: Entity, commands: &mut Commands) {
let mut states = self.states.keys();
let current = states.find(|&&state| world.entity(entity).contains_type_id(state));
let Some(¤t) = current else {
panic!("Entity {entity:?} is in no state");
};
let from = &self.states[¤t];
if let Some(&other) = states.find(|&&state| world.entity(entity).contains_type_id(state)) {
let state = &from.name;
let other = &self.states[&other].name;
panic!("{entity:?} is in multiple states: {state} and {other}");
}
let Some((insert, next_state)) = self
.transitions
.iter_mut()
.filter(|(type_id, _)| *type_id == current || *type_id == TypeId::of::<AnyState>())
.find_map(|(_, transition)| transition.run(world, entity))
else { return };
let to = &self.states[&next_state];
for event in from.on_exit.iter() {
event.trigger(entity, commands);
}
insert.insert(&mut commands.entity(entity));
for event in to.on_enter.iter() {
event.trigger(entity, commands);
}
if self.log_transitions {
info!("{entity:?} transitioned from {} to {}", from.name, to.name);
}
}
fn stub(&self) -> Self {
Self {
states: default(),
log_transitions: false,
transitions: default(),
}
}
}
pub(crate) fn transition(
world: &mut World,
system_state: &mut SystemState<ParallelCommands>,
machine_query: &mut QueryState<(Entity, &mut StateMachine)>,
) {
let mut borrowed_machines: Vec<(Entity, StateMachine)> = machine_query
.iter_mut(world)
.map(|(entity, mut machine)| {
let stub = machine.stub();
(entity, std::mem::replace(machine.as_mut(), stub))
})
.collect();
for (_, machine) in borrowed_machines.iter_mut() {
machine.init_transitions(world);
}
let par_commands = system_state.get(world);
let task_pool = ComputeTaskPool::get();
borrowed_machines.par_splat_map_mut(task_pool, None, |chunk| {
for (entity, machine) in chunk {
par_commands.command_scope(|mut commands| machine.run(world, *entity, &mut commands));
}
});
for (entity, machine) in borrowed_machines {
*machine_query.get_mut(world, entity).unwrap().1 = machine;
}
system_state.apply(world);
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Component, Clone)]
struct StateOne;
#[derive(Component, Clone)]
struct StateTwo;
#[derive(Component, Clone)]
struct StateThree;
#[derive(Resource)]
struct SomeResource;
struct ResourcePresent;
impl BoolTrigger for ResourcePresent {
type Param<'w, 's> = Option<Res<'w, SomeResource>>;
fn trigger(&self, _entity: Entity, param: Self::Param<'_, '_>) -> bool {
param.is_some()
}
}
#[test]
fn test_sets_initial_state() {
let mut app = App::new();
app.add_systems(Update, transition);
let machine = StateMachine::default().with_state::<StateOne>();
let entity = app.world.spawn((machine, StateOne)).id();
app.update();
assert!(
app.world.get::<StateOne>(entity).is_some(),
"StateMachine should have the initial component"
);
}
#[test]
fn test_machine() {
let mut app = App::new();
app.add_systems(Update, transition);
let machine = StateMachine::default()
.trans::<StateOne>(AlwaysTrigger, StateTwo)
.trans::<StateTwo>(ResourcePresent, StateThree);
let entity = app.world.spawn((machine, StateOne)).id();
assert!(app.world.get::<StateOne>(entity).is_some());
app.update();
assert!(app.world.get::<StateOne>(entity).is_none());
assert!(app.world.get::<StateTwo>(entity).is_some());
app.update();
assert!(app.world.get::<StateTwo>(entity).is_some());
assert!(app.world.get::<StateThree>(entity).is_none());
app.world.insert_resource(SomeResource);
app.update();
assert!(app.world.get::<StateTwo>(entity).is_none());
assert!(app.world.get::<StateThree>(entity).is_some());
}
#[test]
fn test_self_transition() {
let mut app = App::new();
app.add_systems(Update, transition);
let entity = app
.world
.spawn((
StateMachine::default().trans::<StateOne>(AlwaysTrigger, StateOne),
StateOne,
))
.id();
app.update();
assert!(
app.world.get::<StateOne>(entity).is_some(),
"transitioning from a state to itself should work"
);
}
}