rmv-bevy-testing-tools 0.10.2

Write simple tests for bevy systems, using rstest, insta, and speculoos.
Documentation
use std::marker::PhantomData;

use bevy_app::{App, Plugin, PostUpdate};
use bevy_derive::{Deref, DerefMut};
use bevy_ecs::{
    message::{Message, MessageReader},
    resource::Resource,
    system::ResMut,
};

#[derive(Debug, Deref, DerefMut, Resource)]
pub struct CollectedMessages<E>(Vec<E>);

impl<E: Message> CollectedMessages<E> {
    pub fn get(&self) -> &Vec<E> {
        &self.0
    }
}

impl<E: Message> Default for CollectedMessages<E> {
    fn default() -> Self {
        Self(Vec::new())
    }
}

#[derive(Debug)]
pub struct MessageCollectorPlugin<E>(PhantomData<E>)
where
    E: Message + Clone;

impl<E: Message + Clone> Default for MessageCollectorPlugin<E> {
    fn default() -> Self {
        Self(PhantomData)
    }
}

impl<E: Message + Clone> Plugin for MessageCollectorPlugin<E> {
    #[cfg_attr(coverage_nightly, coverage(off))]
    fn build(&self, app: &mut App) {
        app.add_message::<E>()
            .init_resource::<CollectedMessages<E>>()
            .add_systems(
                PostUpdate,
                |mut messages: MessageReader<E>, mut collection: ResMut<CollectedMessages<E>>| {
                    collection.extend(messages.read().cloned());
                },
            );
    }
}

#[derive(Debug)]
pub enum MessageFilterPlugin<E>
where
    E: Message + Clone + PartialEq,
{
    Only(E),
    AnyOf(Vec<E>),
}

