#[cfg(feature = "runtime")]
use std::{
pin::Pin,
task::{Context, Poll},
};
use tokio::signal::unix::{signal, Signal, SignalKind};
use tracing::debug;
#[cfg_attr(docsrs, doc(cfg(feature = "shutdown")))]
pub use drain::Watch;
#[derive(Debug)]
#[must_use = "call `Shutdown::on_signal` to await a signal"]
#[cfg_attr(docsrs, doc(cfg(feature = "shutdown")))]
pub struct Shutdown {
interrupt: Signal,
terminate: Signal,
tx: drain::Signal,
}
#[derive(Debug, thiserror::Error)]
#[cfg_attr(docsrs, doc(cfg(feature = "shutdown")))]
#[error("process aborted by signal")]
pub struct Aborted(());
#[derive(Debug, thiserror::Error)]
#[cfg_attr(docsrs, doc(cfg(feature = "shutdown")))]
#[error("failed to register signal handler: {0}")]
pub struct RegisterError(#[from] std::io::Error);
#[cfg(feature = "runtime")]
pin_project_lite::pin_project! {
#[cfg_attr(docsrs, doc(cfg(feature = "runtime")))]
pub struct CancelOnShutdown<T> {
#[pin]
inner: T,
#[pin]
shutdown: Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>>,
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "shutdown")))]
pub fn sigint_or_sigterm() -> Result<(Shutdown, Watch), RegisterError> {
let interrupt = signal(SignalKind::interrupt())?;
let terminate = signal(SignalKind::terminate())?;
let (tx, rx) = drain::channel();
let shutdown = Shutdown {
interrupt,
terminate,
tx,
};
Ok((shutdown, rx))
}
impl Shutdown {
pub async fn signaled(self) -> Result<(), Aborted> {
let Self {
mut interrupt,
mut terminate,
mut tx,
} = self;
tokio::select! {
_ = interrupt.recv() => {
debug!("Received SIGINT; draining");
},
_ = terminate.recv() => {
debug!("Received SIGTERM; draining");
}
_ = tx.closed() => {
debug!("All shutdown receivers dropped");
return Ok(());
}
}
tokio::select! {
_ = tx.drain() => {
debug!("Drained");
Ok(())
},
_ = interrupt.recv() => {
debug!("Received SIGINT; aborting");
Err(Aborted(()))
},
_ = terminate.recv() => {
debug!("Received SIGTERM; aborting");
Err(Aborted(()))
}
}
}
}
#[cfg(feature = "runtime")]
impl<T> CancelOnShutdown<T> {
pub(crate) fn new(watch: Watch, inner: T) -> Self {
let shutdown = Box::pin(async move {
let _ = watch.signaled().await;
});
Self { inner, shutdown }
}
}
#[cfg(feature = "runtime")]
impl<F: std::future::Future<Output = ()>> std::future::Future for CancelOnShutdown<F> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let mut this = self.project();
if this.inner.poll(cx).is_ready() {
return Poll::Ready(());
}
this.shutdown.as_mut().poll(cx)
}
}
#[cfg(feature = "runtime")]
impl<S: futures_core::Stream> futures_core::Stream for CancelOnShutdown<S> {
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<S::Item>> {
use std::future::Future;
let mut this = self.project();
if let Poll::Ready(next) = this.inner.poll_next(cx) {
return Poll::Ready(next);
}
if this.shutdown.as_mut().poll(cx).is_ready() {
return Poll::Ready(None);
}
Poll::Pending
}
}
#[cfg(all(test, feature = "runtime"))]
mod test {
use super::CancelOnShutdown;
use tokio_stream::wrappers::ReceiverStream;
use tokio_test::{assert_pending, assert_ready, assert_ready_eq, task};
#[tokio::test]
async fn cancel_stream_drains() {
let (shutdown_tx, shutdown_rx) = drain::channel();
let (stream_tx, stream_rx) = tokio::sync::mpsc::channel(3);
let mut stream_rx = task::spawn(CancelOnShutdown::new(
shutdown_rx,
ReceiverStream::new(stream_rx),
));
stream_tx.try_send(1).unwrap();
stream_tx.try_send(2).unwrap();
stream_tx.try_send(3).unwrap();
assert_ready_eq!(stream_rx.poll_next(), Some(1));
let mut drain = task::spawn(shutdown_tx.drain());
assert_ready_eq!(stream_rx.poll_next(), Some(2));
assert_ready_eq!(stream_rx.poll_next(), Some(3));
assert_pending!(drain.poll());
assert_ready_eq!(stream_rx.poll_next(), None);
assert_ready!(drain.poll());
}
#[tokio::test]
async fn cancel_future_ends() {
let (shutdown_tx, shutdown_rx) = drain::channel();
let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
let mut rx = task::spawn(CancelOnShutdown::new(
shutdown_rx,
Box::pin(async move {
rx.await.unwrap();
}),
));
assert_pending!(rx.poll());
let mut drain = task::spawn(shutdown_tx.drain());
assert_pending!(drain.poll());
assert_ready!(rx.poll());
assert_ready!(drain.poll());
}
}