1use bevy_app::{App, MainScheduleOrder, Plugin, PreStartup, PreUpdate, SubApp};
2use bevy_ecs::{message::Messages, schedule::IntoScheduleConfigs, world::FromWorld};
3use bevy_utils::once;
4use log::warn;
5
6use crate::{
7 state::{
8 setup_state_transitions_in_world, ComputedStates, FreelyMutableState, NextState, State,
9 StateTransition, StateTransitionEvent, StateTransitionSystems, States, SubStates,
10 },
11 state_scoped::{despawn_entities_on_enter_state, despawn_entities_on_exit_state},
12};
13
14#[cfg(feature = "bevy_reflect")]
15use bevy_reflect::{FromReflect, GetTypeRegistration, Typed};
16
17pub trait AppExtStates {
19 fn init_state<S: FreelyMutableState + FromWorld>(&mut self) -> &mut Self;
35
36 fn insert_state<S: FreelyMutableState>(&mut self, state: S) -> &mut Self;
49
50 fn add_computed_state<S: ComputedStates>(&mut self) -> &mut Self;
54
55 fn add_sub_state<S: SubStates>(&mut self) -> &mut Self;
59
60 #[doc(hidden)]
64 #[deprecated(
65 since = "0.17.0",
66 note = "State scoped entities are enabled by default. This method does nothing anymore, you can safely remove it."
67 )]
68 fn enable_state_scoped_entities<S: States>(&mut self) -> &mut Self;
69
70 #[cfg(feature = "bevy_reflect")]
71 fn register_type_state<S>(&mut self) -> &mut Self
76 where
77 S: States + FromReflect + GetTypeRegistration + Typed;
78
79 #[cfg(feature = "bevy_reflect")]
80 fn register_type_mutable_state<S>(&mut self) -> &mut Self
86 where
87 S: FreelyMutableState + FromReflect + GetTypeRegistration + Typed;
88}
89
90fn warn_if_no_states_plugin_installed(app: &SubApp) {
92 if !app.is_plugin_added::<StatesPlugin>() {
93 once!(warn!(
94 "States were added to the app, but `StatesPlugin` is not installed."
95 ));
96 }
97}
98
99impl AppExtStates for SubApp {
100 fn init_state<S: FreelyMutableState + FromWorld>(&mut self) -> &mut Self {
101 warn_if_no_states_plugin_installed(self);
102 if !self.world().contains_resource::<State<S>>() {
103 self.init_resource::<State<S>>()
104 .init_resource::<NextState<S>>()
105 .add_message::<StateTransitionEvent<S>>();
106 let schedule = self.get_schedule_mut(StateTransition).expect(
107 "The `StateTransition` schedule is missing. Did you forget to add StatesPlugin or DefaultPlugins before calling init_state?"
108 );
109 S::register_state(schedule);
110 let state = self.world().resource::<State<S>>().get().clone();
111 self.world_mut().write_message(StateTransitionEvent {
112 exited: None,
113 entered: Some(state),
114 });
115 enable_state_scoped_entities::<S>(self);
116 } else {
117 let name = core::any::type_name::<S>();
118 warn!("State {name} is already initialized.");
119 }
120
121 self
122 }
123
124 fn insert_state<S: FreelyMutableState>(&mut self, state: S) -> &mut Self {
125 warn_if_no_states_plugin_installed(self);
126 if !self.world().contains_resource::<State<S>>() {
127 self.insert_resource::<State<S>>(State::new(state.clone()))
128 .init_resource::<NextState<S>>()
129 .add_message::<StateTransitionEvent<S>>();
130 let schedule = self.get_schedule_mut(StateTransition).expect(
131 "The `StateTransition` schedule is missing. Did you forget to add StatesPlugin or DefaultPlugins before calling insert_state?"
132 );
133 S::register_state(schedule);
134 self.world_mut().write_message(StateTransitionEvent {
135 exited: None,
136 entered: Some(state),
137 });
138 enable_state_scoped_entities::<S>(self);
139 } else {
140 self.insert_resource::<State<S>>(State::new(state.clone()));
142 self.world_mut()
143 .resource_mut::<Messages<StateTransitionEvent<S>>>()
144 .clear();
145 self.world_mut().write_message(StateTransitionEvent {
146 exited: None,
147 entered: Some(state),
148 });
149 }
150
151 self
152 }
153
154 fn add_computed_state<S: ComputedStates>(&mut self) -> &mut Self {
155 warn_if_no_states_plugin_installed(self);
156 if !self
157 .world()
158 .contains_resource::<Messages<StateTransitionEvent<S>>>()
159 {
160 self.add_message::<StateTransitionEvent<S>>();
161 let schedule = self.get_schedule_mut(StateTransition).expect(
162 "The `StateTransition` schedule is missing. Did you forget to add StatesPlugin or DefaultPlugins before calling add_computed_state?"
163 );
164 S::register_computed_state_systems(schedule);
165 let state = self
166 .world()
167 .get_resource::<State<S>>()
168 .map(|s| s.get().clone());
169 self.world_mut().write_message(StateTransitionEvent {
170 exited: None,
171 entered: state,
172 });
173 enable_state_scoped_entities::<S>(self);
174 } else {
175 let name = core::any::type_name::<S>();
176 warn!("Computed state {name} is already initialized.");
177 }
178
179 self
180 }
181
182 fn add_sub_state<S: SubStates>(&mut self) -> &mut Self {
183 warn_if_no_states_plugin_installed(self);
184 if !self
185 .world()
186 .contains_resource::<Messages<StateTransitionEvent<S>>>()
187 {
188 self.init_resource::<NextState<S>>();
189 self.add_message::<StateTransitionEvent<S>>();
190 let schedule = self.get_schedule_mut(StateTransition).expect(
191 "The `StateTransition` schedule is missing. Did you forget to add StatesPlugin or DefaultPlugins before calling add_sub_state?"
192 );
193 S::register_sub_state_systems(schedule);
194 let state = self
195 .world()
196 .get_resource::<State<S>>()
197 .map(|s| s.get().clone());
198 self.world_mut().write_message(StateTransitionEvent {
199 exited: None,
200 entered: state,
201 });
202 enable_state_scoped_entities::<S>(self);
203 } else {
204 let name = core::any::type_name::<S>();
205 warn!("Sub state {name} is already initialized.");
206 }
207
208 self
209 }
210
211 #[doc(hidden)]
212 fn enable_state_scoped_entities<S: States>(&mut self) -> &mut Self {
213 self
214 }
215
216 #[cfg(feature = "bevy_reflect")]
217 fn register_type_state<S>(&mut self) -> &mut Self
218 where
219 S: States + FromReflect + GetTypeRegistration + Typed,
220 {
221 self.register_type::<S>();
222 self.register_type::<State<S>>();
223 self.register_type_data::<S, crate::reflect::ReflectState>();
224 self
225 }
226
227 #[cfg(feature = "bevy_reflect")]
228 fn register_type_mutable_state<S>(&mut self) -> &mut Self
229 where
230 S: FreelyMutableState + FromReflect + GetTypeRegistration + Typed,
231 {
232 self.register_type::<S>();
233 self.register_type::<State<S>>();
234 self.register_type::<NextState<S>>();
235 self.register_type_data::<S, crate::reflect::ReflectState>();
236 self.register_type_data::<S, crate::reflect::ReflectFreelyMutableState>();
237 self
238 }
239}
240
241fn enable_state_scoped_entities<S: States>(app: &mut SubApp) {
242 if !app
243 .world()
244 .contains_resource::<Messages<StateTransitionEvent<S>>>()
245 {
246 let name = core::any::type_name::<S>();
247 warn!("State scoped entities are enabled for state `{name}`, but the state wasn't initialized in the app!");
248 }
249
250 app.add_systems(
254 StateTransition,
255 despawn_entities_on_exit_state::<S>.in_set(StateTransitionSystems::ExitSchedules),
256 )
257 .add_systems(
261 StateTransition,
262 despawn_entities_on_enter_state::<S>.in_set(StateTransitionSystems::EnterSchedules),
263 );
264}
265
266impl AppExtStates for App {
267 fn init_state<S: FreelyMutableState + FromWorld>(&mut self) -> &mut Self {
268 self.main_mut().init_state::<S>();
269 self
270 }
271
272 fn insert_state<S: FreelyMutableState>(&mut self, state: S) -> &mut Self {
273 self.main_mut().insert_state::<S>(state);
274 self
275 }
276
277 fn add_computed_state<S: ComputedStates>(&mut self) -> &mut Self {
278 self.main_mut().add_computed_state::<S>();
279 self
280 }
281
282 fn add_sub_state<S: SubStates>(&mut self) -> &mut Self {
283 self.main_mut().add_sub_state::<S>();
284 self
285 }
286
287 #[doc(hidden)]
288 fn enable_state_scoped_entities<S: States>(&mut self) -> &mut Self {
289 self
290 }
291
292 #[cfg(feature = "bevy_reflect")]
293 fn register_type_state<S>(&mut self) -> &mut Self
294 where
295 S: States + FromReflect + GetTypeRegistration + Typed,
296 {
297 self.main_mut().register_type_state::<S>();
298 self
299 }
300
301 #[cfg(feature = "bevy_reflect")]
302 fn register_type_mutable_state<S>(&mut self) -> &mut Self
303 where
304 S: FreelyMutableState + FromReflect + GetTypeRegistration + Typed,
305 {
306 self.main_mut().register_type_mutable_state::<S>();
307 self
308 }
309}
310
311#[derive(Default)]
313pub struct StatesPlugin;
314
315impl Plugin for StatesPlugin {
316 fn build(&self, app: &mut App) {
317 let mut schedule = app.world_mut().resource_mut::<MainScheduleOrder>();
318 schedule.insert_after(PreUpdate, StateTransition);
319 schedule.insert_startup_before(PreStartup, StateTransition);
320 setup_state_transitions_in_world(app.world_mut());
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use crate::{
327 app::StatesPlugin,
328 state::{State, StateTransition, StateTransitionEvent},
329 };
330 use bevy_app::App;
331 use bevy_ecs::message::Messages;
332 use bevy_state_macros::States;
333
334 use super::AppExtStates;
335
336 #[derive(States, Default, PartialEq, Eq, Hash, Debug, Clone)]
337 enum TestState {
338 #[default]
339 A,
340 B,
341 C,
342 }
343
344 #[test]
345 fn insert_state_can_overwrite_init_state() {
346 let mut app = App::new();
347 app.add_plugins(StatesPlugin);
348
349 app.init_state::<TestState>();
350 app.insert_state(TestState::B);
351
352 let world = app.world_mut();
353 world.run_schedule(StateTransition);
354
355 assert_eq!(world.resource::<State<TestState>>().0, TestState::B);
356 let events = world.resource::<Messages<StateTransitionEvent<TestState>>>();
357 assert_eq!(events.len(), 1);
358 let mut reader = events.get_cursor();
359 let last = reader.read(events).last().unwrap();
360 assert_eq!(last.exited, None);
361 assert_eq!(last.entered, Some(TestState::B));
362 }
363
364 #[test]
365 fn insert_state_can_overwrite_insert_state() {
366 let mut app = App::new();
367 app.add_plugins(StatesPlugin);
368
369 app.insert_state(TestState::B);
370 app.insert_state(TestState::C);
371
372 let world = app.world_mut();
373 world.run_schedule(StateTransition);
374
375 assert_eq!(world.resource::<State<TestState>>().0, TestState::C);
376 let events = world.resource::<Messages<StateTransitionEvent<TestState>>>();
377 assert_eq!(events.len(), 1);
378 let mut reader = events.get_cursor();
379 let last = reader.read(events).last().unwrap();
380 assert_eq!(last.exited, None);
381 assert_eq!(last.entered, Some(TestState::C));
382 }
383}