use crate::get_value;
use core::future::Future;
use core::pin::Pin;
use core::task::{Context, Poll};
use core::{any::Provider, sync::atomic::AtomicBool};
use pin_project_lite::pin_project;
use std::sync::Arc;
use tokio::sync::futures::Notified;
use tokio::sync::Notify;
#[derive(Debug, Default)]
struct ShutdownInner {
notifier: Notify,
shutdown: AtomicBool,
}
impl Provider for ShutdownReceiver {
fn provide<'a>(&'a self, demand: &mut core::any::Demand<'a>) {
if let Some(inner) = self.0.as_ref() {
demand.provide_ref(inner);
}
}
}
#[derive(Clone, Default)]
pub struct ShutdownSender(Arc<ShutdownInner>);
#[derive(Clone)]
pub struct ShutdownReceiver(Option<Arc<ShutdownInner>>);
impl ShutdownSender {
pub fn new() -> Self {
Default::default()
}
pub fn receiver(self) -> ShutdownReceiver {
ShutdownReceiver(Some(self.0))
}
pub fn shutdown(&self) {
self.0
.shutdown
.store(true, std::sync::atomic::Ordering::Release);
self.0.notifier.notify_waiters();
}
}
impl ShutdownReceiver {
pub async fn from_context() -> Self {
Self(get_value().await)
}
pub async fn wait_for_signal(&self) {
if let Some(x) = &self.0 {
ShutdownSignal {
shutdown: &x.shutdown,
notified: x.notifier.notified(),
}
.await
}
}
}
pin_project!(
struct ShutdownSignal<'a> {
shutdown: &'a AtomicBool,
#[pin]
notified: Notified<'a>,
}
);
impl Future for ShutdownSignal<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if this.shutdown.load(core::sync::atomic::Ordering::Acquire) {
Poll::Ready(())
} else {
this.notified.poll(cx)
}
}
}
#[cfg(feature = "time")]
#[cfg_attr(docsrs, doc(cfg(feature = "time")))]
pub(crate) mod time {
use super::{ShutdownReceiver, ShutdownSignal};
use core::future::Future;
use core::pin::Pin;
use core::task::{Context, Poll};
use pin_project_lite::pin_project;
#[derive(Debug)]
pub enum SignalOrComplete<F: Future> {
ShutdownSignal(F),
Completed(F::Output),
}
impl<F: Future> SignalOrComplete<F> {
pub fn completed(self) -> Option<F::Output> {
match self {
SignalOrComplete::ShutdownSignal(_) => None,
SignalOrComplete::Completed(f) => Some(f),
}
}
}
pin_project!(
struct SignalOrCompleteFut<F, A, B> {
inner: Option<F>,
#[pin]
a: Option<A>,
#[pin]
b: Option<B>,
}
);
impl<F, A, B> Future for SignalOrCompleteFut<F, A, B>
where
F: Future + Unpin,
A: Future<Output = ()>,
B: Future<Output = ()>,
{
type Output = SignalOrComplete<F>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut f = this.inner.take().expect("cannot poll Select twice");
if let Poll::Ready(f) = Pin::new(&mut f).poll(cx) {
return Poll::Ready(SignalOrComplete::Completed(f));
}
if let Some(a) = this.a.as_pin_mut() {
if a.poll(cx).is_ready() {
return Poll::Ready(SignalOrComplete::ShutdownSignal(f));
}
}
if let Some(b) = this.b.as_pin_mut() {
if b.poll(cx).is_ready() {
return Poll::Ready(SignalOrComplete::ShutdownSignal(f));
}
}
*this.inner = Some(f);
Poll::Pending
}
}
pub async fn run_until_signal<F: Future + Unpin>(f: F) -> SignalOrComplete<F> {
use crate::well_known::Deadline;
let deadline = Deadline::get().await;
let shutdown = ShutdownReceiver::from_context().await.0;
let res = SignalOrCompleteFut {
inner: Some(f),
a: deadline.map(|deadline| tokio::time::sleep_until(deadline.into())),
b: shutdown.as_deref().map(|shutdown| ShutdownSignal {
shutdown: &shutdown.shutdown,
notified: shutdown.notifier.notified(),
}),
}
.await;
#[allow(clippy::let_and_return)]
res
}
}