1use std::future::Future;
2use std::pin::pin;
3use std::time::Duration;
4
5use futures::FutureExt;
6use mm1_address::address::Address;
7use mm1_common::errors::error_of::ErrorOf;
8use mm1_common::impl_error_kind;
9use mm1_proc_macros::dispatch;
10use mm1_proto::message;
11use mm1_proto_system::Down;
12use tracing::warn;
13
14use super::{ForkErrorKind, Messaging, RecvErrorKind};
15use crate::context::{Fork, Watching};
16
17#[derive(Debug, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
18#[message]
19pub enum ShutdownErrorKind {
20 InternalError,
21 Fork(ForkErrorKind),
22 Recv(RecvErrorKind),
23}
24
25pub trait Stop {
26 fn shutdown(
27 &mut self,
28 peer: Address,
29 stop_timeout: Duration,
30 ) -> impl Future<Output = Result<(), ErrorOf<ShutdownErrorKind>>> + Send
31 where
32 Self: Watching + Fork + Messaging,
33 {
34 async move {
35 let mut fork = self
36 .fork()
37 .await
38 .map_err(|e| e.map_kind(ShutdownErrorKind::Fork))?;
39
40 let watch_ref = fork.watch(peer).await;
41
42 let mut shutdown_sequence = pin!(
43 async {
44 self.exit(peer).await;
45 tokio::time::sleep(stop_timeout).await;
46 self.kill(peer).await;
47 }
48 .fuse()
49 );
50
51 let mut recv_result = pin!(
52 async {
53 loop {
54 dispatch!(match fork
55 .recv()
56 .await
57 .map_err(|e| e.map_kind(ShutdownErrorKind::Recv))?
58 {
59 down @ Down { .. }
60 if down.watch_ref == watch_ref && down.peer == peer =>
61 {
62 break Ok(())
63 },
64
65 unexpected @ _ => {
66 warn!("unexpected message: {:?}", unexpected);
67 },
68 })
69 }
70 }
71 .fuse()
72 );
73
74 loop {
75 tokio::select! {
76 recv_result = recv_result.as_mut() => { break recv_result },
77 () = shutdown_sequence.as_mut() => {}
78 }
79 }
80 }
81 }
82
83 fn exit(&mut self, peer: Address) -> impl Future<Output = bool> + Send;
84 fn kill(&mut self, peer: Address) -> impl Future<Output = bool> + Send;
85}
86
87impl_error_kind!(ShutdownErrorKind);