rex/
manager.rs

1/*!
2```text
3  ╔═════════════════════════════════╗
4  ║S => Signal       [ Kind, Input ]║
5  ║N => Notification [ Kind, Input ]║                                                    ▲
6  ╚═════════════════════════════════╝                                                    │
7                              ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓                            │
8                              ┃                             ┃                            │
9                              ┃                             ┃                          other
10                              ┃                             ┃                            │
11                 ┌─────S──────┃       TimeoutManager        ◀─────────┐                  │
12                 │            ┃                             ┃         │                  │
13                 │            ┃                             ┃         │   ┏━━━━━━━━━━━━━━▼━━━━━━━━━━━━━━┓
14                 │            ┃                             ┃         N   ┃                             ┃
15                 │            ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛         │   ┃                             ┃
16                 │                                                    │   ┃    Outlet (External I/O)    ┃
17                 ├───────────────────────────S────────────────────────┼───┫[GatewayClient + Service TX +┃
18                 │                                                    │   ┃           etc...]           ┃
19                 │                    ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─           │   ┃                             ┃
20                 │                          Signals &      │          │   ┃                             ┃
21  ┏━━━━━━━━━━━━━━▼━━━━━━━━━━━━━━┓     │   Notifications               │   ┗━━━━━━━━━━━━━━▲━━━━━━━━━━━━━━┛
22  ┃                             ┃      ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘          │                  │
23  ┃                             ┃                           ╔══════════════════╗         │
24  ┃                             ┃                           ║                  ║         N
25  ┃     StateMachineManager     ◀───┬──┐                    ║   Notification   ║         │
26  ┃                             ┃   │  │                    ║      Queue       ║─────────┘
27  ┃                             ┃   │  │                    ║                  ║
28  ┃                             ┃   │  │                    ╚═════════▲════════╝
29  ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛   │  │                              │
30                 │                  │  │                              N
31                 S                  │  │                              │
32                 │                  S  S                ┏━━━━━━━━━━━━━┻━━━━━━━━━━━━┓
33       ╔═════════▼════════╗         │  │                ┃                          ┃
34       ║                  ║         │  │                ┃                          ┃
35       ║   Signal Queue   ║─┐       │  │         ┌──┬───▶   NotificationManager    ┃
36       ║                  ║ │       │  │         │  │   ┃                          ┃
37       ╚══════════════════╝ │       │  │         │  │   ┃                          ┃
38                 │          S ┌─────┴──┴─────┐   N  N   ┗━━━━━━━━━━━━━━━━━━━━━━━━━━┛
39                 │          │ │              │   │  │
40                 │          └─▶ StateMachine │───┼──┘
41                 │         ┌──┴───────────┐  │   │
42                 │         │              ├──┘   │
43                 └────S────▶ StateMachine │──────┘
44                           │              │
45                           └──────────────┘
46```
47*/
48
49use std::{collections::HashMap, fmt, hash::Hash, sync::Arc, time::Duration};
50
51use bigerror::{LogError, OptionReport};
52use parking_lot::FairMutex;
53use tokio::task::JoinSet;
54use tokio_stream::StreamExt;
55use tracing::{debug, Instrument};
56
57use crate::{
58    node::{Insert, Node, Update},
59    notification::{Notification, NotificationQueue},
60    queue::StreamableDeque,
61    storage::{StateStore, Tree},
62    timeout::{RetainItem, TimeoutInput, TimeoutMessage},
63    Kind, Rex, StateId,
64};
65
66pub trait HashKind: Kind + fmt::Debug + Hash + Eq + PartialEq + 'static + Copy
67where
68    Self: Send + Sync,
69{
70}
71
72impl<K> HashKind for K where
73    K: Kind + fmt::Debug + Hash + Eq + PartialEq + Send + Sync + 'static + Copy
74{
75}
76
77/// The [`Signal`] struct represents a routable input meant to be consumed
78/// by a state machine processor.
79/// The `id` field holds:
80/// * The routing logic accessed by the [`Kind`] portion of the id
81/// * a distinguishing identifier that separates state of the _same_ kind
82///
83///
84/// The `input` field :
85/// * holds an event or message meant to be processed by a given state machine
86/// * an event is defined as any output emitted by a state machine
87#[derive(Debug, PartialEq)]
88pub struct Signal<K>
89where
90    K: Rex,
91{
92    pub id: StateId<K>,
93    pub input: K::Input,
94}
95
96impl<K> Signal<K>
97where
98    K: Rex,
99{
100    fn state_change(id: StateId<K>, state: K::State) -> Option<Self> {
101        id.kind.state_input(state).map(|input| Self { id, input })
102    }
103}
104
105pub type SignalQueue<K> = Arc<StreamableDeque<Signal<K>>>;
106
107/// [`SignalExt`] calls [`Signal::state_change`] to consume a [`Kind::State`] and emit
108/// a state change [`Signal`] with a valid [`StateMachine::Input`]
109pub trait SignalExt<K>
110where
111    K: Rex,
112{
113    fn signal_state_change(&self, id: StateId<K>, state: K::State);
114}
115
116impl<K> SignalExt<K> for StreamableDeque<Signal<K>>
117where
118    K: Rex,
119{
120    fn signal_state_change(&self, id: StateId<K>, state: K::State) {
121        if let Some(sig) = Signal::state_change(id, state) {
122            self.push_back(sig);
123        }
124    }
125}
126
127/// Store the injectable dependencies provided by the [`StateMachineManager`]
128/// to a given state machine processor.
129pub struct SmContext<K: Rex> {
130    pub signal_queue: SignalQueue<K>,
131    pub notification_queue: NotificationQueue<K::Message>,
132    pub state_store: Arc<StateStore<StateId<K>, K::State>>,
133    pub id: StateId<K>,
134}
135
136impl<K: Rex> SmContext<K> {
137    pub fn notify(&self, notification: Notification<K::Message>) {
138        self.notification_queue.send(notification);
139    }
140
141    pub fn signal_self(&self, input: K::Input) {
142        self.signal_queue.push_front(Signal { id: self.id, input });
143    }
144
145    pub fn get_state(&self) -> Option<K::State> {
146        let tree = self.state_store.get_tree(self.id)?;
147        let guard = tree.lock();
148        guard.get_state(self.id).copied()
149    }
150
151    pub fn get_tree(&self) -> Option<Tree<K>> {
152        self.state_store.get_tree(self.id)
153    }
154
155    pub fn has_state(&self) -> bool {
156        self.state_store.get_tree(self.id).is_some()
157    }
158
159    pub fn get_parent_id(&self) -> Option<StateId<K>> {
160        self.get_tree().and_then(|tree| {
161            let guard = tree.lock();
162            guard.get_parent_id(self.id)
163        })
164    }
165
166    pub fn has_parent(&self) -> bool {
167        self.get_parent_id().is_some()
168    }
169}
170impl<K: Rex> Clone for SmContext<K> {
171    fn clone(&self) -> Self {
172        Self {
173            signal_queue: self.signal_queue.clone(),
174            notification_queue: self.notification_queue.clone(),
175            state_store: self.state_store.clone(),
176            id: self.id,
177        }
178    }
179}
180
181impl<K: Rex> std::ops::Deref for SmContext<K> {
182    type Target = StateId<K>;
183
184    fn deref(&self) -> &Self::Target {
185        &self.id
186    }
187}
188
189/// Manages the [`Signal`] scope of various [`State`]s and [`StateMachine`]s bounded by
190/// a [`Kind`] enumerable
191pub struct StateMachineManager<K: Rex> {
192    signal_queue: SignalQueue<K>,
193    notification_queue: NotificationQueue<K::Message>,
194    state_machines: Arc<HashMap<K, BoxedStateMachine<K>>>,
195    state_store: Arc<StateStore<StateId<K>, K::State>>,
196}
197
198pub struct EmptyContext<K: Rex> {
199    pub signal_queue: SignalQueue<K>,
200    pub notification_queue: NotificationQueue<K::Message>,
201    pub state_store: Arc<StateStore<StateId<K>, K::State>>,
202}
203impl<K: Rex> EmptyContext<K> {
204    fn init(&self, id: StateId<K>) -> SmContext<K> {
205        SmContext {
206            signal_queue: self.signal_queue.clone(),
207            notification_queue: self.notification_queue.clone(),
208            state_store: self.state_store.clone(),
209            id,
210        }
211    }
212}
213
214impl<K: Rex> StateMachineManager<K> {
215    #[must_use]
216    pub fn ctx_builder(&self) -> EmptyContext<K> {
217        EmptyContext {
218            signal_queue: self.signal_queue.clone(),
219            notification_queue: self.notification_queue.clone(),
220            state_store: self.state_store.clone(),
221        }
222    }
223
224    // const fn does not currently support Iterable
225    pub fn new(
226        state_machines: Vec<BoxedStateMachine<K>>,
227        signal_queue: SignalQueue<K>,
228        notification_queue: NotificationQueue<K::Message>,
229    ) -> Self {
230        let sm_count = state_machines.len();
231        let state_machines: HashMap<K, BoxedStateMachine<K>> = state_machines
232            .into_iter()
233            .map(|sm| (sm.get_kind(), sm))
234            .collect();
235        assert_eq!(
236            sm_count,
237            state_machines.len(),
238            "multiple state machines using the same kind, SMs: {sm_count}, Kinds: {}",
239            state_machines.len(),
240        );
241        Self {
242            signal_queue,
243            notification_queue,
244            state_machines: Arc::new(state_machines),
245            state_store: Arc::new(StateStore::new()),
246        }
247    }
248
249    pub fn init(&self, join_set: &mut JoinSet<()>) {
250        let stream_queue = self.signal_queue.clone();
251        let sm_dispatcher = self.state_machines.clone();
252        let ctx = self.ctx_builder();
253        join_set.spawn(
254            async move {
255                debug!(target:  "state_machine", spawning = "StateMachineManager.signal_queue");
256                let mut stream = stream_queue.stream();
257                while let Some(Signal { id, input }) = stream.next().await {
258                    if let Ok(sm) = sm_dispatcher
259                        .get(&id)
260                        .expect_kv("state_machine", id)
261                        .and_log_err()
262                    {
263                        sm.process(ctx.init(id), input);
264                        continue;
265                    }
266                }
267            }
268            .in_current_span(),
269        );
270    }
271}
272
273pub(crate) type BoxedStateMachine<K> = Box<dyn StateMachine<K>>;
274
275/// Represents the trait that a state machine must fulfill to process signals
276/// A [`StateMachine`] consumes the `input` portion of a [`Signal`] and...
277/// * optionally emits [`Signal`]s - consumed by the [`StateMachineManager`] `signal_queue`
278/// * optionally emits [`Notification`]s  - consumed by the [`NotificationQueue`]
279pub trait StateMachine<K>: Send + Sync
280where
281    K: Rex,
282{
283    fn process(&self, ctx: SmContext<K>, input: K::Input);
284
285    fn get_kind(&self) -> K;
286
287    fn new_child(&self, ctx: &SmContext<K>, child_id: StateId<K>) {
288        let id = ctx.id;
289        let tree = ctx.state_store.get_tree(id).unwrap();
290        ctx.state_store.insert_ref(child_id, tree.clone());
291        let mut tree = tree.lock();
292        tree.insert(Insert {
293            parent_id: Some(ctx.id),
294            id: child_id,
295        });
296    }
297
298    /// Panic: will panic if passed in an id without a previously stored state
299    fn update(&self, ctx: &SmContext<K>, state: K::State) {
300        let id = ctx.id;
301        let tree = ctx.state_store.get_tree(id).expect("missing id for update");
302        let mut guard = tree.lock();
303
304        guard.update(Update { id, state });
305    }
306}
307
308pub trait StateMachineExt<K>: StateMachine<K>
309where
310    K: Rex,
311    K::Message: TimeoutMessage<K>,
312{
313    /// NOTE [`StateMachineExt::new`] is created without a hierarchy
314    fn create_tree(&self, ctx: &SmContext<K>) {
315        let id = ctx.id;
316        ctx.state_store
317            .insert_ref(id, Arc::new(FairMutex::new(Node::new(id))));
318    }
319
320    fn fail(&self, ctx: &SmContext<K>) -> Option<StateId<K>> {
321        let id = ctx.id;
322        self.update_state_and_signal(ctx, id.failed_state())
323    }
324
325    fn complete(&self, ctx: &SmContext<K>) -> Option<StateId<K>> {
326        let id = ctx.id;
327        self.update_state_and_signal(ctx, id.completed_state())
328    }
329
330    /// update state is meant to be used to signal a parent state of a child state
331    /// _if_ a parent exists, this function makes no assumptions of the potential
332    /// structure of a state hierarchy and _should_ be just as performant on a single
333    /// state tree as it is for multiple states.
334    /// Returns the parent's [`StateId`] if there was one.
335    fn update_state_and_signal(&self, ctx: &SmContext<K>, state: K::State) -> Option<StateId<K>> {
336        let id = ctx.id;
337        let Some(tree) = ctx.get_tree() else {
338            // TODO propagate error
339            tracing::error!(%id, "Tree not found!");
340            panic!("missing SmTree");
341        };
342
343        let parent_id = tree.lock().update_and_get_parent_id(Update { id, state });
344        if let Some(id) = parent_id {
345            ctx.signal_queue.signal_state_change(id, state);
346
347            return Some(id);
348        }
349
350        None
351    }
352
353    fn notify(&self, ctx: &SmContext<K>, msg: impl Into<K::Message>) {
354        ctx.notify(Notification(msg.into()));
355    }
356
357    fn set_timeout(&self, ctx: &SmContext<K>, duration: Duration) {
358        ctx.notification_queue.priority_send(Notification(
359            TimeoutInput::set_timeout(ctx.id, duration).into(),
360        ));
361    }
362
363    fn return_in(&self, ctx: &SmContext<K>, item: RetainItem<K>, duration: Duration) {
364        ctx.notification_queue.priority_send(Notification(
365            TimeoutInput::retain(ctx.id, item, duration).into(),
366        ));
367    }
368
369    fn cancel_timeout(&self, ctx: &SmContext<K>) {
370        ctx.notification_queue
371            .priority_send(Notification(TimeoutInput::cancel_timeout(ctx.id).into()));
372    }
373}
374
375impl<K, T> StateMachineExt<K> for T
376where
377    T: StateMachine<K>,
378    K: Rex,
379    K::Message: TimeoutMessage<K>,
380{
381}
382
383#[cfg(test)]
384mod tests {
385    use std::time::Duration;
386
387    use bigerror::{ConversionError, Report};
388    use dashmap::DashMap;
389    use tokio::time::Instant;
390    use tracing::*;
391
392    use super::*;
393    use crate::{
394        notification::GetTopic,
395        test_support::Hold,
396        timeout::{Timeout, TimeoutMessage, TimeoutTopic, TEST_TICK_RATE, TEST_TIMEOUT},
397        Rex, RexBuilder, RexMessage, State,
398    };
399
400    impl From<TimeoutInput<Game>> for GameMsg {
401        fn from(value: TimeoutInput<Game>) -> Self {
402            Self(value)
403        }
404    }
405
406    #[derive(Debug, Clone)]
407    pub struct GameMsg(TimeoutInput<Game>);
408    impl GetTopic<TimeoutTopic> for GameMsg {
409        fn get_topic(&self) -> TimeoutTopic {
410            TimeoutTopic
411        }
412    }
413    impl RexMessage for GameMsg {
414        type Topic = TimeoutTopic;
415    }
416
417    impl TryInto<TimeoutInput<Game>> for GameMsg {
418        type Error = Report<ConversionError>;
419
420        fn try_into(self) -> Result<TimeoutInput<Game>, Self::Error> {
421            Ok(self.0)
422        }
423    }
424
425    impl TimeoutMessage<Game> for GameMsg {
426        type Item = Hold<Packet>;
427    }
428
429    #[derive(Copy, Clone, Debug, derive_more::Display)]
430    #[display("{msg}")]
431    pub struct Packet {
432        msg: u64,
433        sender: StateId<Game>,
434        who_holds: WhoHolds,
435    }
436
437    #[derive(Clone, Debug, derive_more::From)]
438    pub enum GameInput {
439        Ping(PingInput),
440        Pong(PongInput),
441        Menu(MenuInput),
442    }
443
444    // determines whether Ping or Pong will await before packet send
445    //
446    #[derive(Copy, Clone, PartialEq, Eq, Debug)]
447    pub struct WhoHolds(Option<Game>);
448
449    #[derive(Clone, PartialEq, Eq, Debug)]
450    pub enum MenuInput {
451        Play(WhoHolds),
452        PingPongComplete,
453        FailedPing,
454        FailedPong,
455    }
456
457    #[derive(Copy, Clone, PartialEq, Eq, Default, Debug)]
458    pub enum MenuState {
459        #[default]
460        Ready,
461        Done,
462        Failed,
463    }
464
465    #[derive(Copy, Clone, PartialEq, Eq, Default, Debug)]
466    pub enum PingState {
467        #[default]
468        Ready,
469        Sending,
470        Done,
471        Failed,
472    }
473
474    #[derive(Clone, Debug, derive_more::From)]
475    pub enum PingInput {
476        StartSending(StateId<Game>, WhoHolds),
477        Returned(Hold<Packet>),
478        Packet(Packet),
479        #[allow(dead_code)]
480        RecvTimeout(Instant),
481    }
482
483    #[derive(Copy, Clone, PartialEq, Eq, Default, Debug)]
484    pub enum PongState {
485        #[default]
486        Ready,
487        Responding,
488        Done,
489        Failed,
490    }
491
492    #[derive(Clone, Debug, derive_more::From)]
493    pub enum PongInput {
494        Packet(Packet),
495        #[allow(dead_code)]
496        RecvTimeout(Instant),
497        Returned(Hold<Packet>),
498    }
499
500    #[derive(Copy, Clone, PartialEq, Eq, Debug)]
501    pub enum GameState {
502        Ping(PingState),
503        Pong(PongState),
504        Menu(MenuState),
505    }
506
507    impl State for GameState {
508        type Input = GameInput;
509    }
510    impl AsRef<Game> for GameState {
511        fn as_ref(&self) -> &Game {
512            match self {
513                Self::Ping(_) => &Game::Ping,
514                Self::Pong(_) => &Game::Pong,
515                Self::Menu(_) => &Game::Menu,
516            }
517        }
518    }
519
520    #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
521    pub enum Game {
522        Ping,
523        Pong,
524        Menu,
525    }
526
527    impl Rex for Game {
528        type Message = GameMsg;
529
530        fn state_input(&self, state: <Self as Kind>::State) -> Option<Self::Input> {
531            if *self != Self::Menu {
532                return None;
533            }
534
535            match state {
536                GameState::Ping(PingState::Done) => Some(MenuInput::PingPongComplete),
537                GameState::Ping(PingState::Failed) => Some(MenuInput::FailedPing),
538                GameState::Pong(PongState::Failed) => Some(MenuInput::FailedPong),
539                _ => None,
540            }
541            .map(std::convert::Into::into)
542        }
543
544        fn timeout_input(&self, instant: Instant) -> Option<Self::Input> {
545            match self {
546                Self::Ping => Some(PingInput::RecvTimeout(instant).into()),
547                Self::Pong => Some(PongInput::RecvTimeout(instant).into()),
548                Self::Menu => None,
549            }
550        }
551    }
552
553    impl Timeout for Game {
554        fn return_item(&self, item: RetainItem<Self>) -> Option<Self::Input> {
555            match self {
556                Self::Ping => Some(GameInput::Ping(item.into())),
557                Self::Pong => Some(GameInput::Pong(item.into())),
558                Self::Menu => None,
559            }
560        }
561    }
562
563    impl Kind for Game {
564        type State = GameState;
565        type Input = GameInput;
566
567        fn new_state(&self) -> Self::State {
568            match self {
569                Self::Ping => GameState::Ping(PingState::default()),
570                Self::Pong => GameState::Pong(PongState::default()),
571                Self::Menu => GameState::Menu(MenuState::default()),
572            }
573        }
574
575        fn failed_state(&self) -> Self::State {
576            match self {
577                Self::Ping => GameState::Ping(PingState::Failed),
578                Self::Pong => GameState::Pong(PongState::Failed),
579                Self::Menu => GameState::Menu(MenuState::Failed),
580            }
581        }
582
583        fn completed_state(&self) -> Self::State {
584            match self {
585                Self::Ping => GameState::Ping(PingState::Done),
586                Self::Pong => GameState::Pong(PongState::Done),
587                Self::Menu => GameState::Menu(MenuState::Done),
588            }
589        }
590    }
591
592    struct MenuStateMachine {
593        failures: Arc<DashMap<StateId<Game>, MenuInput>>,
594    }
595
596    impl StateMachine<Game> for MenuStateMachine {
597        #[instrument(name = "menu", skip_all)]
598        fn process(&self, ctx: SmContext<Game>, input: GameInput) {
599            let id = ctx.id;
600            let GameInput::Menu(input) = input else {
601                error!(input = ?input, "invalid input!");
602                return;
603            };
604
605            let state = ctx.get_state();
606            if state.map(Game::is_terminal) == Some(true) {
607                warn!(%id, ?state, "Ignoring input due to invalid state");
608                return;
609            }
610
611            match input {
612                MenuInput::Play(who_holds) => {
613                    let ping_id = StateId::new_rand(Game::Ping);
614                    let pong_id = StateId::new_rand(Game::Pong);
615                    // Menu + Ping + Pong
616                    self.create_tree(&ctx);
617                    self.new_child(&ctx, ping_id);
618                    self.new_child(&ctx, pong_id);
619                    // signal to Ping state machine
620                    ctx.signal_queue.push_back(Signal {
621                        id: ping_id,
622                        input: GameInput::Ping(PingInput::StartSending(pong_id, who_holds)),
623                    });
624                }
625                MenuInput::PingPongComplete => {
626                    info!("I'M DONE!");
627                    self.complete(&ctx);
628                }
629                failure @ (MenuInput::FailedPing | MenuInput::FailedPong) => {
630                    let tree = ctx.get_tree().unwrap();
631                    // set all states to failed state
632                    tree.lock().update_all_fn(|mut z| {
633                        z.node.state = z.node.state.as_ref().failed_state();
634                        let id = z.node.id;
635                        ctx.notification_queue
636                            .priority_send(Notification(TimeoutInput::cancel_timeout(id).into()));
637                        z.finish_update()
638                    });
639
640                    error!(input = ?failure, "Encountered Failure!");
641                    self.failures.insert(id, failure);
642                }
643            }
644        }
645
646        fn get_kind(&self) -> Game {
647            Game::Menu
648        }
649    }
650
651    impl MenuStateMachine {
652        fn new() -> Self {
653            Self {
654                failures: Arc::new(DashMap::new()),
655            }
656        }
657    }
658
659    struct PingStateMachine;
660
661    impl StateMachine<Game> for PingStateMachine {
662        #[instrument(name = "ping", skip_all)]
663        fn process(&self, ctx: SmContext<Game>, input: GameInput) {
664            let id = ctx.id;
665            let GameInput::Ping(input) = input else {
666                error!(?input, "invalid input!");
667                return;
668            };
669            assert!(ctx.get_parent_id().is_some());
670            let state = ctx.get_state().unwrap();
671            if Game::is_terminal(state) {
672                warn!(%id, ?state, "Ignoring input due to invalid state");
673                return;
674            }
675
676            match input {
677                PingInput::StartSending(pong_id, who_holds) => {
678                    self.update(&ctx, GameState::Ping(PingState::Sending));
679                    info!(msg = 0, "PINGING");
680                    ctx.signal_queue.push_back(Signal {
681                        id: pong_id,
682                        input: GameInput::Pong(PongInput::Packet(Packet {
683                            msg: 0,
684                            sender: id,
685                            who_holds,
686                        })),
687                    });
688                    // TODO let timeout = now + Duration::from_millis(250);
689                }
690                PingInput::Packet(Packet { msg: 25.., .. }) => {
691                    info!("PING Complete!");
692                    self.complete(&ctx);
693                    self.cancel_timeout(&ctx);
694                }
695                PingInput::Packet(mut packet) => {
696                    self.set_timeout(&ctx, TEST_TIMEOUT);
697                    packet.msg += 5;
698
699                    if packet.who_holds == WhoHolds(Some(Game::Ping)) {
700                        info!(msg = packet.msg, "HOLDING");
701                        // hold for half theduration of the message
702                        let hold_for = Duration::from_millis(packet.msg);
703                        self.return_in(&ctx, Hold(packet), hold_for);
704                        return;
705                    }
706
707                    info!(msg = packet.msg, "PINGING");
708                    return_packet(&ctx, packet);
709                }
710                PingInput::Returned(Hold(packet)) => {
711                    self.set_timeout(&ctx, TEST_TIMEOUT);
712                    info!(msg = packet.msg, "PINGING");
713                    return_packet(&ctx, packet);
714                }
715
716                PingInput::RecvTimeout(_) => {
717                    self.fail(&ctx);
718                }
719            }
720        }
721
722        fn get_kind(&self) -> Game {
723            Game::Ping
724        }
725    }
726
727    struct PongStateMachine;
728
729    impl StateMachine<Game> for PongStateMachine {
730        #[instrument(name = "pong", skip_all, fields(id = %ctx.id))]
731        fn process(&self, ctx: SmContext<Game>, input: GameInput) {
732            let GameInput::Pong(input) = input else {
733                error!(?input, "invalid input!");
734                return;
735            };
736            let state = ctx.get_state().unwrap();
737            if Game::is_terminal(state) {
738                warn!(?state, "Ignoring input due to invalid state");
739                return;
740            }
741            assert!(ctx.get_parent_id().is_some());
742
743            match input {
744                PongInput::Packet(Packet {
745                    // https://doc.rust-lang.org/book/ch18-03-pattern-syntax.html#-bindings
746                    msg: mut msg @ 20..,
747                    sender,
748                    who_holds,
749                }) => {
750                    msg += 5;
751                    info!(?msg, "PONGING");
752                    self.complete(&ctx);
753                    self.cancel_timeout(&ctx);
754                    ctx.signal_queue.push_back(Signal {
755                        id: sender,
756                        input: GameInput::Ping(PingInput::Packet(Packet {
757                            msg,
758                            sender: ctx.id,
759                            who_holds,
760                        })),
761                    });
762                }
763                PongInput::Packet(mut packet) => {
764                    self.set_timeout(&ctx, TEST_TIMEOUT);
765                    if packet.msg == 0 {
766                        self.update(&ctx, GameState::Pong(PongState::Responding));
767                    }
768                    packet.msg += 5;
769
770                    if packet.who_holds == WhoHolds(Some(Game::Pong)) {
771                        info!(msg = packet.msg, "HOLDING");
772                        // hold for half the duration of the message
773                        let hold_for = Duration::from_millis(packet.msg);
774                        self.return_in(&ctx, Hold(packet), hold_for);
775                        return;
776                    }
777
778                    info!(msg = packet.msg, "PONGING");
779                    return_packet(&ctx, packet);
780                }
781                PongInput::Returned(Hold(packet)) => {
782                    self.set_timeout(&ctx, TEST_TIMEOUT);
783                    info!(msg = packet.msg, "PONGING");
784                    return_packet(&ctx, packet);
785                }
786                PongInput::RecvTimeout(_) => {
787                    self.fail(&ctx);
788                }
789            }
790        }
791        fn get_kind(&self) -> Game {
792            Game::Pong
793        }
794    }
795    fn return_packet(ctx: &SmContext<Game>, mut packet: Packet) {
796        let recipient = packet.sender;
797        packet.sender = ctx.id;
798        ctx.signal_queue.push_back(Signal {
799            id: recipient,
800            input: match *ctx.id {
801                Game::Ping => PongInput::Packet(packet).into(),
802                Game::Pong => PingInput::Packet(packet).into(),
803                Game::Menu => unreachable!(),
804            },
805        });
806    }
807
808    #[tracing_test::traced_test]
809    #[tokio::test]
810    async fn state_machine() {
811        // This test does not initialize the NotificationManager
812        let ctx = RexBuilder::new()
813            .with_sm(MenuStateMachine::new())
814            .with_sm(PingStateMachine)
815            .with_sm(PongStateMachine)
816            .build();
817
818        let menu_id = StateId::new_rand(Game::Menu);
819        ctx.signal_queue.push_back(Signal {
820            id: menu_id,
821            input: GameInput::Menu(MenuInput::Play(WhoHolds(None))),
822        });
823        tokio::time::sleep(Duration::from_millis(1)).await;
824
825        let tree = ctx.state_store.get_tree(menu_id).unwrap();
826        let node = tree.lock();
827        let ping_id = node.children[0].id;
828        let pong_id = node.children[1].id;
829        assert_eq!(menu_id, node.id);
830        assert_eq!(GameState::Menu(MenuState::Done), node.state);
831
832        // !!NOTE!! ============================================================
833        // we are trying to acquire another lock...
834        // * from the SAME Mutex
835        // * in the SAME thread
836        // this WILL deadlock unless the previous lock is dropped
837        drop(node);
838        // !!NOTE!! ============================================================
839
840        // ensure MenuState is also indexed by ping id
841        let tree = ctx.state_store.get_tree(ping_id).unwrap();
842        let node = tree.lock();
843        let state = node.get_state(ping_id).unwrap();
844        assert_eq!(GameState::Ping(PingState::Done), *state);
845
846        drop(node);
847
848        let tree = ctx.state_store.get_tree(pong_id).unwrap();
849        let state = tree.lock().get_state(pong_id).copied().unwrap();
850        assert_eq!(GameState::Pong(PongState::Done), state);
851    }
852
853    #[tracing_test::traced_test]
854    #[tokio::test]
855    async fn pong_timeout() {
856        let menu_sm = MenuStateMachine::new();
857        let menu_failures = menu_sm.failures.clone();
858        let ctx = RexBuilder::new()
859            .with_sm(menu_sm)
860            .with_sm(PingStateMachine)
861            .with_sm(PongStateMachine)
862            .with_timeout_manager(TimeoutTopic)
863            .with_tick_rate(TEST_TICK_RATE / 2)
864            .build();
865
866        let menu_id = StateId::new_rand(Game::Menu);
867        ctx.signal_queue.push_back(Signal {
868            id: menu_id,
869            input: GameInput::Menu(MenuInput::Play(WhoHolds(Some(Game::Ping)))),
870        });
871
872        tokio::time::sleep(TEST_TIMEOUT * 4).await;
873
874        {
875            let tree = ctx.state_store.get_tree(menu_id).unwrap();
876            let node = tree.lock();
877            let ping = &node.children[0];
878            let pong = &node.children[1];
879            assert_eq!(menu_id, node.id);
880            assert_eq!(GameState::Menu(MenuState::Failed), node.state);
881            assert_eq!(GameState::Ping(PingState::Failed), ping.state);
882            assert_eq!(GameState::Pong(PongState::Failed), pong.state);
883
884            // !!NOTE!! ============================================================
885            // we are trying to acquire another lock...
886            // * from the SAME Mutex
887            // * in the SAME thread
888            // this WILL deadlock unless the previous lock is dropped
889            drop(node);
890            // !!NOTE!! ============================================================
891        }
892
893        // Ensure that our Menu failed due to Pong timing out since
894        // Ping "slept on" the packet
895        assert_eq!(MenuInput::FailedPong, *menu_failures.get(&menu_id).unwrap());
896    }
897
898    #[tracing_test::traced_test]
899    #[tokio::test]
900    async fn ping_timeout() {
901        // Now fail due to Ping
902        let menu_sm = MenuStateMachine::new();
903        let menu_failures = menu_sm.failures.clone();
904        let ctx = RexBuilder::new()
905            .with_sm(menu_sm)
906            .with_sm(PingStateMachine)
907            .with_sm(PongStateMachine)
908            .with_timeout_manager(TimeoutTopic)
909            .with_tick_rate(TEST_TICK_RATE / 2)
910            .build();
911        let menu_id = StateId::new_rand(Game::Menu);
912        ctx.signal_queue.push_back(Signal {
913            id: menu_id,
914            input: GameInput::Menu(MenuInput::Play(WhoHolds(Some(Game::Pong)))),
915        });
916
917        tokio::time::sleep(TEST_TIMEOUT * 4).await;
918
919        let tree = ctx.state_store.get_tree(menu_id).unwrap();
920        let node = tree.lock();
921        let ping_node = &node.children[0];
922        let pong_node = &node.children[1];
923        assert_eq!(menu_id, node.id);
924        assert_eq!(GameState::Menu(MenuState::Failed), node.state);
925        assert_eq!(GameState::Ping(PingState::Failed), ping_node.state);
926        assert_eq!(GameState::Pong(PongState::Failed), pong_node.state);
927        // Ensure that our Menu failed due to Ping
928        drop(node);
929        assert_eq!(MenuInput::FailedPing, *menu_failures.get(&menu_id).unwrap());
930    }
931}