#![deny(warnings, rust_2018_idioms)]
#![forbid(unsafe_code)]
#[cfg(feature = "tower")]
mod retain;
#[cfg(feature = "tower")]
pub use crate::retain::Retain;
use std::future::Future;
use tokio::sync::{mpsc, watch};
pub fn channel() -> (Signal, Watch) {
let (signal_tx, signal_rx) = watch::channel(());
let (drained_tx, drained_rx) = mpsc::channel(1);
let signal = Signal {
drained_rx,
signal_tx,
};
let watch = Watch {
drained_tx,
signal_rx,
};
(signal, watch)
}
enum Never {}
pub struct Signal {
drained_rx: mpsc::Receiver<Never>,
signal_tx: watch::Sender<()>,
}
#[derive(Clone)]
pub struct Watch {
drained_tx: mpsc::Sender<Never>,
signal_rx: watch::Receiver<()>,
}
#[must_use = "ReleaseShutdown should be dropped explicitly to release the runtime"]
#[derive(Clone)]
pub struct ReleaseShutdown(mpsc::Sender<Never>);
impl Signal {
pub async fn closed(&mut self) {
self.signal_tx.closed().await;
}
pub async fn drain(mut self) {
let _ = self.signal_tx.send(());
match self.drained_rx.recv().await {
None => {}
Some(n) => match n {},
}
}
}
impl std::fmt::Debug for Signal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Signal").finish_non_exhaustive()
}
}
impl Watch {
pub async fn signaled(mut self) -> ReleaseShutdown {
let _ = self.signal_rx.changed().await;
ReleaseShutdown(self.drained_tx)
}
pub fn ignore_signaled(self) -> ReleaseShutdown {
drop(self.signal_rx);
ReleaseShutdown(self.drained_tx)
}
pub async fn watch<A, F>(self, mut future: A, on_drain: F) -> A::Output
where
A: Future + Unpin,
F: FnOnce(&mut A),
{
tokio::select! {
res = &mut future => res,
shutdown = self.signaled() => {
on_drain(&mut future);
shutdown.release_after(future).await
}
}
}
}
impl std::fmt::Debug for Watch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Watch").finish_non_exhaustive()
}
}
impl ReleaseShutdown {
pub async fn release_after<F: Future>(self, future: F) -> F::Output {
let res = future.await;
drop(self.0);
res
}
}
impl std::fmt::Debug for ReleaseShutdown {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReleaseShutdown").finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use std::{
future::Future,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
task::{Context, Poll},
};
use tokio::sync::oneshot;
use tokio_test::{assert_pending, assert_ready, task};
pin_project_lite::pin_project! {
struct Fut {
notified: Arc<AtomicBool>,
#[pin]
inner: oneshot::Receiver<()>,
}
}
impl Fut {
pub fn new() -> (Self, oneshot::Sender<()>, Arc<AtomicBool>) {
let notified = Arc::new(AtomicBool::new(false));
let (tx, rx) = oneshot::channel::<()>();
let fut = Fut {
notified: notified.clone(),
inner: rx,
};
(fut, tx, notified)
}
}
impl Future for Fut {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let this = self.project();
let _ = futures::ready!(this.inner.poll(cx));
Poll::Ready(())
}
}
#[tokio::test]
async fn watch() {
let (signal, watch) = super::channel();
let (fut0, tx0, notified0) = Fut::new();
let mut watch0 = task::spawn(
watch
.clone()
.watch(fut0, |f| f.notified.store(true, SeqCst)),
);
let (fut1, tx1, notified1) = Fut::new();
let mut watch1 = task::spawn(watch.watch(fut1, |f| f.notified.store(true, SeqCst)));
assert_pending!(watch0.poll());
assert_pending!(watch1.poll());
assert!(!notified0.load(SeqCst));
assert!(!notified1.load(SeqCst));
let mut drain = task::spawn(signal.drain());
assert_pending!(drain.poll());
assert_pending!(watch0.poll());
assert!(notified0.load(SeqCst));
assert_pending!(watch1.poll());
assert!(notified1.load(SeqCst));
tx0.send(()).expect("must send");
assert_ready!(watch0.poll());
assert_pending!(watch1.poll());
assert_pending!(drain.poll());
tx1.send(()).expect("must send");
assert_ready!(watch1.poll());
assert_ready!(drain.poll());
}
#[tokio::test]
async fn drain() {
let (signal, watch) = super::channel();
let mut signaled = task::spawn(async move {
let release = watch.signaled().await;
drop(release);
});
assert_pending!(signaled.poll());
let mut drain = task::spawn(signal.drain());
assert_pending!(drain.poll());
assert_ready!(signaled.poll());
assert_ready!(drain.poll());
}
#[tokio::test]
async fn closed() {
let (mut signal, watch) = super::channel();
let mut closed = task::spawn(async move {
signal.closed().await;
});
assert_pending!(closed.poll());
drop(watch);
assert_ready!(closed.poll());
}
}