mm1_core/context/
stop.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
use std::future::Future;
use std::pin::pin;
use std::time::Duration;

use futures::FutureExt;
use mm1_address::address::Address;
use mm1_common::errors::error_of::ErrorOf;
use mm1_common::impl_error_kind;
use mm1_proc_macros::dispatch;
use mm1_proto_system::{self as system, Down, System};
use tracing::warn;

use super::{ForkErrorKind, Recv, RecvErrorKind};
use crate::context::{Call, Fork, Watching};

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ShutdownErrorKind {
    InternalError,
    Fork(ForkErrorKind),
    Recv(RecvErrorKind),
}

pub trait Stop<Sys>:
    Call<Sys, system::Exit, Outcome = bool> + Call<Sys, system::Kill, Outcome = bool>
where
    Sys: System,
{
    fn shutdown(
        &mut self,
        peer: Address,
        stop_timeout: Duration,
    ) -> impl Future<Output = Result<(), ErrorOf<ShutdownErrorKind>>> + Send
    where
        Sys: Default,
        Self: Watching<Sys> + Fork + Recv,
    {
        async move {
            let mut fork = self
                .fork()
                .await
                .map_err(|e| e.map_kind(ShutdownErrorKind::Fork))?;

            let watch_ref = fork.watch(peer).await;

            let mut shutdown_sequence = pin!(async {
                self.exit(peer).await;
                tokio::time::sleep(stop_timeout).await;
                self.kill(peer).await;
            }
            .fuse());

            let mut recv_result = pin!(async {
                loop {
                    dispatch!(match fork
                        .recv()
                        .await
                        .map_err(|e| e.map_kind(ShutdownErrorKind::Recv))?
                    {
                        down @ Down { .. } if down.watch_ref == watch_ref && down.peer == peer => {
                            break Ok(())
                        },

                        unexpected @ _ => {
                            warn!("unexpected message: {:?}", unexpected);
                        },
                    })
                }
            }
            .fuse());

            loop {
                tokio::select! {
                    recv_result = recv_result.as_mut() => { break recv_result },
                    () = shutdown_sequence.as_mut() => {}
                }
            }
        }
    }

    fn exit(&mut self, peer: Address) -> impl Future<Output = bool> + Send
    where
        Sys: Default,
    {
        async move { self.call(Sys::default(), system::Exit { peer }).await }
    }

    fn kill(&mut self, peer: Address) -> impl Future<Output = bool> + Send
    where
        Sys: Default,
    {
        async move { self.call(Sys::default(), system::Kill { peer }).await }
    }
}

impl<Sys, T> Stop<Sys> for T
where
    T: Call<Sys, system::Exit, Outcome = bool> + Call<Sys, system::Kill, Outcome = bool>,
    Sys: System,
{
}

impl_error_kind!(ShutdownErrorKind);