1use crate::notify_once::NotifyOnce;
2use std::{
3 net::SocketAddr,
4 sync::{
5 atomic::{AtomicUsize, Ordering},
6 Arc, Mutex,
7 },
8 time::Duration,
9};
10use tokio::{sync::Notify, time::sleep};
11
12#[derive(Clone, Debug, Default)]
14pub struct Handle {
15 inner: Arc<HandleInner>,
16}
17
18#[derive(Debug, Default)]
19struct HandleInner {
20 addr: Mutex<Option<SocketAddr>>,
21 addr_notify: Notify,
22 conn_count: AtomicUsize,
23 shutdown: NotifyOnce,
24 graceful: NotifyOnce,
25 graceful_dur: Mutex<Option<Duration>>,
26 conn_end: NotifyOnce,
27}
28
29impl Handle {
30 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn connection_count(&self) -> usize {
37 self.inner.conn_count.load(Ordering::SeqCst)
38 }
39
40 pub fn shutdown(&self) {
42 self.inner.shutdown.notify_waiters();
43 }
44
45 pub fn graceful_shutdown(&self, duration: Option<Duration>) {
49 *self.inner.graceful_dur.lock().unwrap() = duration;
50
51 self.inner.graceful.notify_waiters();
52 }
53
54 pub async fn listening(&self) -> Option<SocketAddr> {
58 let notified = self.inner.addr_notify.notified();
59
60 if let Some(addr) = *self.inner.addr.lock().unwrap() {
61 return Some(addr);
62 }
63
64 notified.await;
65
66 *self.inner.addr.lock().unwrap()
67 }
68
69 pub(crate) fn notify_listening(&self, addr: Option<SocketAddr>) {
70 *self.inner.addr.lock().unwrap() = addr;
71
72 self.inner.addr_notify.notify_waiters();
73 }
74
75 pub(crate) fn watcher(&self) -> Watcher {
76 Watcher::new(self.clone())
77 }
78
79 pub(crate) async fn wait_shutdown(&self) {
80 self.inner.shutdown.notified().await;
81 }
82
83 pub(crate) async fn wait_graceful_shutdown(&self) {
84 self.inner.graceful.notified().await;
85 }
86
87 pub(crate) async fn wait_connections_end(&self) {
88 if self.inner.conn_count.load(Ordering::SeqCst) == 0 {
89 return;
90 }
91
92 let deadline = *self.inner.graceful_dur.lock().unwrap();
93
94 match deadline {
95 Some(duration) => tokio::select! {
96 biased;
97 _ = sleep(duration) => self.shutdown(),
98 _ = self.inner.conn_end.notified() => (),
99 },
100 None => self.inner.conn_end.notified().await,
101 }
102 }
103}
104
105pub(crate) struct Watcher {
106 handle: Handle,
107}
108
109impl Watcher {
110 fn new(handle: Handle) -> Self {
111 handle.inner.conn_count.fetch_add(1, Ordering::SeqCst);
112
113 Self { handle }
114 }
115
116 pub(crate) async fn wait_graceful_shutdown(&self) {
117 self.handle.wait_graceful_shutdown().await
118 }
119
120 pub(crate) async fn wait_shutdown(&self) {
121 self.handle.wait_shutdown().await
122 }
123}
124
125impl Drop for Watcher {
126 fn drop(&mut self) {
127 let count = self.handle.inner.conn_count.fetch_sub(1, Ordering::SeqCst) - 1;
128
129 if count == 0 && self.handle.inner.graceful.is_notified() {
130 self.handle.inner.conn_end.notify_waiters();
131 }
132 }
133}