1use std::pin::pin;
2use std::time::Duration;
3
4use futures::FutureExt;
5use mm1_address::address::Address;
6use mm1_common::errors::error_of::ErrorOf;
7use mm1_common::impl_error_kind;
8use mm1_proc_macros::dispatch;
9use mm1_proto::message;
10use mm1_proto_system::Down;
11use tracing::warn;
12
13use super::{ForkErrorKind, Messaging, RecvErrorKind};
14use crate::context::{Fork, Watching};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
17#[message(base_path = ::mm1_proto)]
18pub enum ShutdownErrorKind {
19 InternalError,
20 Fork(ForkErrorKind),
21 Recv(RecvErrorKind),
22}
23
24pub trait Stop {
25 fn shutdown(
26 &mut self,
27 peer: Address,
28 stop_timeout: Duration,
29 ) -> impl Future<Output = Result<(), ErrorOf<ShutdownErrorKind>>> + Send
30 where
31 Self: Watching + Fork + Messaging,
32 {
33 async move {
34 let mut fork = self
35 .fork()
36 .await
37 .map_err(|e| e.map_kind(ShutdownErrorKind::Fork))?;
38
39 let watch_ref = fork.watch(peer).await;
40
41 let mut shutdown_sequence = pin!(
42 async {
43 self.exit(peer).await;
44 tokio::time::sleep(stop_timeout).await;
45 self.kill(peer).await;
46 }
47 .fuse()
48 );
49
50 let mut recv_result = pin!(
51 async {
52 loop {
53 dispatch!(match fork
54 .recv()
55 .await
56 .map_err(|e| e.map_kind(ShutdownErrorKind::Recv))?
57 {
58 down @ Down { .. }
59 if down.watch_ref == watch_ref && down.peer == peer =>
60 {
61 break Ok(())
62 },
63
64 unexpected @ _ => {
65 warn!("unexpected message: {:?}", unexpected);
66 },
67 })
68 }
69 }
70 .fuse()
71 );
72
73 loop {
74 tokio::select! {
75 recv_result = recv_result.as_mut() => { break recv_result },
76 () = shutdown_sequence.as_mut() => {}
77 }
78 }
79 }
80 }
81
82 fn exit(&mut self, peer: Address) -> impl Future<Output = bool> + Send;
83 fn kill(&mut self, peer: Address) -> impl Future<Output = bool> + Send;
84}
85
86impl_error_kind!(ShutdownErrorKind);