use crate::prelude::{threading::*, *};
#[cfg(feature = "graceful-shutdown")]
use atomic::{AtomicBool, AtomicIsize};
#[cfg(feature = "graceful-shutdown")]
use std::cell::UnsafeCell;
#[cfg(feature = "graceful-shutdown")]
use tokio::sync::watch::{
channel as watch_channel, Receiver as WatchReceiver, Sender as WatchSender,
};
#[derive(Debug, Clone, Copy)]
#[cfg(feature = "graceful-shutdown")]
#[repr(transparent)]
pub(crate) struct WakerIndex(usize);
#[derive(Debug)]
#[cfg(feature = "graceful-shutdown")]
#[repr(transparent)]
pub(crate) struct WakerList(UnsafeCell<Vec<Option<Waker>>>);
#[cfg(feature = "graceful-shutdown")]
impl WakerList {
pub(crate) fn new(capacity: usize) -> Self {
Self(UnsafeCell::new(Vec::with_capacity(capacity)))
}
pub(crate) fn get_mut(&mut self) -> &mut Vec<Option<Waker>> {
self.0.get_mut()
}
pub(crate) fn get(&self) -> *mut Vec<Option<Waker>> {
self.0.get()
}
pub(crate) fn notify(&self) {
let wakers = unsafe { &mut *self.get() };
for waker in wakers.iter_mut().filter_map(Option::take) {
waker.wake();
}
}
}
#[cfg(feature = "graceful-shutdown")]
unsafe impl Sync for WakerList {}
#[cfg(feature = "graceful-shutdown")]
unsafe impl Send for WakerList {}
#[derive(Debug)]
#[must_use]
pub struct Manager {
#[cfg(feature = "graceful-shutdown")]
shutdown: AtomicBool,
#[cfg(feature = "graceful-shutdown")]
shutting_down: AtomicBool,
#[cfg(feature = "graceful-shutdown")]
connections: AtomicIsize,
#[cfg(feature = "graceful-shutdown")]
wakers: WakerList,
#[cfg(feature = "graceful-shutdown")]
channel: (Arc<WatchSender<()>>, WatchReceiver<()>),
#[cfg(feature = "graceful-shutdown")]
pre_shutdown_channel: (
Arc<WatchSender<tokio::sync::mpsc::UnboundedSender<()>>>,
WatchReceiver<tokio::sync::mpsc::UnboundedSender<()>>,
),
#[cfg(feature = "graceful-shutdown")]
pre_shutdown_count: Arc<atomic::AtomicUsize>,
#[cfg(feature = "graceful-shutdown")]
pub(crate) handover_socket_path: Option<PathBuf>,
}
impl Manager {
pub fn new(_capacity: usize) -> Self {
#[cfg(feature = "graceful-shutdown")]
{
let channel = watch_channel(());
let pre_shutdown_channel = watch_channel(tokio::sync::mpsc::unbounded_channel().0);
Self {
shutdown: AtomicBool::new(false),
shutting_down: AtomicBool::new(false),
connections: AtomicIsize::new(0),
wakers: WakerList::new(_capacity),
channel: (Arc::new(channel.0), channel.1),
pre_shutdown_channel: (Arc::new(pre_shutdown_channel.0), pre_shutdown_channel.1),
pre_shutdown_count: Arc::new(atomic::AtomicUsize::new(0)),
handover_socket_path: None,
}
}
#[cfg(not(feature = "graceful-shutdown"))]
{
Self {}
}
}
pub fn add_listener(&mut self, listener: TcpListener) -> AcceptManager {
AcceptManager {
#[cfg(feature = "graceful-shutdown")]
index: {
let wakers = self.wakers.get_mut();
let len = wakers.len();
wakers.push(None);
WakerIndex(len)
},
listener,
}
}
pub fn add_connection(&self) {
#[cfg(feature = "graceful-shutdown")]
{
self.connections.fetch_add(1, Ordering::Release);
debug!(
"Current connections: {}",
self.connections.load(Ordering::Acquire)
);
}
}
pub fn remove_connection(&self) {
#[cfg(feature = "graceful-shutdown")]
{
let connections = self.connections.fetch_sub(1, Ordering::AcqRel) - 1;
if connections < 0 {
info!(
"Connection count is less than 0. \
This might be an issue if you didn't explicitly \
call `ShutdownManager::remove_connection` in your code."
);
}
if connections <= 0 {
let shutdown = self.shutdown.load(Ordering::Acquire);
if shutdown {
debug!("There are no connections. Shutting down.");
self._shutdown();
}
}
debug!(
"Current connections: {}",
self.connections.load(Ordering::Acquire)
);
}
}
#[must_use]
pub fn get_connecions(&self) -> isize {
#[cfg(feature = "graceful-shutdown")]
{
self.connections.load(Ordering::Acquire)
}
#[cfg(not(feature = "graceful-shutdown"))]
{
0
}
}
#[cfg(feature = "graceful-shutdown")]
pub fn get_shutdown(&self, order: Ordering) -> bool {
self.shutdown.load(order)
}
#[cfg(feature = "graceful-shutdown")]
pub(crate) fn set_waker(&self, index: WakerIndex, waker: Waker) {
let wakers = unsafe { &mut *self.wakers.get() };
wakers[index.0] = Some(waker);
}
#[cfg(feature = "graceful-shutdown")]
pub(crate) fn remove_waker(&self, index: WakerIndex) {
let wakers = unsafe { &mut *self.wakers.get() };
wakers[index.0] = None;
}
#[must_use]
pub fn build(self) -> Arc<Self> {
Arc::new(self)
}
#[cfg(feature = "graceful-shutdown")]
pub fn shutdown(&self) {
info!(
"Initiating shutdown. Handover path: {:?}",
self.handover_socket_path
);
self.shutdown.store(true, Ordering::Release);
#[cfg(unix)]
if let Some(path) = &self.handover_socket_path {
std::fs::remove_file(&path).ok();
}
if self.connections.load(Ordering::Acquire) == 0 {
self._shutdown();
}
debug!(
"Current connections: {}",
self.connections.load(Ordering::Acquire)
);
info!("Notifying wakers.");
self.wakers.notify();
}
#[cfg(feature = "graceful-shutdown")]
fn _shutdown(&self) {
if self.shutting_down.swap(true, Ordering::AcqRel) {
return;
}
let channel = self.channel.0.clone();
let pre_channel = self.pre_shutdown_channel.0.clone();
let count = Arc::clone(&self.pre_shutdown_count);
tokio::spawn(async move {
let mut confirmation_channel = tokio::sync::mpsc::unbounded_channel();
pre_channel.send(confirmation_channel.0).unwrap();
let mut recieved = 0;
let wanted = count.load(Ordering::Acquire);
loop {
if recieved >= wanted {
break;
}
confirmation_channel.1.recv().await;
recieved += 1;
}
info!("Sending shutdown signal");
drop(channel.send(()));
});
}
pub async fn wait(&self) {
#[cfg(feature = "graceful-shutdown")]
{
let mut receiver = WatchReceiver::clone(&self.channel.1);
drop(receiver.changed().await);
info!("Received shutdown signal");
}
#[cfg(not(feature = "graceful-shutdown"))]
{
std::future::pending::<()>().await;
}
}
#[allow(clippy::manual_async_fn)] pub fn wait_for_pre_shutdown(
&self,
) -> impl Future<Output = tokio::sync::mpsc::UnboundedSender<()>> + '_ {
#[cfg(feature = "graceful-shutdown")]
{
let mut receiver = WatchReceiver::clone(&self.pre_shutdown_channel.1);
self.pre_shutdown_count.fetch_add(1, Ordering::SeqCst);
async move {
drop(receiver.changed().await);
info!("Received pre shutdown signal");
let borrow = receiver.borrow();
(*borrow).clone()
}
}
#[cfg(not(feature = "graceful-shutdown"))]
async {
std::future::pending::<()>().await;
unreachable!()
}
}
}
#[derive(Debug)]
#[must_use]
pub enum AcceptAction {
Shutdown,
Accept(io::Result<(TcpStream, SocketAddr)>),
}
#[derive(Debug)]
#[must_use]
pub struct AcceptManager {
#[cfg(feature = "graceful-shutdown")]
index: WakerIndex,
listener: TcpListener,
}
impl AcceptManager {
#[allow(clippy::let_and_return)] pub async fn accept(&mut self, _manager: &Manager) -> AcceptAction {
let action = AcceptFuture {
#[cfg(feature = "graceful-shutdown")]
manager: _manager,
#[cfg(feature = "graceful-shutdown")]
index: self.index,
listener: &mut self.listener,
}
.await;
#[cfg(feature = "graceful-shutdown")]
_manager.remove_waker(self.index);
action
}
pub fn get_inner(&self) -> &TcpListener {
&self.listener
}
}
struct AcceptFuture<'a> {
#[cfg(feature = "graceful-shutdown")]
manager: &'a Manager,
#[cfg(feature = "graceful-shutdown")]
index: WakerIndex,
listener: &'a mut TcpListener,
}
impl<'a> Future for AcceptFuture<'a> {
type Output = AcceptAction;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let me = self.get_mut();
#[cfg(feature = "graceful-shutdown")]
{
debug!(
"Shutting down? {}",
me.manager.shutdown.load(Ordering::Acquire)
);
if me.manager.shutdown.load(Ordering::Acquire) {
Poll::Ready(AcceptAction::Shutdown)
} else {
debug!("Set listener waker.");
me.manager.set_waker(me.index, Waker::clone(cx.waker()));
let poll = me.listener.poll_accept(cx);
poll.map(AcceptAction::Accept)
}
}
#[cfg(not(feature = "graceful-shutdown"))]
{
me.listener.poll_accept(cx).map(AcceptAction::Accept)
}
}
}