1use std::marker::PhantomData;
2
3use bevy_app::{App, Plugin, PostUpdate};
4use bevy_derive::{Deref, DerefMut};
5use bevy_ecs::{
6 event::{Event, EventReader},
7 system::{ResMut, Resource},
8};
9
10#[derive(Debug, Resource, Deref, DerefMut)]
11pub struct CollectedEvents<E>(Vec<E>);
12
13impl<E: Event> CollectedEvents<E> {
14 pub fn get(&self) -> &Vec<E> {
15 &self.0
16 }
17}
18
19impl<E: Event> Default for CollectedEvents<E> {
20 fn default() -> Self {
21 Self(Vec::new())
22 }
23}
24
25#[derive(Debug)]
26pub struct EventCollectorPlugin<E>(PhantomData<E>)
27where
28 E: Event + Clone;
29
30impl<E: Event + Clone> Default for EventCollectorPlugin<E> {
31 fn default() -> Self {
32 Self(PhantomData)
33 }
34}
35
36impl<E: Event + Clone> Plugin for EventCollectorPlugin<E> {
37 #[cfg_attr(coverage_nightly, coverage(off))]
38 fn build(&self, app: &mut App) {
39 app.add_event::<E>()
40 .init_resource::<CollectedEvents<E>>()
41 .add_systems(
42 PostUpdate,
43 |mut events: EventReader<E>, mut collection: ResMut<CollectedEvents<E>>| {
44 collection.extend(events.read().cloned());
45 },
46 );
47 }
48}
49
50#[derive(Debug)]
51pub enum EventFilterPlugin<E>
52where
53 E: Event + Clone + PartialEq,
54{
55 Only(E),
56 AnyOf(Vec<E>),
57}
58
59impl<E: Event + Clone + PartialEq> Plugin for EventFilterPlugin<E> {
60 #[cfg_attr(coverage_nightly, coverage(off))]
61 fn build(&self, app: &mut App) {
62 app.add_event::<E>().init_resource::<CollectedEvents<E>>();
63 match &self {
64 EventFilterPlugin::Only(event) => {
65 app.add_systems(PostUpdate, {
66 let event = event.clone();
67 move |mut events: EventReader<E>, mut collection: ResMut<CollectedEvents<E>>| {
68 collection.extend(events.read().filter(|ev| *ev == &event).cloned());
69 }
70 });
71 }
72 EventFilterPlugin::AnyOf(ref any_of_events) => {
73 app.add_systems(PostUpdate, {
74 let any_of_events = any_of_events.clone();
75 move |mut events: EventReader<E>, mut collection: ResMut<CollectedEvents<E>>| {
76 collection.extend(
77 events
78 .read()
79 .filter(|ev| any_of_events.contains(ev))
80 .cloned(),
81 );
82 }
83 });
84 }
85 }
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use std::str::FromStr;
92
93 use bevy_app::Update;
94 use bevy_ecs::event::EventWriter;
95 use rstest::*;
96 use speculoos::prelude::*;
97
98 use super::*;
99 use crate::{app::TestApp, fixtures::test_app, traits::CollectEvents};
100
101 #[rstest]
102 fn test_collected_events_default_deref() {
103 let collected_events: CollectedEvents<CmpEvent> = CollectedEvents::default();
104 let v1: &Vec<_> = &*collected_events;
105 let v2: &Vec<_> = collected_events.get();
106 assert_that!(v1).is_equal_to(v2);
107 }
108
109 #[derive(Debug, Event, Copy, Clone)]
110 struct NonEqEvent;
111
112 #[rstest]
113 #[case(0)]
114 #[case(1)]
115 #[case(10)]
116 fn test_event_collector_plugin(
117 #[from(test_app)]
118 #[with(EventCollectorPlugin::<NonEqEvent>::default())]
119 mut app: TestApp,
120 #[case] emit_count: usize,
121 ) {
122 use crate::traits::CollectEvents;
123
124 app.add_systems(Update, move |mut writer: EventWriter<NonEqEvent>| {
125 for _ in 0..emit_count {
126 writer.send(NonEqEvent);
127 }
128 });
129
130 app.update();
131
132 assert_that!(app.get_collected_events::<NonEqEvent>())
133 .is_some()
134 .has_length(emit_count);
135 }
136
137 #[derive(Debug, Event, Clone, PartialEq)]
138 enum CmpEvent {
139 A,
140 B,
141 C,
142 }
143
144 #[rstest]
145 #[case("ABCA", "A", "AA")]
146 #[case("BCAB", "B", "BB")]
147 #[case("CABC", "C", "CC")]
148 fn test_event_filter_plugin_only(
149 #[case] events_to_emit: EventList<CmpEvent>,
150 #[case] only_event: CmpEvent,
151 #[case] expected_events: EventList<CmpEvent>,
152 #[from(test_app)]
153 #[with(EventFilterPlugin::Only(only_event.clone()))]
154 mut app: TestApp,
155 ) {
156 app.add_systems(Update, move |mut writer: EventWriter<CmpEvent>| {
157 for e in &*events_to_emit {
158 writer.send(e.clone());
159 }
160 });
161
162 app.update();
163
164 let collected_events = app.get_collected_events::<CmpEvent>();
165 assert_that!(collected_events)
166 .is_some()
167 .is_equal_to(&*expected_events);
168
169 for e in &collected_events.unwrap() {
170 assert_that!(e).is_equal_to(&only_event);
171 }
172 }
173
174 #[rstest]
175 #[case("AABBCC", "A", "AA")]
176 #[case("AABBCC", "B", "BB")]
177 #[case("AABBCC", "C", "CC")]
178 #[case("ABCCBA", "AB", "ABBA")]
179 #[case("ABCCBA", "AC", "ACCA")]
180 #[case("ABCCBA", "BC", "BCCB")]
181 #[case("AABBCC", "ABC", "AABBCC")]
182 fn test_event_filter_plugin_any_of(
183 #[case] events_to_emit: EventList<CmpEvent>,
184 #[case] any_of_events: EventList<CmpEvent>,
185 #[case] expected_events: EventList<CmpEvent>,
186 #[from(test_app)]
187 #[with(EventFilterPlugin::AnyOf((*any_of_events).clone()))]
188 mut app: TestApp,
189 ) {
190 use crate::traits::CollectEvents;
191
192 app.add_systems(Update, move |mut writer: EventWriter<CmpEvent>| {
193 for e in &*events_to_emit {
194 writer.send(e.clone());
195 }
196 });
197
198 app.update();
199
200 let collected_events = app.get_collected_events::<CmpEvent>();
201 assert_that!(collected_events)
202 .is_some()
203 .is_equal_to(&*expected_events);
204
205 for e in collected_events.unwrap().into_iter() {
206 assert_that!(*any_of_events).contains(e);
207 }
208 }
209
210 pub struct InvalidEvent;
211
212 impl FromStr for CmpEvent {
213 type Err = InvalidEvent;
214 fn from_str(s: &str) -> Result<Self, Self::Err> {
215 match s {
216 "A" => Ok(CmpEvent::A),
217 "B" => Ok(CmpEvent::B),
218 "C" => Ok(CmpEvent::C),
219 _ => Err(InvalidEvent),
220 }
221 }
222 }
223
224 #[rstest]
225 #[case("A", Some(CmpEvent::A))]
226 #[case("B", Some(CmpEvent::B))]
227 #[case("C", Some(CmpEvent::C))]
228 #[should_panic]
229 #[case("", None)]
230 #[should_panic]
231 #[case("D", None)]
232 #[should_panic]
233 #[case("more nonsense", None)]
234 fn test_filtered_event_fromstr(#[case] magic: CmpEvent, #[case] expected: Option<CmpEvent>) {
235 assert_that!(magic).is_equal_to(expected.unwrap());
236 }
237
238 #[derive(Debug, Clone, Deref)]
239 struct EventList<E: Event + Clone>(Vec<E>);
240
241 impl<E: Event + Clone + FromStr<Err = InvalidEvent>> FromStr for EventList<E> {
242 type Err = InvalidEvent;
243 fn from_str(s: &str) -> Result<Self, Self::Err> {
244 let mut events = Vec::new();
245 for c in s.chars() {
246 let e = E::from_str(&c.to_string())?;
247 events.push(e);
248 }
249 Ok(EventList(events))
250 }
251 }
252
253 #[rstest]
254 #[case("A", vec![CmpEvent::A])]
255 #[case("AB", vec![CmpEvent::A, CmpEvent::B])]
256 #[case("ABC", vec![CmpEvent::A, CmpEvent::B, CmpEvent::C])]
257 #[case("AABBCC", vec![
258 CmpEvent::A, CmpEvent::A,
259 CmpEvent::B, CmpEvent::B,
260 CmpEvent::C, CmpEvent::C
261 ])]
262 #[should_panic]
263 #[case("abc", vec![])]
264 fn test_event_list_fromstr(
265 #[case] magic: EventList<CmpEvent>,
266 #[case] expected: Vec<CmpEvent>,
267 ) {
268 assert_that!(*magic).is_equal_to(&expected);
269 }
270}