use core::{
pin::Pin,
task::{Context, Poll, Waker},
};
use derive_more::{Constructor, Deref};
use crate::utils::*;
pub(crate) type WaitGroupData = Option<Waker>;
pub(crate) unsafe trait WaitGroupType: Sized {
fn state(&self) -> &AtomicU8;
unsafe fn slot(&self) -> &UnsafeCell<WaitGroupData>;
}
const DONE: u8 = 0b01;
const LOCK: u8 = 0b10;
#[allow(clippy::mut_from_ref)]
#[inline]
unsafe fn with_slot_mut<T: WaitGroupType, R, F: FnOnce(&mut WaitGroupData) -> R>(
val: &T,
f: F,
) -> R {
#[cfg(not(loom))]
{
f(unsafe { &mut *val.slot().get() })
}
#[cfg(loom)]
{
unsafe { val.slot() }
.get()
.with(|ptr| f(unsafe { &mut *ptr.cast_mut() }))
}
}
pub(crate) trait WaitGroupUtil: WaitGroupType {
#[inline]
fn is_done(&self) -> bool {
self.state().load(atomic::Acquire) & DONE != 0
}
#[inline]
unsafe fn send_done(&self) {
let prev_state = self.state().fetch_or(DONE | LOCK, atomic::AcqRel);
if prev_state & LOCK == 0
&& let Some(waker) = unsafe { with_slot_mut(self, |slot| slot.take()) }
{
waker.wake();
}
}
}
impl<T: WaitGroupType> WaitGroupUtil for T {}
#[must_use]
#[derive(Debug, Constructor, Deref)]
pub(crate) struct WaitGroupWrapper<T: WaitGroupType>(T);
impl<T: WaitGroupType> Future for WaitGroupWrapper<T> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let prev_state = self.state().fetch_or(LOCK, atomic::Acquire);
if prev_state & DONE != 0 {
return Poll::Ready(());
}
debug_assert!(prev_state & LOCK == 0);
let guard = UnlockGuard(self.state());
let waker = cx.waker();
unsafe {
with_slot_mut(&self.0, |slot| {
match slot {
Some(old) if old.will_wake(waker) => {}
_ => {
*slot = Some(waker.clone());
}
};
});
}
guard.defuse();
let prev_state = self.state().fetch_and(!LOCK, atomic::AcqRel);
if prev_state & DONE != 0 {
drop(unsafe { with_slot_mut(&self.0, |slot| slot.take()) });
self.state().fetch_or(LOCK, atomic::Release);
return Poll::Ready(());
}
Poll::Pending
}
}
impl<T: WaitGroupType> Drop for WaitGroupWrapper<T> {
#[inline]
fn drop(&mut self) {
let prev_state = self.state().fetch_or(LOCK, atomic::Acquire);
if prev_state & LOCK == 0
&& let Some(waker) = unsafe { with_slot_mut(&self.0, |slot| slot.take()) }
{
drop(waker);
}
}
}
#[cfg(feature = "futures-core")]
impl<T: WaitGroupType> futures_core::FusedFuture for WaitGroupWrapper<T> {
#[inline]
fn is_terminated(&self) -> bool {
self.is_done()
}
}
struct UnlockGuard<'a>(&'a AtomicU8);
impl<'a> UnlockGuard<'a> {
#[inline]
fn defuse(self) {
core::mem::forget(self);
}
}
impl<'a> Drop for UnlockGuard<'a> {
#[inline]
fn drop(&mut self) {
self.0.fetch_and(!LOCK, atomic::AcqRel);
}
}