Skip to main content

bevy_state/
state_scoped_events.rs

1use alloc::vec::Vec;
2use core::marker::PhantomData;
3
4use bevy_app::{App, SubApp};
5use bevy_ecs::{
6    message::{Message, MessageReader, Messages},
7    resource::Resource,
8    system::Commands,
9    world::World,
10};
11use bevy_platform::collections::HashMap;
12
13use crate::state::{OnEnter, OnExit, StateTransitionEvent, States};
14
15fn clear_message_queue<M: Message>(w: &mut World) {
16    if let Some(mut queue) = w.get_resource_mut::<Messages<M>>() {
17        queue.clear();
18    }
19}
20
21#[derive(Copy, Clone)]
22enum TransitionType {
23    OnExit,
24    OnEnter,
25}
26
27#[derive(Resource)]
28struct StateScopedMessages<S: States> {
29    /// Keeps track of which messages need to be reset when the state is exited.
30    on_exit: HashMap<S, Vec<fn(&mut World)>>,
31    /// Keeps track of which messages need to be reset when the state is entered.
32    on_enter: HashMap<S, Vec<fn(&mut World)>>,
33}
34
35impl<S: States> StateScopedMessages<S> {
36    fn add_message<M: Message>(&mut self, state: S, transition_type: TransitionType) {
37        let map = match transition_type {
38            TransitionType::OnExit => &mut self.on_exit,
39            TransitionType::OnEnter => &mut self.on_enter,
40        };
41        map.entry(state).or_default().push(clear_message_queue::<M>);
42    }
43
44    fn cleanup(&self, w: &mut World, state: S, transition_type: TransitionType) {
45        let map = match transition_type {
46            TransitionType::OnExit => &self.on_exit,
47            TransitionType::OnEnter => &self.on_enter,
48        };
49        let Some(fns) = map.get(&state) else {
50            return;
51        };
52        for callback in fns {
53            (*callback)(w);
54        }
55    }
56}
57
58impl<S: States> Default for StateScopedMessages<S> {
59    fn default() -> Self {
60        Self {
61            on_exit: HashMap::default(),
62            on_enter: HashMap::default(),
63        }
64    }
65}
66
67fn clear_messages_on_exit<S: States>(
68    mut c: Commands,
69    mut transitions: MessageReader<StateTransitionEvent<S>>,
70) {
71    let Some(transition) = transitions.read().last() else {
72        return;
73    };
74    if transition.entered == transition.exited {
75        return;
76    }
77    let Some(exited) = transition.exited.clone() else {
78        return;
79    };
80
81    c.queue(move |w: &mut World| {
82        w.resource_scope::<StateScopedMessages<S>, ()>(|w, messages| {
83            messages.cleanup(w, exited, TransitionType::OnExit);
84        });
85    });
86}
87
88fn clear_messages_on_enter<S: States>(
89    mut c: Commands,
90    mut transitions: MessageReader<StateTransitionEvent<S>>,
91) {
92    let Some(transition) = transitions.read().last() else {
93        return;
94    };
95    if transition.entered == transition.exited {
96        return;
97    }
98    let Some(entered) = transition.entered.clone() else {
99        return;
100    };
101
102    c.queue(move |w: &mut World| {
103        w.resource_scope::<StateScopedMessages<S>, ()>(|w, messages| {
104            messages.cleanup(w, entered, TransitionType::OnEnter);
105        });
106    });
107}
108
109fn clear_messages_on_state_transition<M: Message, S: States>(
110    app: &mut SubApp,
111    _p: PhantomData<M>,
112    state: S,
113    transition_type: TransitionType,
114) {
115    if !app.world().contains_resource::<StateScopedMessages<S>>() {
116        app.init_resource::<StateScopedMessages<S>>();
117    }
118    app.world_mut()
119        .resource_mut::<StateScopedMessages<S>>()
120        .add_message::<M>(state.clone(), transition_type);
121    match transition_type {
122        TransitionType::OnExit => app.add_systems(OnExit(state), clear_messages_on_exit::<S>),
123        TransitionType::OnEnter => app.add_systems(OnEnter(state), clear_messages_on_enter::<S>),
124    };
125}
126
127/// Extension trait for [`App`] adding methods for registering state scoped messages.
128pub trait StateScopedMessagesAppExt {
129    /// Clears a [`Message`] when exiting the specified `state`.
130    ///
131    /// Note that message cleanup is ambiguously ordered relative to
132    /// [`DespawnOnExit`](crate::prelude::DespawnOnExit) entity cleanup,
133    /// and the [`OnExit`] schedule for the target state.
134    /// All of these (state scoped entities and messages cleanup, and `OnExit`)
135    /// occur within schedule [`StateTransition`](crate::prelude::StateTransition)
136    /// and system set `StateTransitionSystems::ExitSchedules`.
137    fn clear_messages_on_exit<M: Message>(&mut self, state: impl States) -> &mut Self;
138
139    /// Clears a [`Message`] when entering the specified `state`.
140    ///
141    /// Note that message cleanup is ambiguously ordered relative to
142    /// [`DespawnOnEnter`](crate::prelude::DespawnOnEnter) entity cleanup,
143    /// and the [`OnEnter`] schedule for the target state.
144    /// All of these (state scoped entities and messages cleanup, and `OnEnter`)
145    /// occur within schedule [`StateTransition`](crate::prelude::StateTransition)
146    /// and system set `StateTransitionSystems::EnterSchedules`.
147    fn clear_messages_on_enter<M: Message>(&mut self, state: impl States) -> &mut Self;
148}
149
150impl StateScopedMessagesAppExt for App {
151    fn clear_messages_on_exit<M: Message>(&mut self, state: impl States) -> &mut Self {
152        clear_messages_on_state_transition(
153            self.main_mut(),
154            PhantomData::<M>,
155            state,
156            TransitionType::OnExit,
157        );
158        self
159    }
160
161    fn clear_messages_on_enter<M: Message>(&mut self, state: impl States) -> &mut Self {
162        clear_messages_on_state_transition(
163            self.main_mut(),
164            PhantomData::<M>,
165            state,
166            TransitionType::OnEnter,
167        );
168        self
169    }
170}
171
172impl StateScopedMessagesAppExt for SubApp {
173    fn clear_messages_on_exit<M: Message>(&mut self, state: impl States) -> &mut Self {
174        clear_messages_on_state_transition(self, PhantomData::<M>, state, TransitionType::OnExit);
175        self
176    }
177
178    fn clear_messages_on_enter<M: Message>(&mut self, state: impl States) -> &mut Self {
179        clear_messages_on_state_transition(self, PhantomData::<M>, state, TransitionType::OnEnter);
180        self
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use crate::app::StatesPlugin;
188    use bevy_ecs::message::Message;
189    use bevy_state::prelude::*;
190
191    #[derive(States, Default, Clone, Hash, Eq, PartialEq, Debug)]
192    enum TestState {
193        #[default]
194        A,
195        B,
196    }
197
198    #[derive(Message, Debug)]
199    struct StandardMessage;
200
201    #[derive(Message, Debug)]
202    struct StateScopedMessage;
203
204    #[test]
205    fn clear_message_on_exit_state() {
206        let mut app = App::new();
207        app.add_plugins(StatesPlugin);
208        app.init_state::<TestState>();
209
210        app.add_message::<StandardMessage>();
211        app.add_message::<StateScopedMessage>()
212            .clear_messages_on_exit::<StateScopedMessage>(TestState::A);
213
214        app.world_mut().write_message(StandardMessage).unwrap();
215        app.world_mut().write_message(StateScopedMessage).unwrap();
216        assert!(!app
217            .world()
218            .resource::<Messages<StandardMessage>>()
219            .is_empty());
220        assert!(!app
221            .world()
222            .resource::<Messages<StateScopedMessage>>()
223            .is_empty());
224
225        app.world_mut()
226            .resource_mut::<NextState<TestState>>()
227            .set(TestState::B);
228        app.update();
229
230        assert!(!app
231            .world()
232            .resource::<Messages<StandardMessage>>()
233            .is_empty());
234        assert!(app
235            .world()
236            .resource::<Messages<StateScopedMessage>>()
237            .is_empty());
238    }
239
240    #[test]
241    fn clear_message_on_enter_state() {
242        let mut app = App::new();
243        app.add_plugins(StatesPlugin);
244        app.init_state::<TestState>();
245
246        app.add_message::<StandardMessage>();
247        app.add_message::<StateScopedMessage>()
248            .clear_messages_on_enter::<StateScopedMessage>(TestState::B);
249
250        app.world_mut().write_message(StandardMessage).unwrap();
251        app.world_mut().write_message(StateScopedMessage).unwrap();
252        assert!(!app
253            .world()
254            .resource::<Messages<StandardMessage>>()
255            .is_empty());
256        assert!(!app
257            .world()
258            .resource::<Messages<StateScopedMessage>>()
259            .is_empty());
260
261        app.world_mut()
262            .resource_mut::<NextState<TestState>>()
263            .set(TestState::B);
264        app.update();
265
266        assert!(!app
267            .world()
268            .resource::<Messages<StandardMessage>>()
269            .is_empty());
270        assert!(app
271            .world()
272            .resource::<Messages<StateScopedMessage>>()
273            .is_empty());
274    }
275}