bevy_state/
state_scoped_events.rs1use alloc::vec::Vec;
2use core::marker::PhantomData;
3
4use bevy_app::{App, SubApp};
5use bevy_ecs::{
6 message::{Message, MessageReader, Messages},
7 resource::Resource,
8 system::Commands,
9 world::World,
10};
11use bevy_platform::collections::HashMap;
12
13use crate::state::{OnEnter, OnExit, StateTransitionEvent, States};
14
15fn clear_message_queue<M: Message>(w: &mut World) {
16 if let Some(mut queue) = w.get_resource_mut::<Messages<M>>() {
17 queue.clear();
18 }
19}
20
21#[derive(Copy, Clone)]
22enum TransitionType {
23 OnExit,
24 OnEnter,
25}
26
27#[derive(Resource)]
28struct StateScopedMessages<S: States> {
29 on_exit: HashMap<S, Vec<fn(&mut World)>>,
31 on_enter: HashMap<S, Vec<fn(&mut World)>>,
33}
34
35impl<S: States> StateScopedMessages<S> {
36 fn add_message<M: Message>(&mut self, state: S, transition_type: TransitionType) {
37 let map = match transition_type {
38 TransitionType::OnExit => &mut self.on_exit,
39 TransitionType::OnEnter => &mut self.on_enter,
40 };
41 map.entry(state).or_default().push(clear_message_queue::<M>);
42 }
43
44 fn cleanup(&self, w: &mut World, state: S, transition_type: TransitionType) {
45 let map = match transition_type {
46 TransitionType::OnExit => &self.on_exit,
47 TransitionType::OnEnter => &self.on_enter,
48 };
49 let Some(fns) = map.get(&state) else {
50 return;
51 };
52 for callback in fns {
53 (*callback)(w);
54 }
55 }
56}
57
58impl<S: States> Default for StateScopedMessages<S> {
59 fn default() -> Self {
60 Self {
61 on_exit: HashMap::default(),
62 on_enter: HashMap::default(),
63 }
64 }
65}
66
67fn clear_messages_on_exit<S: States>(
68 mut c: Commands,
69 mut transitions: MessageReader<StateTransitionEvent<S>>,
70) {
71 let Some(transition) = transitions.read().last() else {
72 return;
73 };
74 if transition.entered == transition.exited {
75 return;
76 }
77 let Some(exited) = transition.exited.clone() else {
78 return;
79 };
80
81 c.queue(move |w: &mut World| {
82 w.resource_scope::<StateScopedMessages<S>, ()>(|w, messages| {
83 messages.cleanup(w, exited, TransitionType::OnExit);
84 });
85 });
86}
87
88fn clear_messages_on_enter<S: States>(
89 mut c: Commands,
90 mut transitions: MessageReader<StateTransitionEvent<S>>,
91) {
92 let Some(transition) = transitions.read().last() else {
93 return;
94 };
95 if transition.entered == transition.exited {
96 return;
97 }
98 let Some(entered) = transition.entered.clone() else {
99 return;
100 };
101
102 c.queue(move |w: &mut World| {
103 w.resource_scope::<StateScopedMessages<S>, ()>(|w, messages| {
104 messages.cleanup(w, entered, TransitionType::OnEnter);
105 });
106 });
107}
108
109fn clear_messages_on_state_transition<M: Message, S: States>(
110 app: &mut SubApp,
111 _p: PhantomData<M>,
112 state: S,
113 transition_type: TransitionType,
114) {
115 if !app.world().contains_resource::<StateScopedMessages<S>>() {
116 app.init_resource::<StateScopedMessages<S>>();
117 }
118 app.world_mut()
119 .resource_mut::<StateScopedMessages<S>>()
120 .add_message::<M>(state.clone(), transition_type);
121 match transition_type {
122 TransitionType::OnExit => app.add_systems(OnExit(state), clear_messages_on_exit::<S>),
123 TransitionType::OnEnter => app.add_systems(OnEnter(state), clear_messages_on_enter::<S>),
124 };
125}
126
127pub trait StateScopedMessagesAppExt {
129 fn clear_messages_on_exit<M: Message>(&mut self, state: impl States) -> &mut Self;
138
139 fn clear_messages_on_enter<M: Message>(&mut self, state: impl States) -> &mut Self;
148}
149
150impl StateScopedMessagesAppExt for App {
151 fn clear_messages_on_exit<M: Message>(&mut self, state: impl States) -> &mut Self {
152 clear_messages_on_state_transition(
153 self.main_mut(),
154 PhantomData::<M>,
155 state,
156 TransitionType::OnExit,
157 );
158 self
159 }
160
161 fn clear_messages_on_enter<M: Message>(&mut self, state: impl States) -> &mut Self {
162 clear_messages_on_state_transition(
163 self.main_mut(),
164 PhantomData::<M>,
165 state,
166 TransitionType::OnEnter,
167 );
168 self
169 }
170}
171
172impl StateScopedMessagesAppExt for SubApp {
173 fn clear_messages_on_exit<M: Message>(&mut self, state: impl States) -> &mut Self {
174 clear_messages_on_state_transition(self, PhantomData::<M>, state, TransitionType::OnExit);
175 self
176 }
177
178 fn clear_messages_on_enter<M: Message>(&mut self, state: impl States) -> &mut Self {
179 clear_messages_on_state_transition(self, PhantomData::<M>, state, TransitionType::OnEnter);
180 self
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use crate::app::StatesPlugin;
188 use bevy_ecs::message::Message;
189 use bevy_state::prelude::*;
190
191 #[derive(States, Default, Clone, Hash, Eq, PartialEq, Debug)]
192 enum TestState {
193 #[default]
194 A,
195 B,
196 }
197
198 #[derive(Message, Debug)]
199 struct StandardMessage;
200
201 #[derive(Message, Debug)]
202 struct StateScopedMessage;
203
204 #[test]
205 fn clear_message_on_exit_state() {
206 let mut app = App::new();
207 app.add_plugins(StatesPlugin);
208 app.init_state::<TestState>();
209
210 app.add_message::<StandardMessage>();
211 app.add_message::<StateScopedMessage>()
212 .clear_messages_on_exit::<StateScopedMessage>(TestState::A);
213
214 app.world_mut().write_message(StandardMessage).unwrap();
215 app.world_mut().write_message(StateScopedMessage).unwrap();
216 assert!(!app
217 .world()
218 .resource::<Messages<StandardMessage>>()
219 .is_empty());
220 assert!(!app
221 .world()
222 .resource::<Messages<StateScopedMessage>>()
223 .is_empty());
224
225 app.world_mut()
226 .resource_mut::<NextState<TestState>>()
227 .set(TestState::B);
228 app.update();
229
230 assert!(!app
231 .world()
232 .resource::<Messages<StandardMessage>>()
233 .is_empty());
234 assert!(app
235 .world()
236 .resource::<Messages<StateScopedMessage>>()
237 .is_empty());
238 }
239
240 #[test]
241 fn clear_message_on_enter_state() {
242 let mut app = App::new();
243 app.add_plugins(StatesPlugin);
244 app.init_state::<TestState>();
245
246 app.add_message::<StandardMessage>();
247 app.add_message::<StateScopedMessage>()
248 .clear_messages_on_enter::<StateScopedMessage>(TestState::B);
249
250 app.world_mut().write_message(StandardMessage).unwrap();
251 app.world_mut().write_message(StateScopedMessage).unwrap();
252 assert!(!app
253 .world()
254 .resource::<Messages<StandardMessage>>()
255 .is_empty());
256 assert!(!app
257 .world()
258 .resource::<Messages<StateScopedMessage>>()
259 .is_empty());
260
261 app.world_mut()
262 .resource_mut::<NextState<TestState>>()
263 .set(TestState::B);
264 app.update();
265
266 assert!(!app
267 .world()
268 .resource::<Messages<StandardMessage>>()
269 .is_empty());
270 assert!(app
271 .world()
272 .resource::<Messages<StateScopedMessage>>()
273 .is_empty());
274 }
275}