mm1_core/context/
stop.rs

1use std::future::Future;
2use std::pin::pin;
3use std::time::Duration;
4
5use futures::FutureExt;
6use mm1_address::address::Address;
7use mm1_common::errors::error_of::ErrorOf;
8use mm1_common::impl_error_kind;
9use mm1_proc_macros::dispatch;
10use mm1_proto_system::{self as system, Down, System};
11use tracing::warn;
12
13use super::{ForkErrorKind, Recv, RecvErrorKind};
14use crate::context::{Call, Fork, Watching};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
17pub enum ShutdownErrorKind {
18    InternalError,
19    Fork(ForkErrorKind),
20    Recv(RecvErrorKind),
21}
22
23pub trait Stop<Sys>:
24    Call<Sys, system::Exit, Outcome = bool> + Call<Sys, system::Kill, Outcome = bool>
25where
26    Sys: System,
27{
28    fn shutdown(
29        &mut self,
30        peer: Address,
31        stop_timeout: Duration,
32    ) -> impl Future<Output = Result<(), ErrorOf<ShutdownErrorKind>>> + Send
33    where
34        Sys: Default,
35        Self: Watching<Sys> + Fork + Recv,
36    {
37        async move {
38            let mut fork = self
39                .fork()
40                .await
41                .map_err(|e| e.map_kind(ShutdownErrorKind::Fork))?;
42
43            let watch_ref = fork.watch(peer).await;
44
45            let mut shutdown_sequence = pin!(async {
46                self.exit(peer).await;
47                tokio::time::sleep(stop_timeout).await;
48                self.kill(peer).await;
49            }
50            .fuse());
51
52            let mut recv_result = pin!(async {
53                loop {
54                    dispatch!(match fork
55                        .recv()
56                        .await
57                        .map_err(|e| e.map_kind(ShutdownErrorKind::Recv))?
58                    {
59                        down @ Down { .. } if down.watch_ref == watch_ref && down.peer == peer => {
60                            break Ok(())
61                        },
62
63                        unexpected @ _ => {
64                            warn!("unexpected message: {:?}", unexpected);
65                        },
66                    })
67                }
68            }
69            .fuse());
70
71            loop {
72                tokio::select! {
73                    recv_result = recv_result.as_mut() => { break recv_result },
74                    () = shutdown_sequence.as_mut() => {}
75                }
76            }
77        }
78    }
79
80    fn exit(&mut self, peer: Address) -> impl Future<Output = bool> + Send
81    where
82        Sys: Default,
83    {
84        async move { self.call(Sys::default(), system::Exit { peer }).await }
85    }
86
87    fn kill(&mut self, peer: Address) -> impl Future<Output = bool> + Send
88    where
89        Sys: Default,
90    {
91        async move { self.call(Sys::default(), system::Kill { peer }).await }
92    }
93}
94
95impl<Sys, T> Stop<Sys> for T
96where
97    T: Call<Sys, system::Exit, Outcome = bool> + Call<Sys, system::Kill, Outcome = bool>,
98    Sys: System,
99{
100}
101
102impl_error_kind!(ShutdownErrorKind);