Skip to main content

rmv_bevy_testing_tools/
messages.rs

1use std::marker::PhantomData;
2
3use bevy_app::{App, Plugin, PostUpdate};
4use bevy_derive::{Deref, DerefMut};
5use bevy_ecs::{
6    message::{Message, MessageReader},
7    resource::Resource,
8    system::ResMut,
9};
10
11#[derive(Debug, Deref, DerefMut, Resource)]
12pub struct CollectedMessages<E>(Vec<E>);
13
14impl<E: Message> CollectedMessages<E> {
15    pub fn get(&self) -> &Vec<E> {
16        &self.0
17    }
18}
19
20impl<E: Message> Default for CollectedMessages<E> {
21    fn default() -> Self {
22        Self(Vec::new())
23    }
24}
25
26#[derive(Debug)]
27pub struct MessageCollectorPlugin<E>(PhantomData<E>)
28where
29    E: Message + Clone;
30
31impl<E: Message + Clone> Default for MessageCollectorPlugin<E> {
32    fn default() -> Self {
33        Self(PhantomData)
34    }
35}
36
37impl<E: Message + Clone> Plugin for MessageCollectorPlugin<E> {
38    #[cfg_attr(coverage_nightly, coverage(off))]
39    fn build(&self, app: &mut App) {
40        app.add_message::<E>()
41            .init_resource::<CollectedMessages<E>>()
42            .add_systems(
43                PostUpdate,
44                |mut messages: MessageReader<E>, mut collection: ResMut<CollectedMessages<E>>| {
45                    collection.extend(messages.read().cloned());
46                },
47            );
48    }
49}
50
51#[derive(Debug)]
52pub enum MessageFilterPlugin<E>
53where
54    E: Message + Clone + PartialEq,
55{
56    Only(E),
57    AnyOf(Vec<E>),
58}
59
60impl<E: Message + Clone + PartialEq> Plugin for MessageFilterPlugin<E> {
61    #[cfg_attr(coverage_nightly, coverage(off))]
62    fn build(&self, app: &mut App) {
63        app.add_message::<E>()
64            .init_resource::<CollectedMessages<E>>();
65        match &self {
66            MessageFilterPlugin::Only(message) => {
67                app.add_systems(PostUpdate, {
68                    let message = message.clone();
69                    move |mut messages: MessageReader<E>,
70                          mut collection: ResMut<CollectedMessages<E>>| {
71                        collection.extend(messages.read().filter(|ev| *ev == &message).cloned());
72                    }
73                });
74            }
75            MessageFilterPlugin::AnyOf(any_of_messages) => {
76                app.add_systems(PostUpdate, {
77                    let any_of_messages = any_of_messages.clone();
78                    move |mut messages: MessageReader<E>,
79                          mut collection: ResMut<CollectedMessages<E>>| {
80                        collection.extend(
81                            messages
82                                .read()
83                                .filter(|ev| any_of_messages.contains(ev))
84                                .cloned(),
85                        );
86                    }
87                });
88            }
89        }
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use std::str::FromStr;
96
97    use bevy_app::Update;
98    use bevy_ecs::message::MessageWriter;
99    use rstest::*;
100    use speculoos::prelude::*;
101
102    use super::*;
103    use crate::{fixtures::minimal_test_app, test_app::TestApp, traits::CollectMessages};
104
105    #[rstest]
106    fn test_collected_messages_default_deref() {
107        let collected_messages: CollectedMessages<CmpMessage> = CollectedMessages::default();
108        let v1: &Vec<_> = &*collected_messages;
109        let v2: &Vec<_> = collected_messages.get();
110        assert_that!(v1).is_equal_to(v2);
111    }
112
113    #[derive(Clone, Copy, Debug, Message)]
114    struct NonEqMessage;
115
116    #[rstest]
117    #[case(0)]
118    #[case(1)]
119    #[case(10)]
120    fn test_message_collector_plugin(
121        #[from(minimal_test_app)]
122        #[with(MessageCollectorPlugin::<NonEqMessage>::default())]
123        mut app: TestApp,
124        #[case] emit_count: usize,
125    ) {
126        use crate::traits::CollectMessages;
127
128        app.add_systems(Update, move |mut writer: MessageWriter<NonEqMessage>| {
129            for _ in 0..emit_count {
130                writer.write(NonEqMessage);
131            }
132        });
133
134        app.update();
135
136        assert_that!(app.get_collected_messages::<NonEqMessage>())
137            .is_some()
138            .has_length(emit_count);
139    }
140
141    #[derive(Clone, Debug, Message, PartialEq)]
142    enum CmpMessage {
143        A,
144        B,
145        C,
146    }
147
148    #[rstest]
149    #[case("ABCA", "A", "AA")]
150    #[case("BCAB", "B", "BB")]
151    #[case("CABC", "C", "CC")]
152    fn test_message_filter_plugin_only(
153        #[case] messages_to_emit: MessageList<CmpMessage>,
154        #[case] only_message: CmpMessage,
155        #[case] expected_messages: MessageList<CmpMessage>,
156        #[from(minimal_test_app)]
157        #[with(MessageFilterPlugin::Only(only_message.clone()))]
158        mut app: TestApp,
159    ) {
160        app.add_systems(Update, move |mut writer: MessageWriter<CmpMessage>| {
161            for e in &*messages_to_emit {
162                writer.write(e.clone());
163            }
164        });
165
166        app.update();
167
168        let collected_messages = app.get_collected_messages::<CmpMessage>();
169        assert_that!(collected_messages)
170            .is_some()
171            .is_equal_to(&*expected_messages);
172
173        for e in &collected_messages.unwrap() {
174            assert_that!(e).is_equal_to(&only_message);
175        }
176    }
177
178    #[rstest]
179    #[case("AABBCC", "A", "AA")]
180    #[case("AABBCC", "B", "BB")]
181    #[case("AABBCC", "C", "CC")]
182    #[case("ABCCBA", "AB", "ABBA")]
183    #[case("ABCCBA", "AC", "ACCA")]
184    #[case("ABCCBA", "BC", "BCCB")]
185    #[case("AABBCC", "ABC", "AABBCC")]
186    fn test_message_filter_plugin_any_of(
187        #[case] messages_to_emit: MessageList<CmpMessage>,
188        #[case] any_of_messages: MessageList<CmpMessage>,
189        #[case] expected_messages: MessageList<CmpMessage>,
190        #[from(minimal_test_app)]
191        #[with(MessageFilterPlugin::AnyOf((*any_of_messages).clone()))]
192        mut app: TestApp,
193    ) {
194        use crate::traits::CollectMessages;
195
196        app.add_systems(Update, move |mut writer: MessageWriter<CmpMessage>| {
197            for e in &*messages_to_emit {
198                writer.write(e.clone());
199            }
200        });
201
202        app.update();
203
204        let collected_messages = app.get_collected_messages::<CmpMessage>();
205        assert_that!(collected_messages)
206            .is_some()
207            .is_equal_to(&*expected_messages);
208
209        for e in collected_messages.unwrap().into_iter() {
210            assert_that!(*any_of_messages).contains(e);
211        }
212    }
213
214    pub struct InvalidMessage;
215
216    impl FromStr for CmpMessage {
217        type Err = InvalidMessage;
218        fn from_str(s: &str) -> Result<Self, Self::Err> {
219            match s {
220                "A" => Ok(CmpMessage::A),
221                "B" => Ok(CmpMessage::B),
222                "C" => Ok(CmpMessage::C),
223                _ => Err(InvalidMessage),
224            }
225        }
226    }
227
228    #[rstest]
229    #[case("A", Some(CmpMessage::A))]
230    #[case("B", Some(CmpMessage::B))]
231    #[case("C", Some(CmpMessage::C))]
232    #[should_panic]
233    #[case("", None)]
234    #[should_panic]
235    #[case("D", None)]
236    #[should_panic]
237    #[case("more nonsense", None)]
238    fn test_filtered_message_fromstr(
239        #[case] magic: CmpMessage,
240        #[case] expected: Option<CmpMessage>,
241    ) {
242        assert_that!(magic).is_equal_to(expected.unwrap());
243    }
244
245    #[derive(Clone, Debug, Deref)]
246    struct MessageList<E: Message + Clone>(Vec<E>);
247
248    impl<E: Message + Clone + FromStr<Err = InvalidMessage>> FromStr for MessageList<E> {
249        type Err = InvalidMessage;
250        fn from_str(s: &str) -> Result<Self, Self::Err> {
251            let mut messages = Vec::new();
252            for c in s.chars() {
253                let e = E::from_str(&c.to_string())?;
254                messages.push(e);
255            }
256            Ok(MessageList(messages))
257        }
258    }
259
260    #[rstest]
261    #[case("A", vec![CmpMessage::A])]
262    #[case("AB", vec![CmpMessage::A, CmpMessage::B])]
263    #[case("ABC", vec![CmpMessage::A, CmpMessage::B, CmpMessage::C])]
264    #[case("AABBCC", vec![
265        CmpMessage::A, CmpMessage::A,
266        CmpMessage::B, CmpMessage::B,
267        CmpMessage::C, CmpMessage::C
268    ])]
269    #[should_panic]
270    #[case("abc", vec![])]
271    fn test_message_list_fromstr(
272        #[case] magic: MessageList<CmpMessage>,
273        #[case] expected: Vec<CmpMessage>,
274    ) {
275        assert_that!(*magic).is_equal_to(&expected);
276    }
277}