use pin_project_lite::pin_project;
use std::{
future::Future,
pin::{pin, Pin},
sync::{atomic::AtomicBool, Arc},
task::Poll,
};
use tokio::{
signal::unix::{signal, SignalKind},
sync::{futures::Notified, Notify},
};
#[derive(Debug, Default)]
pub struct ShutdownHandler {
notifier: Notify,
shutdown: AtomicBool,
}
impl ShutdownHandler {
pub fn new() -> Self {
Self::default()
}
pub fn sigterm() -> std::io::Result<Arc<Self>> {
let this = Arc::new(Self::new());
this.spawn_sigterm_handler()?;
Ok(this)
}
pub fn spawn_sigterm_handler(self: &Arc<Self>) -> std::io::Result<()> {
self.spawn_signal_handler(SignalKind::terminate())
}
pub fn spawn_signal_handler(self: &Arc<Self>, signal_kind: SignalKind) -> std::io::Result<()> {
let mut signal = signal(signal_kind)?;
let shutdown = self.clone();
tokio::spawn(async move {
signal.recv().await;
shutdown.shutdown();
});
Ok(())
}
pub fn shutdown(&self) {
self.shutdown
.store(true, std::sync::atomic::Ordering::Release);
self.notifier.notify_waiters();
}
pub fn wait_for_signal(&self) -> ShutdownSignal<'_> {
ShutdownSignal {
shutdown: &self.shutdown,
notified: self.notifier.notified(),
}
}
pub async fn wait_for_signal_or_future<F: Future + Unpin>(&self, f: F) -> SignalOrComplete<F> {
let mut handle = pin!(self.wait_for_signal());
let mut f = Some(f);
std::future::poll_fn(|cx| {
if let Poll::Ready(_signal) = handle.as_mut().poll(cx) {
return Poll::Ready(SignalOrComplete::ShutdownSignal(f.take().unwrap()));
}
if let Poll::Ready(res) = Pin::new(f.as_mut().unwrap()).poll(cx) {
return Poll::Ready(SignalOrComplete::Completed(res));
}
Poll::Pending
})
.await
}
}
#[derive(Debug)]
pub enum SignalOrComplete<F: Future> {
ShutdownSignal(F),
Completed(F::Output),
}
pin_project!(
pub struct ShutdownSignal<'a> {
shutdown: &'a AtomicBool,
#[pin]
notified: Notified<'a>,
}
);
impl std::future::Future for ShutdownSignal<'_> {
type Output = ();
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let this = self.project();
if this.shutdown.load(std::sync::atomic::Ordering::Acquire) {
std::task::Poll::Ready(())
} else {
this.notified.poll(cx)
}
}
}
#[cfg(test)]
mod test {
use std::{sync::Arc, time::Duration};
use nix::sys::signal::{raise, Signal};
use tokio::{signal::unix::SignalKind, sync::oneshot, time::timeout};
use crate::ShutdownHandler;
#[tokio::test]
async fn shutdown_sigterm() {
let shutdown = Arc::new(ShutdownHandler::new());
shutdown.spawn_sigterm_handler().unwrap();
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
shutdown.wait_for_signal().await;
tx.send(true).unwrap();
});
raise(Signal::SIGTERM).unwrap();
assert!(
(timeout(Duration::from_secs(1), rx).await).is_ok(),
"Shutdown handler took longer than 1 second!"
);
}
#[tokio::test]
async fn shutdown_custom_signal() {
let shutdown = Arc::new(ShutdownHandler::new());
shutdown.spawn_signal_handler(SignalKind::hangup()).unwrap();
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
shutdown.wait_for_signal().await;
tx.send(true).unwrap();
});
raise(Signal::SIGHUP).unwrap();
assert!(
(timeout(Duration::from_secs(1), rx).await).is_ok(),
"Shutdown handler took longer than 1 second!"
);
}
#[tokio::test]
async fn shutdown() {
let shutdown = Arc::new(ShutdownHandler::new());
let (tx, rx) = oneshot::channel();
let channel_shutdown = shutdown.clone();
tokio::spawn(async move {
channel_shutdown.wait_for_signal().await;
tx.send(true).unwrap();
});
tokio::spawn(async move {
shutdown.shutdown();
});
assert!(
(timeout(Duration::from_secs(1), rx).await).is_ok(),
"Shutdown handler took longer than 1 second!"
);
}
#[tokio::test]
async fn no_notification() {
let shutdown = Arc::new(ShutdownHandler::new());
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
shutdown.wait_for_signal().await;
tx.send(true).unwrap();
});
assert!(
(timeout(Duration::from_secs(1), rx).await).is_err(),
"Shutdown handler ran without a signal!"
);
}
}