rmv_bevy_testing_tools/
events.rs

1use std::marker::PhantomData;
2
3use bevy_app::{App, Plugin, PostUpdate};
4use bevy_derive::{Deref, DerefMut};
5use bevy_ecs::{
6    event::{Event, EventReader},
7    system::{ResMut, Resource},
8};
9
10#[derive(Debug, Resource, Deref, DerefMut)]
11pub struct CollectedEvents<E>(Vec<E>);
12
13impl<E: Event> CollectedEvents<E> {
14    pub fn get(&self) -> &Vec<E> {
15        &self.0
16    }
17}
18
19impl<E: Event> Default for CollectedEvents<E> {
20    fn default() -> Self {
21        Self(Vec::new())
22    }
23}
24
25#[derive(Debug)]
26pub struct EventCollectorPlugin<E>(PhantomData<E>)
27where
28    E: Event + Clone;
29
30impl<E: Event + Clone> Default for EventCollectorPlugin<E> {
31    fn default() -> Self {
32        Self(PhantomData)
33    }
34}
35
36impl<E: Event + Clone> Plugin for EventCollectorPlugin<E> {
37    #[cfg_attr(coverage_nightly, coverage(off))]
38    fn build(&self, app: &mut App) {
39        app.add_event::<E>()
40            .init_resource::<CollectedEvents<E>>()
41            .add_systems(
42                PostUpdate,
43                |mut events: EventReader<E>, mut collection: ResMut<CollectedEvents<E>>| {
44                    collection.extend(events.read().cloned());
45                },
46            );
47    }
48}
49
50#[derive(Debug)]
51pub enum EventFilterPlugin<E>
52where
53    E: Event + Clone + PartialEq,
54{
55    Only(E),
56    AnyOf(Vec<E>),
57}
58
59impl<E: Event + Clone + PartialEq> Plugin for EventFilterPlugin<E> {
60    #[cfg_attr(coverage_nightly, coverage(off))]
61    fn build(&self, app: &mut App) {
62        app.add_event::<E>().init_resource::<CollectedEvents<E>>();
63        match &self {
64            EventFilterPlugin::Only(event) => {
65                app.add_systems(PostUpdate, {
66                    let event = event.clone();
67                    move |mut events: EventReader<E>, mut collection: ResMut<CollectedEvents<E>>| {
68                        collection.extend(events.read().filter(|ev| *ev == &event).cloned());
69                    }
70                });
71            }
72            EventFilterPlugin::AnyOf(ref any_of_events) => {
73                app.add_systems(PostUpdate, {
74                    let any_of_events = any_of_events.clone();
75                    move |mut events: EventReader<E>, mut collection: ResMut<CollectedEvents<E>>| {
76                        collection.extend(
77                            events
78                                .read()
79                                .filter(|ev| any_of_events.contains(ev))
80                                .cloned(),
81                        );
82                    }
83                });
84            }
85        }
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use std::str::FromStr;
92
93    use bevy_app::Update;
94    use bevy_ecs::event::EventWriter;
95    use rstest::*;
96    use speculoos::prelude::*;
97
98    use super::*;
99    use crate::{app::TestApp, fixtures::test_app, traits::CollectEvents};
100
101    #[rstest]
102    fn test_collected_events_default_deref() {
103        let collected_events: CollectedEvents<CmpEvent> = CollectedEvents::default();
104        let v1: &Vec<_> = &*collected_events;
105        let v2: &Vec<_> = collected_events.get();
106        assert_that!(v1).is_equal_to(v2);
107    }
108
109    #[derive(Debug, Event, Copy, Clone)]
110    struct NonEqEvent;
111
112    #[rstest]
113    #[case(0)]
114    #[case(1)]
115    #[case(10)]
116    fn test_event_collector_plugin(
117        #[from(test_app)]
118        #[with(EventCollectorPlugin::<NonEqEvent>::default())]
119        mut app: TestApp,
120        #[case] emit_count: usize,
121    ) {
122        use crate::traits::CollectEvents;
123
124        app.add_systems(Update, move |mut writer: EventWriter<NonEqEvent>| {
125            for _ in 0..emit_count {
126                writer.send(NonEqEvent);
127            }
128        });
129
130        app.update();
131
132        assert_that!(app.get_collected_events::<NonEqEvent>())
133            .is_some()
134            .has_length(emit_count);
135    }
136
137    #[derive(Debug, Event, Clone, PartialEq)]
138    enum CmpEvent {
139        A,
140        B,
141        C,
142    }
143
144    #[rstest]
145    #[case("ABCA", "A", "AA")]
146    #[case("BCAB", "B", "BB")]
147    #[case("CABC", "C", "CC")]
148    fn test_event_filter_plugin_only(
149        #[case] events_to_emit: EventList<CmpEvent>,
150        #[case] only_event: CmpEvent,
151        #[case] expected_events: EventList<CmpEvent>,
152        #[from(test_app)]
153        #[with(EventFilterPlugin::Only(only_event.clone()))]
154        mut app: TestApp,
155    ) {
156        app.add_systems(Update, move |mut writer: EventWriter<CmpEvent>| {
157            for e in &*events_to_emit {
158                writer.send(e.clone());
159            }
160        });
161
162        app.update();
163
164        let collected_events = app.get_collected_events::<CmpEvent>();
165        assert_that!(collected_events)
166            .is_some()
167            .is_equal_to(&*expected_events);
168
169        for e in &collected_events.unwrap() {
170            assert_that!(e).is_equal_to(&only_event);
171        }
172    }
173
174    #[rstest]
175    #[case("AABBCC", "A", "AA")]
176    #[case("AABBCC", "B", "BB")]
177    #[case("AABBCC", "C", "CC")]
178    #[case("ABCCBA", "AB", "ABBA")]
179    #[case("ABCCBA", "AC", "ACCA")]
180    #[case("ABCCBA", "BC", "BCCB")]
181    #[case("AABBCC", "ABC", "AABBCC")]
182    fn test_event_filter_plugin_any_of(
183        #[case] events_to_emit: EventList<CmpEvent>,
184        #[case] any_of_events: EventList<CmpEvent>,
185        #[case] expected_events: EventList<CmpEvent>,
186        #[from(test_app)]
187        #[with(EventFilterPlugin::AnyOf((*any_of_events).clone()))]
188        mut app: TestApp,
189    ) {
190        use crate::traits::CollectEvents;
191
192        app.add_systems(Update, move |mut writer: EventWriter<CmpEvent>| {
193            for e in &*events_to_emit {
194                writer.send(e.clone());
195            }
196        });
197
198        app.update();
199
200        let collected_events = app.get_collected_events::<CmpEvent>();
201        assert_that!(collected_events)
202            .is_some()
203            .is_equal_to(&*expected_events);
204
205        for e in collected_events.unwrap().into_iter() {
206            assert_that!(*any_of_events).contains(e);
207        }
208    }
209
210    pub struct InvalidEvent;
211
212    impl FromStr for CmpEvent {
213        type Err = InvalidEvent;
214        fn from_str(s: &str) -> Result<Self, Self::Err> {
215            match s {
216                "A" => Ok(CmpEvent::A),
217                "B" => Ok(CmpEvent::B),
218                "C" => Ok(CmpEvent::C),
219                _ => Err(InvalidEvent),
220            }
221        }
222    }
223
224    #[rstest]
225    #[case("A", Some(CmpEvent::A))]
226    #[case("B", Some(CmpEvent::B))]
227    #[case("C", Some(CmpEvent::C))]
228    #[should_panic]
229    #[case("", None)]
230    #[should_panic]
231    #[case("D", None)]
232    #[should_panic]
233    #[case("more nonsense", None)]
234    fn test_filtered_event_fromstr(#[case] magic: CmpEvent, #[case] expected: Option<CmpEvent>) {
235        assert_that!(magic).is_equal_to(expected.unwrap());
236    }
237
238    #[derive(Debug, Clone, Deref)]
239    struct EventList<E: Event + Clone>(Vec<E>);
240
241    impl<E: Event + Clone + FromStr<Err = InvalidEvent>> FromStr for EventList<E> {
242        type Err = InvalidEvent;
243        fn from_str(s: &str) -> Result<Self, Self::Err> {
244            let mut events = Vec::new();
245            for c in s.chars() {
246                let e = E::from_str(&c.to_string())?;
247                events.push(e);
248            }
249            Ok(EventList(events))
250        }
251    }
252
253    #[rstest]
254    #[case("A", vec![CmpEvent::A])]
255    #[case("AB", vec![CmpEvent::A, CmpEvent::B])]
256    #[case("ABC", vec![CmpEvent::A, CmpEvent::B, CmpEvent::C])]
257    #[case("AABBCC", vec![
258        CmpEvent::A, CmpEvent::A,
259        CmpEvent::B, CmpEvent::B,
260        CmpEvent::C, CmpEvent::C
261    ])]
262    #[should_panic]
263    #[case("abc", vec![])]
264    fn test_event_list_fromstr(
265        #[case] magic: EventList<CmpEvent>,
266        #[case] expected: Vec<CmpEvent>,
267    ) {
268        assert_that!(*magic).is_equal_to(&expected);
269    }
270}