1use std::marker::PhantomData;
2
3use bevy_app::{App, Plugin, PostUpdate};
4use bevy_derive::{Deref, DerefMut};
5use bevy_ecs::{
6 message::{Message, MessageReader},
7 resource::Resource,
8 system::ResMut,
9};
10
11#[derive(Debug, Deref, DerefMut, Resource)]
12pub struct CollectedMessages<E>(Vec<E>);
13
14impl<E: Message> CollectedMessages<E> {
15 pub fn get(&self) -> &Vec<E> {
16 &self.0
17 }
18}
19
20impl<E: Message> Default for CollectedMessages<E> {
21 fn default() -> Self {
22 Self(Vec::new())
23 }
24}
25
26#[derive(Debug)]
27pub struct MessageCollectorPlugin<E>(PhantomData<E>)
28where
29 E: Message + Clone;
30
31impl<E: Message + Clone> Default for MessageCollectorPlugin<E> {
32 fn default() -> Self {
33 Self(PhantomData)
34 }
35}
36
37impl<E: Message + Clone> Plugin for MessageCollectorPlugin<E> {
38 #[cfg_attr(coverage_nightly, coverage(off))]
39 fn build(&self, app: &mut App) {
40 app.add_message::<E>()
41 .init_resource::<CollectedMessages<E>>()
42 .add_systems(
43 PostUpdate,
44 |mut messages: MessageReader<E>, mut collection: ResMut<CollectedMessages<E>>| {
45 collection.extend(messages.read().cloned());
46 },
47 );
48 }
49}
50
51#[derive(Debug)]
52pub enum MessageFilterPlugin<E>
53where
54 E: Message + Clone + PartialEq,
55{
56 Only(E),
57 AnyOf(Vec<E>),
58}
59
60impl<E: Message + Clone + PartialEq> Plugin for MessageFilterPlugin<E> {
61 #[cfg_attr(coverage_nightly, coverage(off))]
62 fn build(&self, app: &mut App) {
63 app.add_message::<E>()
64 .init_resource::<CollectedMessages<E>>();
65 match &self {
66 MessageFilterPlugin::Only(message) => {
67 app.add_systems(PostUpdate, {
68 let message = message.clone();
69 move |mut messages: MessageReader<E>,
70 mut collection: ResMut<CollectedMessages<E>>| {
71 collection.extend(messages.read().filter(|ev| *ev == &message).cloned());
72 }
73 });
74 }
75 MessageFilterPlugin::AnyOf(any_of_messages) => {
76 app.add_systems(PostUpdate, {
77 let any_of_messages = any_of_messages.clone();
78 move |mut messages: MessageReader<E>,
79 mut collection: ResMut<CollectedMessages<E>>| {
80 collection.extend(
81 messages
82 .read()
83 .filter(|ev| any_of_messages.contains(ev))
84 .cloned(),
85 );
86 }
87 });
88 }
89 }
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use std::str::FromStr;
96
97 use bevy_app::Update;
98 use bevy_ecs::message::MessageWriter;
99 use rstest::*;
100 use speculoos::prelude::*;
101
102 use super::*;
103 use crate::{fixtures::minimal_test_app, test_app::TestApp, traits::CollectMessages};
104
105 #[rstest]
106 fn test_collected_messages_default_deref() {
107 let collected_messages: CollectedMessages<CmpMessage> = CollectedMessages::default();
108 let v1: &Vec<_> = &*collected_messages;
109 let v2: &Vec<_> = collected_messages.get();
110 assert_that!(v1).is_equal_to(v2);
111 }
112
113 #[derive(Clone, Copy, Debug, Message)]
114 struct NonEqMessage;
115
116 #[rstest]
117 #[case(0)]
118 #[case(1)]
119 #[case(10)]
120 fn test_message_collector_plugin(
121 #[from(minimal_test_app)]
122 #[with(MessageCollectorPlugin::<NonEqMessage>::default())]
123 mut app: TestApp,
124 #[case] emit_count: usize,
125 ) {
126 use crate::traits::CollectMessages;
127
128 app.add_systems(Update, move |mut writer: MessageWriter<NonEqMessage>| {
129 for _ in 0..emit_count {
130 writer.write(NonEqMessage);
131 }
132 });
133
134 app.update();
135
136 assert_that!(app.get_collected_messages::<NonEqMessage>())
137 .is_some()
138 .has_length(emit_count);
139 }
140
141 #[derive(Clone, Debug, Message, PartialEq)]
142 enum CmpMessage {
143 A,
144 B,
145 C,
146 }
147
148 #[rstest]
149 #[case("ABCA", "A", "AA")]
150 #[case("BCAB", "B", "BB")]
151 #[case("CABC", "C", "CC")]
152 fn test_message_filter_plugin_only(
153 #[case] messages_to_emit: MessageList<CmpMessage>,
154 #[case] only_message: CmpMessage,
155 #[case] expected_messages: MessageList<CmpMessage>,
156 #[from(minimal_test_app)]
157 #[with(MessageFilterPlugin::Only(only_message.clone()))]
158 mut app: TestApp,
159 ) {
160 app.add_systems(Update, move |mut writer: MessageWriter<CmpMessage>| {
161 for e in &*messages_to_emit {
162 writer.write(e.clone());
163 }
164 });
165
166 app.update();
167
168 let collected_messages = app.get_collected_messages::<CmpMessage>();
169 assert_that!(collected_messages)
170 .is_some()
171 .is_equal_to(&*expected_messages);
172
173 for e in &collected_messages.unwrap() {
174 assert_that!(e).is_equal_to(&only_message);
175 }
176 }
177
178 #[rstest]
179 #[case("AABBCC", "A", "AA")]
180 #[case("AABBCC", "B", "BB")]
181 #[case("AABBCC", "C", "CC")]
182 #[case("ABCCBA", "AB", "ABBA")]
183 #[case("ABCCBA", "AC", "ACCA")]
184 #[case("ABCCBA", "BC", "BCCB")]
185 #[case("AABBCC", "ABC", "AABBCC")]
186 fn test_message_filter_plugin_any_of(
187 #[case] messages_to_emit: MessageList<CmpMessage>,
188 #[case] any_of_messages: MessageList<CmpMessage>,
189 #[case] expected_messages: MessageList<CmpMessage>,
190 #[from(minimal_test_app)]
191 #[with(MessageFilterPlugin::AnyOf((*any_of_messages).clone()))]
192 mut app: TestApp,
193 ) {
194 use crate::traits::CollectMessages;
195
196 app.add_systems(Update, move |mut writer: MessageWriter<CmpMessage>| {
197 for e in &*messages_to_emit {
198 writer.write(e.clone());
199 }
200 });
201
202 app.update();
203
204 let collected_messages = app.get_collected_messages::<CmpMessage>();
205 assert_that!(collected_messages)
206 .is_some()
207 .is_equal_to(&*expected_messages);
208
209 for e in collected_messages.unwrap().into_iter() {
210 assert_that!(*any_of_messages).contains(e);
211 }
212 }
213
214 pub struct InvalidMessage;
215
216 impl FromStr for CmpMessage {
217 type Err = InvalidMessage;
218 fn from_str(s: &str) -> Result<Self, Self::Err> {
219 match s {
220 "A" => Ok(CmpMessage::A),
221 "B" => Ok(CmpMessage::B),
222 "C" => Ok(CmpMessage::C),
223 _ => Err(InvalidMessage),
224 }
225 }
226 }
227
228 #[rstest]
229 #[case("A", Some(CmpMessage::A))]
230 #[case("B", Some(CmpMessage::B))]
231 #[case("C", Some(CmpMessage::C))]
232 #[should_panic]
233 #[case("", None)]
234 #[should_panic]
235 #[case("D", None)]
236 #[should_panic]
237 #[case("more nonsense", None)]
238 fn test_filtered_message_fromstr(
239 #[case] magic: CmpMessage,
240 #[case] expected: Option<CmpMessage>,
241 ) {
242 assert_that!(magic).is_equal_to(expected.unwrap());
243 }
244
245 #[derive(Clone, Debug, Deref)]
246 struct MessageList<E: Message + Clone>(Vec<E>);
247
248 impl<E: Message + Clone + FromStr<Err = InvalidMessage>> FromStr for MessageList<E> {
249 type Err = InvalidMessage;
250 fn from_str(s: &str) -> Result<Self, Self::Err> {
251 let mut messages = Vec::new();
252 for c in s.chars() {
253 let e = E::from_str(&c.to_string())?;
254 messages.push(e);
255 }
256 Ok(MessageList(messages))
257 }
258 }
259
260 #[rstest]
261 #[case("A", vec![CmpMessage::A])]
262 #[case("AB", vec![CmpMessage::A, CmpMessage::B])]
263 #[case("ABC", vec![CmpMessage::A, CmpMessage::B, CmpMessage::C])]
264 #[case("AABBCC", vec![
265 CmpMessage::A, CmpMessage::A,
266 CmpMessage::B, CmpMessage::B,
267 CmpMessage::C, CmpMessage::C
268 ])]
269 #[should_panic]
270 #[case("abc", vec![])]
271 fn test_message_list_fromstr(
272 #[case] magic: MessageList<CmpMessage>,
273 #[case] expected: Vec<CmpMessage>,
274 ) {
275 assert_that!(*magic).is_equal_to(&expected);
276 }
277}