1use std::marker::PhantomData;
2
3use bevy_app::{App, Plugin, PostUpdate};
4use bevy_derive::{Deref, DerefMut};
5use bevy_ecs::{
6 event::{Event, EventReader},
7 resource::Resource,
8 system::ResMut,
9};
10
11#[derive(Debug, Resource, Deref, DerefMut)]
12pub struct CollectedEvents<E>(Vec<E>);
13
14impl<E: Event> CollectedEvents<E> {
15 pub fn get(&self) -> &Vec<E> {
16 &self.0
17 }
18}
19
20impl<E: Event> Default for CollectedEvents<E> {
21 fn default() -> Self {
22 Self(Vec::new())
23 }
24}
25
26#[derive(Debug)]
27pub struct EventCollectorPlugin<E>(PhantomData<E>)
28where
29 E: Event + Clone;
30
31impl<E: Event + Clone> Default for EventCollectorPlugin<E> {
32 fn default() -> Self {
33 Self(PhantomData)
34 }
35}
36
37impl<E: Event + Clone> Plugin for EventCollectorPlugin<E> {
38 #[cfg_attr(coverage_nightly, coverage(off))]
39 fn build(&self, app: &mut App) {
40 app.add_event::<E>()
41 .init_resource::<CollectedEvents<E>>()
42 .add_systems(
43 PostUpdate,
44 |mut events: EventReader<E>, mut collection: ResMut<CollectedEvents<E>>| {
45 collection.extend(events.read().cloned());
46 },
47 );
48 }
49}
50
51#[derive(Debug)]
52pub enum EventFilterPlugin<E>
53where
54 E: Event + Clone + PartialEq,
55{
56 Only(E),
57 AnyOf(Vec<E>),
58}
59
60impl<E: Event + Clone + PartialEq> Plugin for EventFilterPlugin<E> {
61 #[cfg_attr(coverage_nightly, coverage(off))]
62 fn build(&self, app: &mut App) {
63 app.add_event::<E>().init_resource::<CollectedEvents<E>>();
64 match &self {
65 EventFilterPlugin::Only(event) => {
66 app.add_systems(PostUpdate, {
67 let event = event.clone();
68 move |mut events: EventReader<E>, mut collection: ResMut<CollectedEvents<E>>| {
69 collection.extend(events.read().filter(|ev| *ev == &event).cloned());
70 }
71 });
72 }
73 EventFilterPlugin::AnyOf(ref any_of_events) => {
74 app.add_systems(PostUpdate, {
75 let any_of_events = any_of_events.clone();
76 move |mut events: EventReader<E>, mut collection: ResMut<CollectedEvents<E>>| {
77 collection.extend(
78 events
79 .read()
80 .filter(|ev| any_of_events.contains(ev))
81 .cloned(),
82 );
83 }
84 });
85 }
86 }
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use std::str::FromStr;
93
94 use bevy_app::Update;
95 use bevy_ecs::event::EventWriter;
96 use rstest::*;
97 use speculoos::prelude::*;
98
99 use super::*;
100 use crate::{fixtures::test_app, test_app::TestApp, traits::CollectEvents};
101
102 #[rstest]
103 fn test_collected_events_default_deref() {
104 let collected_events: CollectedEvents<CmpEvent> = CollectedEvents::default();
105 let v1: &Vec<_> = &*collected_events;
106 let v2: &Vec<_> = collected_events.get();
107 assert_that!(v1).is_equal_to(v2);
108 }
109
110 #[derive(Debug, Event, Copy, Clone)]
111 struct NonEqEvent;
112
113 #[rstest]
114 #[case(0)]
115 #[case(1)]
116 #[case(10)]
117 fn test_event_collector_plugin(
118 #[from(test_app)]
119 #[with(EventCollectorPlugin::<NonEqEvent>::default())]
120 mut app: TestApp,
121 #[case] emit_count: usize,
122 ) {
123 use crate::traits::CollectEvents;
124
125 app.add_systems(Update, move |mut writer: EventWriter<NonEqEvent>| {
126 for _ in 0..emit_count {
127 writer.send(NonEqEvent);
128 }
129 });
130
131 app.update();
132
133 assert_that!(app.get_collected_events::<NonEqEvent>())
134 .is_some()
135 .has_length(emit_count);
136 }
137
138 #[derive(Debug, Event, Clone, PartialEq)]
139 enum CmpEvent {
140 A,
141 B,
142 C,
143 }
144
145 #[rstest]
146 #[case("ABCA", "A", "AA")]
147 #[case("BCAB", "B", "BB")]
148 #[case("CABC", "C", "CC")]
149 fn test_event_filter_plugin_only(
150 #[case] events_to_emit: EventList<CmpEvent>,
151 #[case] only_event: CmpEvent,
152 #[case] expected_events: EventList<CmpEvent>,
153 #[from(test_app)]
154 #[with(EventFilterPlugin::Only(only_event.clone()))]
155 mut app: TestApp,
156 ) {
157 app.add_systems(Update, move |mut writer: EventWriter<CmpEvent>| {
158 for e in &*events_to_emit {
159 writer.send(e.clone());
160 }
161 });
162
163 app.update();
164
165 let collected_events = app.get_collected_events::<CmpEvent>();
166 assert_that!(collected_events)
167 .is_some()
168 .is_equal_to(&*expected_events);
169
170 for e in &collected_events.unwrap() {
171 assert_that!(e).is_equal_to(&only_event);
172 }
173 }
174
175 #[rstest]
176 #[case("AABBCC", "A", "AA")]
177 #[case("AABBCC", "B", "BB")]
178 #[case("AABBCC", "C", "CC")]
179 #[case("ABCCBA", "AB", "ABBA")]
180 #[case("ABCCBA", "AC", "ACCA")]
181 #[case("ABCCBA", "BC", "BCCB")]
182 #[case("AABBCC", "ABC", "AABBCC")]
183 fn test_event_filter_plugin_any_of(
184 #[case] events_to_emit: EventList<CmpEvent>,
185 #[case] any_of_events: EventList<CmpEvent>,
186 #[case] expected_events: EventList<CmpEvent>,
187 #[from(test_app)]
188 #[with(EventFilterPlugin::AnyOf((*any_of_events).clone()))]
189 mut app: TestApp,
190 ) {
191 use crate::traits::CollectEvents;
192
193 app.add_systems(Update, move |mut writer: EventWriter<CmpEvent>| {
194 for e in &*events_to_emit {
195 writer.send(e.clone());
196 }
197 });
198
199 app.update();
200
201 let collected_events = app.get_collected_events::<CmpEvent>();
202 assert_that!(collected_events)
203 .is_some()
204 .is_equal_to(&*expected_events);
205
206 for e in collected_events.unwrap().into_iter() {
207 assert_that!(*any_of_events).contains(e);
208 }
209 }
210
211 pub struct InvalidEvent;
212
213 impl FromStr for CmpEvent {
214 type Err = InvalidEvent;
215 fn from_str(s: &str) -> Result<Self, Self::Err> {
216 match s {
217 "A" => Ok(CmpEvent::A),
218 "B" => Ok(CmpEvent::B),
219 "C" => Ok(CmpEvent::C),
220 _ => Err(InvalidEvent),
221 }
222 }
223 }
224
225 #[rstest]
226 #[case("A", Some(CmpEvent::A))]
227 #[case("B", Some(CmpEvent::B))]
228 #[case("C", Some(CmpEvent::C))]
229 #[should_panic]
230 #[case("", None)]
231 #[should_panic]
232 #[case("D", None)]
233 #[should_panic]
234 #[case("more nonsense", None)]
235 fn test_filtered_event_fromstr(#[case] magic: CmpEvent, #[case] expected: Option<CmpEvent>) {
236 assert_that!(magic).is_equal_to(expected.unwrap());
237 }
238
239 #[derive(Debug, Clone, Deref)]
240 struct EventList<E: Event + Clone>(Vec<E>);
241
242 impl<E: Event + Clone + FromStr<Err = InvalidEvent>> FromStr for EventList<E> {
243 type Err = InvalidEvent;
244 fn from_str(s: &str) -> Result<Self, Self::Err> {
245 let mut events = Vec::new();
246 for c in s.chars() {
247 let e = E::from_str(&c.to_string())?;
248 events.push(e);
249 }
250 Ok(EventList(events))
251 }
252 }
253
254 #[rstest]
255 #[case("A", vec![CmpEvent::A])]
256 #[case("AB", vec![CmpEvent::A, CmpEvent::B])]
257 #[case("ABC", vec![CmpEvent::A, CmpEvent::B, CmpEvent::C])]
258 #[case("AABBCC", vec![
259 CmpEvent::A, CmpEvent::A,
260 CmpEvent::B, CmpEvent::B,
261 CmpEvent::C, CmpEvent::C
262 ])]
263 #[should_panic]
264 #[case("abc", vec![])]
265 fn test_event_list_fromstr(
266 #[case] magic: EventList<CmpEvent>,
267 #[case] expected: Vec<CmpEvent>,
268 ) {
269 assert_that!(*magic).is_equal_to(&expected);
270 }
271}