axum_server/
handle.rs

1use crate::{notify_once::NotifyOnce, server::Address};
2use std::{
3    sync::{
4        atomic::{AtomicUsize, Ordering},
5        Arc, Mutex,
6    },
7    time::Duration,
8};
9use tokio::{sync::Notify, time::sleep};
10
11/// A handle for [`Server`](crate::server::Server).
12#[derive(Clone, Debug)]
13pub struct Handle<A: Address> {
14    inner: Arc<HandleInner<A>>,
15}
16
17impl<A: Address> Default for Handle<A> {
18    fn default() -> Self {
19        Self {
20            inner: Default::default(),
21        }
22    }
23}
24
25#[derive(Debug)]
26struct HandleInner<A: Address> {
27    addr: Mutex<Option<A>>,
28    addr_notify: Notify,
29    conn_count: AtomicUsize,
30    shutdown: NotifyOnce,
31    graceful: NotifyOnce,
32    graceful_dur: Mutex<Option<Duration>>,
33    conn_end: NotifyOnce,
34}
35
36// Manually implemented as the derive macro will want A to be Default.
37impl<A: Address> Default for HandleInner<A> {
38    fn default() -> Self {
39        Self {
40            addr: Default::default(),
41            addr_notify: Default::default(),
42            conn_count: Default::default(),
43            shutdown: Default::default(),
44            graceful: Default::default(),
45            graceful_dur: Default::default(),
46            conn_end: Default::default(),
47        }
48    }
49}
50
51impl<A: Address> Handle<A> {
52    /// Create a new handle.
53    pub fn new() -> Self {
54        Self::default()
55    }
56
57    /// Get the number of connections.
58    pub fn connection_count(&self) -> usize {
59        self.inner.conn_count.load(Ordering::SeqCst)
60    }
61
62    /// Shutdown the server.
63    pub fn shutdown(&self) {
64        self.inner.shutdown.notify_waiters();
65    }
66
67    /// Gracefully shutdown the server.
68    ///
69    /// `None` means indefinite grace period.
70    pub fn graceful_shutdown(&self, duration: Option<Duration>) {
71        *self.inner.graceful_dur.lock().unwrap() = duration;
72
73        self.inner.graceful.notify_waiters();
74    }
75
76    /// Returns local address and port when server starts listening.
77    ///
78    /// Returns `None` if server fails to bind.
79    pub async fn listening(&self) -> Option<A> {
80        let notified = self.inner.addr_notify.notified();
81
82        if let Some(addr) = self.inner.addr.lock().unwrap().clone() {
83            return Some(addr);
84        }
85
86        notified.await;
87
88        self.inner.addr.lock().unwrap().clone()
89    }
90
91    pub(crate) fn notify_listening(&self, addr: Option<A>) {
92        *self.inner.addr.lock().unwrap() = addr;
93
94        self.inner.addr_notify.notify_waiters();
95    }
96
97    pub(crate) fn watcher(&self) -> Watcher<A> {
98        Watcher::new(self.clone())
99    }
100
101    pub(crate) async fn wait_shutdown(&self) {
102        self.inner.shutdown.notified().await;
103    }
104
105    pub(crate) async fn wait_graceful_shutdown(&self) {
106        self.inner.graceful.notified().await;
107    }
108
109    pub(crate) async fn wait_connections_end(&self) {
110        if self.inner.conn_count.load(Ordering::SeqCst) == 0 {
111            return;
112        }
113
114        let deadline = *self.inner.graceful_dur.lock().unwrap();
115
116        match deadline {
117            Some(duration) => tokio::select! {
118                biased;
119                _ = sleep(duration) => self.shutdown(),
120                _ = self.inner.conn_end.notified() => (),
121            },
122            None => self.inner.conn_end.notified().await,
123        }
124    }
125}
126
127pub(crate) struct Watcher<A: Address> {
128    handle: Handle<A>,
129}
130
131impl<A: Address> Watcher<A> {
132    fn new(handle: Handle<A>) -> Self {
133        handle.inner.conn_count.fetch_add(1, Ordering::SeqCst);
134
135        Self { handle }
136    }
137
138    pub(crate) async fn wait_graceful_shutdown(&self) {
139        self.handle.wait_graceful_shutdown().await
140    }
141
142    pub(crate) async fn wait_shutdown(&self) {
143        self.handle.wait_shutdown().await
144    }
145}
146
147impl<A: Address> Drop for Watcher<A> {
148    fn drop(&mut self) {
149        let count = self.handle.inner.conn_count.fetch_sub(1, Ordering::SeqCst) - 1;
150
151        if count == 0 && self.handle.inner.graceful.is_notified() {
152            self.handle.inner.conn_end.notify_waiters();
153        }
154    }
155}