impl<E: Message + Clone + PartialEq> Plugin for MessageFilterPlugin<E> {
    #[cfg_attr(coverage_nightly, coverage(off))]
    fn build(&self, app: &mut App) {
        app.add_message::<E>()
            .init_resource::<CollectedMessages<E>>();
        match &self {
            MessageFilterPlugin::Only(message) => {
                app.add_systems(PostUpdate, {
                    let message = message.clone();
                    move |mut messages: MessageReader<E>,
                          mut collection: ResMut<CollectedMessages<E>>| {
                        collection.extend(messages.read().filter(|ev| *ev == &message).cloned());
                    }
                });
            }
            MessageFilterPlugin::AnyOf(any_of_messages) => {
                app.add_systems(PostUpdate, {
                    let any_of_messages = any_of_messages.clone();
                    move |mut messages: MessageReader<E>,
                          mut collection: ResMut<CollectedMessages<E>>| {
                        collection.extend(
                            messages
                                .read()
                                .filter(|ev| any_of_messages.contains(ev))
                                .cloned(),
                        );
                    }
                });
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use std::str::FromStr;

    use bevy_app::Update;
    use bevy_ecs::message::MessageWriter;
    use rstest::*;
    use speculoos::prelude::*;

    use super::*;
    use crate::{fixtures::minimal_test_app, test_app::TestApp, traits::CollectMessages};

    #[rstest]
    fn test_collected_messages_default_deref() {
        let collected_messages: CollectedMessages<CmpMessage> = CollectedMessages::default();
        let v1: &Vec<_> = &*collected_messages;
        let v2: &Vec<_> = collected_messages.get();
        assert_that!(v1).is_equal_to(v2);
    }

    #[derive(Clone, Copy, Debug, Message)]
    struct NonEqMessage;

    #[rstest]
    #[case(0)]
    #[case(1)]
    #[case(10)]
    fn test_message_collector_plugin(
        #[from(minimal_test_app)]
        #[with(MessageCollectorPlugin::<NonEqMessage>::default())]
        mut app: TestApp,
        #[case] emit_count: usize,
    ) {
        use crate::traits::CollectMessages;

        app.add_systems(Update, move |mut writer: MessageWriter<NonEqMessage>| {
            for _ in 0..emit_count {
                writer.write(NonEqMessage);
            }
        });

        app.update();

        assert_that!(app.get_collected_messages::<NonEqMessage>())
            .is_some()
            .has_length(emit_count);
    }

    #[derive(Clone, Debug, Message, PartialEq)]
    enum CmpMessage {
        A,
        B,
        C,
    }

    #[rstest]
    #[case("ABCA", "A", "AA")]
    #[case("BCAB", "B", "BB")]
    #[case("CABC", "C", "CC")]
    fn test_message_filter_plugin_only(
        #[case] messages_to_emit: MessageList<CmpMessage>,
        #[case] only_message: CmpMessage,
        #[case] expected_messages: MessageList<CmpMessage>,
        #[from(minimal_test_app)]
        #[with(MessageFilterPlugin::Only(only_message.clone()))]
        mut app: TestApp,
    ) {
        app.add_systems(Update, move |mut writer: MessageWriter<CmpMessage>| {
            for e in &*messages_to_emit {
                writer.write(e.clone());
            }
        });

        app.update();

        let collected_messages = app.get_collected_messages::<CmpMessage>();
        assert_that!(collected_messages)
            .is_some()
            .is_equal_to(&*expected_messages);

        for e in &collected_messages.unwrap() {
            assert_that!(e).is_equal_to(&only_message);
        }
    }

    #[rstest]
    #[case("AABBCC", "A", "AA")]
    #[case("AABBCC", "B", "BB")]
    #[case("AABBCC", "C", "CC")]
    #[case("ABCCBA", "AB", "ABBA")]
    #[case("ABCCBA", "AC", "ACCA")]
    #[case("ABCCBA", "BC", "BCCB")]
    #[case("AABBCC", "ABC", "AABBCC")]
    fn test_message_filter_plugin_any_of(
        #[case] messages_to_emit: MessageList<CmpMessage>,
        #[case] any_of_messages: MessageList<CmpMessage>,
        #[case] expected_messages: MessageList<CmpMessage>,
        #[from(minimal_test_app)]
        #[with(MessageFilterPlugin::AnyOf((*any_of_messages).clone()))]
        mut app: TestApp,
    ) {
        use crate::traits::CollectMessages;

        app.add_systems(Update, move |mut writer: MessageWriter<CmpMessage>| {
            for e in &*messages_to_emit {
                writer.write(e.clone());
            }
        });

        app.update();

        let collected_messages = app.get_collected_messages::<CmpMessage>();
        assert_that!(collected_messages)
            .is_some()
            .is_equal_to(&*expected_messages);

        for e in collected_messages.unwrap().into_iter() {
            assert_that!(*any_of_messages).contains(e);
        }
    }

    pub struct InvalidMessage;

    impl FromStr for CmpMessage {
        type Err = InvalidMessage;
        fn from_str(s: &str) -> Result<Self, Self::Err> {
            match s {
                "A" => Ok(CmpMessage::A),
                "B" => Ok(CmpMessage::B),
                "C" => Ok(CmpMessage::C),
                _ => Err(InvalidMessage),
            }
        }
    }

    #[rstest]
    #[case("A", Some(CmpMessage::A))]
    #[case("B", Some(CmpMessage::B))]
    #[case("C", Some(CmpMessage::C))]
    #[should_panic]
    #[case("", None)]
    #[should_panic]
    #[case("D", None)]
    #[should_panic]
    #[case("more nonsense", None)]
    fn test_filtered_message_fromstr(
        #[case] magic: CmpMessage,
        #[case] expected: Option<CmpMessage>,
    ) {
        assert_that!(magic).is_equal_to(expected.unwrap());
    }

    #[derive(Clone, Debug, Deref)]
    struct MessageList<E: Message + Clone>(Vec<E>);

    impl<E: Message + Clone + FromStr<Err = InvalidMessage>> FromStr for MessageList<E> {
        type Err = InvalidMessage;
        fn from_str(s: &str) -> Result<Self, Self::Err> {
            let mut messages = Vec::new();
            for c in s.chars() {
                let e = E::from_str(&c.to_string())?;
                messages.push(e);
            }
            Ok(MessageList(messages))
        }
    }

    #[rstest]
    #[case("A", vec![CmpMessage::A])]
    #[case("AB", vec![CmpMessage::A, CmpMessage::B])]
    #[case("ABC", vec![CmpMessage::A, CmpMessage::B, CmpMessage::C])]
    #[case("AABBCC", vec![
        CmpMessage::A, CmpMessage::A,
        CmpMessage::B, CmpMessage::B,
        CmpMessage::C, CmpMessage::C
    ])]
    #[should_panic]
    #[case("abc", vec![])]
    fn test_message_list_fromstr(
        #[case] magic: MessageList<CmpMessage>,
        #[case] expected: Vec<CmpMessage>,
    ) {
        assert_that!(*magic).is_equal_to(&expected);
    }
}