1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
use std::marker::PhantomData;

use bevy::{
    ecs::schedule::{run_enter_schedule, ScheduleLabel},
    prelude::*,
};

mod frame_count;

// re-exports
pub use frame_count::{increase_frame_count, RollFrameCount};

pub mod prelude {
    pub use super::RollApp;
}

pub trait RollApp {
    /// Init state transitions in the given schedule
    fn init_roll_state<S: States + FromWorld>(&mut self, schedule: impl ScheduleLabel)
        -> &mut Self;

    #[cfg(feature = "bevy_ggrs")]
    /// Register this state to be rolled back by bevy_ggrs
    fn init_ggrs_state<S: States + FromWorld + Clone>(&mut self) -> &mut Self;

    #[cfg(feature = "bevy_ggrs")]
    /// Register this state to be rolled back by bevy_ggrs in the specified schedule
    fn init_ggrs_state_in_schedule<S: States + FromWorld + Clone>(
        &mut self,
        schedule: impl ScheduleLabel,
    ) -> &mut Self;
}

impl RollApp for App {
    fn init_roll_state<S: States + FromWorld>(
        &mut self,
        schedule: impl ScheduleLabel,
    ) -> &mut Self {
        self.init_resource::<State<S>>()
            .init_resource::<NextState<S>>()
            .init_resource::<InitialStateEntered<S>>()
            // events are not rollback safe, but `apply_state_transition` will cause errors without it
            .add_event::<StateTransitionEvent<S>>()
            .add_systems(
                schedule,
                (
                    run_enter_schedule::<S>
                        .run_if(resource_equals(InitialStateEntered::<S>(false, default()))),
                    mark_state_initialized::<S>
                        .run_if(resource_equals(InitialStateEntered::<S>(false, default()))),
                    apply_state_transition::<S>,
                )
                    .chain(),
            )
    }

    #[cfg(feature = "bevy_ggrs")]
    fn init_ggrs_state<S: States + FromWorld + Clone>(&mut self) -> &mut Self {
        use bevy_ggrs::GgrsSchedule;
        self.init_ggrs_state_in_schedule::<S>(GgrsSchedule)
    }

    #[cfg(feature = "bevy_ggrs")]
    fn init_ggrs_state_in_schedule<S: States + FromWorld + Clone>(
        &mut self,
        schedule: impl ScheduleLabel,
    ) -> &mut Self {
        use crate::ggrs_support::{NextStateStrategy, StateStrategy};
        use bevy_ggrs::{CloneStrategy, ResourceSnapshotPlugin};

        self.init_roll_state::<S>(schedule).add_plugins((
            ResourceSnapshotPlugin::<StateStrategy<S>>::default(),
            ResourceSnapshotPlugin::<NextStateStrategy<S>>::default(),
            ResourceSnapshotPlugin::<CloneStrategy<InitialStateEntered<S>>>::default(),
        ))
    }
}

#[cfg(feature = "bevy_ggrs")]
mod ggrs_support {
    use bevy::prelude::*;
    use bevy_ggrs::Strategy;
    use std::marker::PhantomData;

    pub(crate) struct StateStrategy<S: States>(PhantomData<S>);

    // todo: make State<S> implement clone instead
    impl<S: States> Strategy for StateStrategy<S> {
        type Target = State<S>;
        type Stored = S;

        fn store(target: &Self::Target) -> Self::Stored {
            target.get().to_owned()
        }

        fn load(stored: &Self::Stored) -> Self::Target {
            State::new(stored.to_owned())
        }
    }

    pub(crate) struct NextStateStrategy<S: States>(PhantomData<S>);

    // todo: make NextState<S> implement clone instead
    impl<S: States> Strategy for NextStateStrategy<S> {
        type Target = NextState<S>;
        type Stored = Option<S>;

        fn store(target: &Self::Target) -> Self::Stored {
            target.0.to_owned()
        }

        fn load(stored: &Self::Stored) -> Self::Target {
            NextState(stored.to_owned())
        }
    }
}

#[derive(Resource, Debug, Reflect, Eq, PartialEq, Clone)]
#[reflect(Resource)]
pub struct InitialStateEntered<S: States>(bool, PhantomData<S>);

impl<S: States> Default for InitialStateEntered<S> {
    fn default() -> Self {
        Self(false, default())
    }
}

fn mark_state_initialized<S: States + FromWorld>(
    mut state_initialized: ResMut<InitialStateEntered<S>>,
) {
    state_initialized.0 = true;
}