rmv_bevy_testing_tools/
events.rs

1use std::marker::PhantomData;
2
3use bevy_app::{App, Plugin, PostUpdate};
4use bevy_derive::{Deref, DerefMut};
5use bevy_ecs::{
6    event::{Event, EventReader},
7    resource::Resource,
8    system::ResMut,
9};
10
11#[derive(Debug, Resource, Deref, DerefMut)]
12pub struct CollectedEvents<E>(Vec<E>);
13
14impl<E: Event> CollectedEvents<E> {
15    pub fn get(&self) -> &Vec<E> {
16        &self.0
17    }
18}
19
20impl<E: Event> Default for CollectedEvents<E> {
21    fn default() -> Self {
22        Self(Vec::new())
23    }
24}
25
26#[derive(Debug)]
27pub struct EventCollectorPlugin<E>(PhantomData<E>)
28where
29    E: Event + Clone;
30
31impl<E: Event + Clone> Default for EventCollectorPlugin<E> {
32    fn default() -> Self {
33        Self(PhantomData)
34    }
35}
36
37impl<E: Event + Clone> Plugin for EventCollectorPlugin<E> {
38    #[cfg_attr(coverage_nightly, coverage(off))]
39    fn build(&self, app: &mut App) {
40        app.add_event::<E>()
41            .init_resource::<CollectedEvents<E>>()
42            .add_systems(
43                PostUpdate,
44                |mut events: EventReader<E>, mut collection: ResMut<CollectedEvents<E>>| {
45                    collection.extend(events.read().cloned());
46                },
47            );
48    }
49}
50
51#[derive(Debug)]
52pub enum EventFilterPlugin<E>
53where
54    E: Event + Clone + PartialEq,
55{
56    Only(E),
57    AnyOf(Vec<E>),
58}
59
60impl<E: Event + Clone + PartialEq> Plugin for EventFilterPlugin<E> {
61    #[cfg_attr(coverage_nightly, coverage(off))]
62    fn build(&self, app: &mut App) {
63        app.add_event::<E>().init_resource::<CollectedEvents<E>>();
64        match &self {
65            EventFilterPlugin::Only(event) => {
66                app.add_systems(PostUpdate, {
67                    let event = event.clone();
68                    move |mut events: EventReader<E>, mut collection: ResMut<CollectedEvents<E>>| {
69                        collection.extend(events.read().filter(|ev| *ev == &event).cloned());
70                    }
71                });
72            }
73            EventFilterPlugin::AnyOf(ref any_of_events) => {
74                app.add_systems(PostUpdate, {
75                    let any_of_events = any_of_events.clone();
76                    move |mut events: EventReader<E>, mut collection: ResMut<CollectedEvents<E>>| {
77                        collection.extend(
78                            events
79                                .read()
80                                .filter(|ev| any_of_events.contains(ev))
81                                .cloned(),
82                        );
83                    }
84                });
85            }
86        }
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use std::str::FromStr;
93
94    use bevy_app::Update;
95    use bevy_ecs::event::EventWriter;
96    use rstest::*;
97    use speculoos::prelude::*;
98
99    use super::*;
100    use crate::{fixtures::test_app, test_app::TestApp, traits::CollectEvents};
101
102    #[rstest]
103    fn test_collected_events_default_deref() {
104        let collected_events: CollectedEvents<CmpEvent> = CollectedEvents::default();
105        let v1: &Vec<_> = &*collected_events;
106        let v2: &Vec<_> = collected_events.get();
107        assert_that!(v1).is_equal_to(v2);
108    }
109
110    #[derive(Debug, Event, Copy, Clone)]
111    struct NonEqEvent;
112
113    #[rstest]
114    #[case(0)]
115    #[case(1)]
116    #[case(10)]
117    fn test_event_collector_plugin(
118        #[from(test_app)]
119        #[with(EventCollectorPlugin::<NonEqEvent>::default())]
120        mut app: TestApp,
121        #[case] emit_count: usize,
122    ) {
123        use crate::traits::CollectEvents;
124
125        app.add_systems(Update, move |mut writer: EventWriter<NonEqEvent>| {
126            for _ in 0..emit_count {
127                writer.send(NonEqEvent);
128            }
129        });
130
131        app.update();
132
133        assert_that!(app.get_collected_events::<NonEqEvent>())
134            .is_some()
135            .has_length(emit_count);
136    }
137
138    #[derive(Debug, Event, Clone, PartialEq)]
139    enum CmpEvent {
140        A,
141        B,
142        C,
143    }
144
145    #[rstest]
146    #[case("ABCA", "A", "AA")]
147    #[case("BCAB", "B", "BB")]
148    #[case("CABC", "C", "CC")]
149    fn test_event_filter_plugin_only(
150        #[case] events_to_emit: EventList<CmpEvent>,
151        #[case] only_event: CmpEvent,
152        #[case] expected_events: EventList<CmpEvent>,
153        #[from(test_app)]
154        #[with(EventFilterPlugin::Only(only_event.clone()))]
155        mut app: TestApp,
156    ) {
157        app.add_systems(Update, move |mut writer: EventWriter<CmpEvent>| {
158            for e in &*events_to_emit {
159                writer.send(e.clone());
160            }
161        });
162
163        app.update();
164
165        let collected_events = app.get_collected_events::<CmpEvent>();
166        assert_that!(collected_events)
167            .is_some()
168            .is_equal_to(&*expected_events);
169
170        for e in &collected_events.unwrap() {
171            assert_that!(e).is_equal_to(&only_event);
172        }
173    }
174
175    #[rstest]
176    #[case("AABBCC", "A", "AA")]
177    #[case("AABBCC", "B", "BB")]
178    #[case("AABBCC", "C", "CC")]
179    #[case("ABCCBA", "AB", "ABBA")]
180    #[case("ABCCBA", "AC", "ACCA")]
181    #[case("ABCCBA", "BC", "BCCB")]
182    #[case("AABBCC", "ABC", "AABBCC")]
183    fn test_event_filter_plugin_any_of(
184        #[case] events_to_emit: EventList<CmpEvent>,
185        #[case] any_of_events: EventList<CmpEvent>,
186        #[case] expected_events: EventList<CmpEvent>,
187        #[from(test_app)]
188        #[with(EventFilterPlugin::AnyOf((*any_of_events).clone()))]
189        mut app: TestApp,
190    ) {
191        use crate::traits::CollectEvents;
192
193        app.add_systems(Update, move |mut writer: EventWriter<CmpEvent>| {
194            for e in &*events_to_emit {
195                writer.send(e.clone());
196            }
197        });
198
199        app.update();
200
201        let collected_events = app.get_collected_events::<CmpEvent>();
202        assert_that!(collected_events)
203            .is_some()
204            .is_equal_to(&*expected_events);
205
206        for e in collected_events.unwrap().into_iter() {
207            assert_that!(*any_of_events).contains(e);
208        }
209    }
210
211    pub struct InvalidEvent;
212
213    impl FromStr for CmpEvent {
214        type Err = InvalidEvent;
215        fn from_str(s: &str) -> Result<Self, Self::Err> {
216            match s {
217                "A" => Ok(CmpEvent::A),
218                "B" => Ok(CmpEvent::B),
219                "C" => Ok(CmpEvent::C),
220                _ => Err(InvalidEvent),
221            }
222        }
223    }
224
225    #[rstest]
226    #[case("A", Some(CmpEvent::A))]
227    #[case("B", Some(CmpEvent::B))]
228    #[case("C", Some(CmpEvent::C))]
229    #[should_panic]
230    #[case("", None)]
231    #[should_panic]
232    #[case("D", None)]
233    #[should_panic]
234    #[case("more nonsense", None)]
235    fn test_filtered_event_fromstr(#[case] magic: CmpEvent, #[case] expected: Option<CmpEvent>) {
236        assert_that!(magic).is_equal_to(expected.unwrap());
237    }
238
239    #[derive(Debug, Clone, Deref)]
240    struct EventList<E: Event + Clone>(Vec<E>);
241
242    impl<E: Event + Clone + FromStr<Err = InvalidEvent>> FromStr for EventList<E> {
243        type Err = InvalidEvent;
244        fn from_str(s: &str) -> Result<Self, Self::Err> {
245            let mut events = Vec::new();
246            for c in s.chars() {
247                let e = E::from_str(&c.to_string())?;
248                events.push(e);
249            }
250            Ok(EventList(events))
251        }
252    }
253
254    #[rstest]
255    #[case("A", vec![CmpEvent::A])]
256    #[case("AB", vec![CmpEvent::A, CmpEvent::B])]
257    #[case("ABC", vec![CmpEvent::A, CmpEvent::B, CmpEvent::C])]
258    #[case("AABBCC", vec![
259        CmpEvent::A, CmpEvent::A,
260        CmpEvent::B, CmpEvent::B,
261        CmpEvent::C, CmpEvent::C
262    ])]
263    #[should_panic]
264    #[case("abc", vec![])]
265    fn test_event_list_fromstr(
266        #[case] magic: EventList<CmpEvent>,
267        #[case] expected: Vec<CmpEvent>,
268    ) {
269        assert_that!(*magic).is_equal_to(&expected);
270    }
271}