use parking_lot::Mutex;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::Notify;
use crate::error::RelayError;
pub struct CompletionTracker {
expected: AtomicUsize,
completed: AtomicUsize,
error: Mutex<Option<RelayError>>,
notify: Arc<Notify>,
}
impl CompletionTracker {
pub fn new(expected: usize) -> Self {
Self {
expected: AtomicUsize::new(expected),
completed: AtomicUsize::new(0),
error: Mutex::new(None),
notify: Arc::new(Notify::new()),
}
}
pub fn expected(&self) -> usize {
self.expected.load(Ordering::SeqCst)
}
pub fn completed(&self) -> usize {
self.completed.load(Ordering::SeqCst)
}
pub fn is_complete(&self) -> bool {
self.completed() >= self.expected()
}
pub fn complete_one(&self) {
let completed = self.completed.fetch_add(1, Ordering::SeqCst) + 1;
let expected = self.expected.load(Ordering::SeqCst);
debug_assert!(
completed <= expected + 1,
"complete_one called {} times but only {} expected (possible double-completion bug)",
completed,
expected
);
if completed >= expected {
self.notify.notify_waiters();
}
}
pub fn fail(&self, error: RelayError) {
{
let mut err = self.error.lock();
if err.is_none() {
*err = Some(error);
}
}
self.complete_one();
}
pub fn take_error(&self) -> Option<RelayError> {
self.error.lock().take()
}
pub fn wait(&self) -> CompletionFuture<'_> {
CompletionFuture { tracker: self }
}
pub async fn wait_async(&self) {
while !self.is_complete() {
self.notify.notified().await;
}
}
}
impl CompletionTracker {
pub async fn wait_owned(self: Arc<Self>) {
while !self.is_complete() {
self.notify.notified().await;
}
}
}
pub struct CompletionFuture<'a> {
tracker: &'a CompletionTracker,
}
impl<'a> Future for CompletionFuture<'a> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.tracker.is_complete() {
return Poll::Ready(());
}
let notified = self.tracker.notify.notified();
futures::pin_mut!(notified);
notified.as_mut().enable();
if self.tracker.is_complete() {
return Poll::Ready(());
}
match notified.as_mut().poll(cx) {
Poll::Ready(()) => {
if self.tracker.is_complete() {
Poll::Ready(())
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
}
Poll::Pending => Poll::Pending,
}
}
}
impl std::fmt::Debug for CompletionTracker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompletionTracker")
.field("expected", &self.expected())
.field("completed", &self.completed())
.field("has_error", &self.error.lock().is_some())
.finish()
}
}