use crate::backoff::Backoff;
use crate::shared::{check_timeout, ThinWaker};
#[allow(unused_imports)]
use crate::{tokio_task_id, trace_log};
use std::cell::UnsafeCell;
use std::future::Future;
use std::mem::transmute;
use std::ops::Deref;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::atomic::{
AtomicUsize,
Ordering::{self, Acquire, Relaxed, Release, SeqCst},
};
use std::task::{Context, Poll, Waker};
use std::thread;
use std::time::{Duration, Instant};
pub struct WaitGroupInline<const THRESHOLD: usize = 0> {
inner: WaitGroupInner<()>,
}
impl<const THRESHOLD: usize> WaitGroupInline<THRESHOLD> {
pub fn new() -> Self {
Self { inner: WaitGroupInner::new((), 0) }
}
#[inline(always)]
pub fn get_left_seqcst(&self) -> usize {
self.inner.count(SeqCst)
}
#[inline(always)]
pub fn get_left(&self) -> usize {
self.inner.count(Acquire)
}
#[inline(always)]
pub fn add(&self) {
self.inner.add(1);
}
#[inline(always)]
pub fn add_many(&self, count: usize) {
debug_assert!(count < COUNT_MASK - 2);
self.inner.add(count);
}
pub unsafe fn done(&self) -> bool {
let p = &self.inner as *const WaitGroupInner<()>;
WaitGroupInner::<()>::done::<false>(p, 1, THRESHOLD)
}
pub unsafe fn done_many(&self, count: usize) -> bool {
debug_assert!(count < COUNT_MASK - 2);
let p = &self.inner as *const WaitGroupInner<()>;
WaitGroupInner::<()>::done::<false>(p, count, THRESHOLD)
}
#[inline]
pub fn try_wait(&self) -> Result<(), ()> {
if self.inner.count(SeqCst) <= THRESHOLD {
Ok(())
} else {
Err(())
}
}
#[inline]
pub unsafe fn wait_async<'a>(&'a self) -> WaitGroupFuture<'a, ()> {
WaitGroupFuture { inner: &self.inner, threshold: THRESHOLD, waker: None }
}
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
#[inline]
pub unsafe fn wait_async_timeout<'a>(
&'a self, timeout: Duration,
) -> WaitGroupTimeoutFuture<'a, (), tokio::time::Sleep, ()> {
let sleep = tokio::time::sleep(timeout);
self.wait_async_with_timer(sleep)
}
#[cfg(feature = "async_std")]
#[cfg_attr(docsrs, doc(cfg(feature = "async_std")))]
#[inline]
pub unsafe fn wait_async_timeout<'a>(
&'a self, timeout: Duration,
) -> WaitGroupTimeoutFuture<'a, (), impl Future<Output = ()>, ()> {
let sleep = async_std::task::sleep(timeout);
self.wait_async_with_timer(sleep)
}
#[inline]
pub unsafe fn wait_async_with_timer<'a, FR, R>(
&'a self, fut: FR,
) -> WaitGroupTimeoutFuture<'a, (), FR, R>
where
FR: Future<Output = R>,
{
WaitGroupTimeoutFuture { inner: &self.inner, threshold: THRESHOLD, sleep: fut, waker: None }
}
#[inline]
pub unsafe fn wait(&self) {
let _ = self.inner.wait_blocking(None, THRESHOLD);
}
#[inline]
pub unsafe fn wait_timeout(&self, timeout: Duration) -> Result<(), ()> {
self.inner.wait_blocking(Some(Instant::now() + timeout), THRESHOLD)
}
}
pub struct WaitGroup<T> {
threshold: usize,
inner: NonNull<WaitGroupInner<T>>,
}
unsafe impl<T: Send> Send for WaitGroup<T> {}
impl<T> WaitGroup<T> {
#[inline(always)]
pub fn new(inner: T, threshold: usize) -> Self {
let inner = Box::new(WaitGroupInner::new(inner, 1));
Self {
threshold: threshold + 1,
inner: unsafe { NonNull::new_unchecked(Box::into_raw(inner)) },
}
}
#[inline]
pub fn set_threshold(&mut self, threshold: usize) {
self.threshold = threshold + 1;
}
#[inline(always)]
fn get_inner(&self) -> &WaitGroupInner<T> {
unsafe { self.inner.as_ref() }
}
#[inline(always)]
pub fn get_left_seqcst(&self) -> usize {
self.get_inner().count(SeqCst) - 1
}
#[inline(always)]
pub fn get_left(&self) -> usize {
self.get_inner().count(Acquire) - 1
}
#[inline(always)]
pub fn add_guard(&self) -> WaitGroupGuard<T> {
self.get_inner().add(1);
WaitGroupGuard { inner: self.inner, threshold: self.threshold }
}
#[inline]
pub fn try_wait(&self) -> Result<(), ()> {
if self.get_inner().count(SeqCst) <= self.threshold {
Ok(())
} else {
Err(())
}
}
#[inline]
pub fn wait_async<'a>(&'a self) -> WaitGroupFuture<'a, T>
where
T: Send + Unpin,
{
let inner = self.get_inner();
WaitGroupFuture { inner, threshold: self.threshold, waker: None }
}
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
#[inline]
pub fn wait_async_timeout<'a>(
&'a self, timeout: Duration,
) -> WaitGroupTimeoutFuture<'a, T, tokio::time::Sleep, ()>
where
T: Send + Unpin,
{
let sleep = tokio::time::sleep(timeout);
self.wait_async_with_timer(sleep)
}
#[cfg(feature = "async_std")]
#[cfg_attr(docsrs, doc(cfg(feature = "async_std")))]
#[inline]
pub fn wait_async_timeout<'a>(
&'a self, timeout: Duration,
) -> WaitGroupTimeoutFuture<'a, T, impl Future<Output = ()>, ()>
where
T: Send + Unpin,
{
let sleep = async_std::task::sleep(timeout);
self.wait_async_with_timer(sleep)
}
#[inline]
pub fn wait_async_with_timer<'a, FR, R>(
&'a self, fut: FR,
) -> WaitGroupTimeoutFuture<'a, T, FR, R>
where
FR: Future<Output = R>,
T: Send + Unpin,
{
let inner = self.get_inner();
WaitGroupTimeoutFuture { inner, threshold: self.threshold, sleep: fut, waker: None }
}
#[inline]
pub fn wait(&self) {
let _ = self.get_inner().wait_blocking(None, self.threshold);
}
#[inline]
pub fn wait_timeout(&self, timeout: Duration) -> Result<(), ()> {
self.get_inner().wait_blocking(Some(Instant::now() + timeout), self.threshold)
}
}
impl<T> Drop for WaitGroup<T> {
#[inline]
fn drop(&mut self) {
unsafe {
WaitGroupInner::destroy(self.inner);
}
}
}
impl<T> Deref for WaitGroup<T> {
type Target = T;
#[inline]
fn deref(&self) -> &T {
&unsafe { self.inner.as_ref() }.inner
}
}
pub struct WaitGroupGuard<T> {
inner: NonNull<WaitGroupInner<T>>,
threshold: usize,
}
unsafe impl<T: Send> Send for WaitGroupGuard<T> {}
unsafe impl<T: Sync> Sync for WaitGroupGuard<T> {}
impl<T> Drop for WaitGroupGuard<T> {
#[inline(always)]
fn drop(&mut self) {
unsafe {
WaitGroupInner::done_ptr(self.inner, 1, self.threshold);
}
}
}
impl<T> Clone for WaitGroupGuard<T> {
#[inline]
fn clone(&self) -> Self {
let inner = unsafe { self.inner.as_ref() };
inner.add(1);
Self { inner: self.inner, threshold: self.threshold }
}
}
impl<T> Deref for WaitGroupGuard<T> {
type Target = T;
#[inline]
fn deref(&self) -> &T {
&unsafe { self.inner.as_ref() }.inner
}
}
struct WaitGroupInner<T> {
state: AtomicUsize,
o_waker: UnsafeCell<Option<ThinWaker>>,
inner: T,
}
unsafe impl<T: Sync> Sync for WaitGroupInner<T> {}
impl<T> WaitGroupInner<T> {
#[inline(always)]
fn new(inner: T, init_count: usize) -> Self {
Self { state: AtomicUsize::new(init_count), o_waker: UnsafeCell::new(None), inner }
}
#[inline]
fn count(&self, order: Ordering) -> usize {
self.state.load(order) & COUNT_MASK
}
#[inline(always)]
fn get_waker(&self) -> &mut Option<ThinWaker> {
unsafe { transmute(self.o_waker.get()) }
}
#[inline]
fn add(&self, count: usize) {
let old_state = self.state.fetch_add(count, Relaxed);
if State::new(old_state).count() >= COUNT_MASK - 2 {
panic!("WaitGroup count overflowed");
}
}
#[inline]
unsafe fn destroy(p: NonNull<Self>) -> bool {
let this = unsafe { p.as_ref() };
let mut state = this.state.load(SeqCst);
loop {
let s = State::new(state);
if s.is_locked() || s.count() > 1 {
if let Err(_state) =
this.state.compare_exchange_weak(state, state - 1, SeqCst, Acquire)
{
state = _state;
continue;
}
trace_log!("wg:({:?}) drop delay state={}", tokio_task_id!(), state - 1);
return false;
}
{
trace_log!("wg:({:?}) drop", tokio_task_id!());
let _ = unsafe { Box::from_raw(p.as_ptr()) };
return true;
}
}
}
#[inline(always)]
unsafe fn done_ptr(p: NonNull<Self>, count: usize, threshold: usize) -> bool {
let _p = p.as_ptr();
if Self::done::<true>(_p, count, threshold) {
let _ = unsafe { Box::from_raw(_p) };
return true;
} else {
false
}
}
#[inline]
fn done<const OWNER_SHIP: bool>(this: *const Self, count: usize, threshold: usize) -> bool {
trace_log!("wg:({:?}) enter done {count} {threshold}", tokio_task_id!());
unsafe {
let mut state = (*this).state.load(Relaxed);
loop {
let mut s = State::new(state);
if OWNER_SHIP && s.is_last(count) {
let _state = (*this).state.load(SeqCst);
if _state == state {
trace_log!("wg:({:?}) done drop {count} {threshold}", tokio_task_id!());
return true;
}
state = _state;
continue;
}
let try_lock = s.try_done(count, threshold);
if try_lock {
debug_assert!(s.is_locked());
}
match (*this).state.compare_exchange_weak(state, s.to_usize(), SeqCst, Acquire) {
Ok(_) => {
if try_lock {
let o_waker = (*this).get_waker().take();
if OWNER_SHIP {
let old = (*this).state.fetch_and(!WAKER_FLAG_MASK, SeqCst);
if old & COUNT_MASK == 0 {
trace_log!(
"wg:({:?}) done locked drop cur {count} = 0",
tokio_task_id!(),
);
return true;
}
} else {
(*this).state.fetch_and(!WAKER_FLAG_MASK, Release);
}
if let Some(waker) = o_waker {
trace_log!(
"wg:({:?}) done waked {count} -> {} <= {threshold}",
tokio_task_id!(),
s.count()
);
waker.wake();
}
} else {
trace_log!("wg:({:?}) done {count} -> {}", tokio_task_id!(), s.count());
}
return false;
}
Err(cur) => {
state = cur;
}
}
}
}
}
#[inline]
fn try_set_waker(&self, waker: ThinWaker, threshold: usize, may_skip: bool) -> Result<(), ()> {
let mut state = self.state.load(SeqCst);
loop {
let s = State::new(state);
if s.count() <= threshold {
return Err(());
} else if s.is_locked() {
std::hint::spin_loop();
state = self.state.load(Acquire);
trace_log!("wg:({:?}) set_waker try again", tokio_task_id!());
continue;
}
let old_state = if s.has_waker() {
if may_skip {
trace_log!("wg:({:?}) set_waker skip", tokio_task_id!());
return Ok(());
}
if let Err(s) =
self.state.compare_exchange_weak(state, s.try_lock(), SeqCst, Acquire)
{
state = s;
continue;
}
self.get_waker().replace(waker);
trace_log!("wg:({:?}) set_waker replaced", tokio_task_id!());
self.state.fetch_xor(WAKER_FLAG_MASK, SeqCst)
} else {
self.get_waker().replace(waker);
trace_log!("wg:({:?}) set_waker ok", tokio_task_id!());
self.state.fetch_or(WAKER_FLAG_SET, SeqCst)
};
if State::new(old_state).count() <= threshold {
return Err(());
}
return Ok(());
}
}
#[inline]
fn wait_blocking(&self, deadline: Option<Instant>, threshold: usize) -> Result<(), ()> {
macro_rules! check {
($order: expr) => {
let cur = self.count($order);
if cur <= threshold {
trace_log!("wg:({:?}) check {cur} <= {threshold}", tokio_task_id!());
return Ok(());
}
trace_log!("wg:({:?}) check {cur} > {threshold}", tokio_task_id!());
};
}
check!(Acquire);
let mut backoff = Backoff::new();
let mut set_waker = false;
loop {
let r = backoff.snooze();
check!(Acquire);
if r {
let waker = ThinWaker::Blocking(thread::current());
if self.try_set_waker(waker, threshold, set_waker).is_err() {
return Ok(());
} else {
set_waker = true;
}
match check_timeout(deadline) {
Ok(None) => thread::park(),
Ok(Some(dur)) => thread::park_timeout(dur),
Err(_) => {
return Err(());
}
}
backoff.reset();
}
}
}
#[inline]
fn poll_async(
&self, ctx: &mut Context, o_waker: &mut Option<Waker>, threshold: usize,
) -> Poll<()> {
macro_rules! check {
($order: expr) => {{
let s = State::new(self.state.load($order));
let cur = s.count();
if cur <= threshold {
trace_log!("wg:({:?}) READY check {cur} <= {threshold}", tokio_task_id!());
return Poll::Ready(());
}
trace_log!("wg:({:?}) check {cur} > {threshold}", tokio_task_id!());
s.has_waker()
}};
}
let has_waker = check!(Acquire);
let new_waker = ctx.waker();
if has_waker {
#[allow(clippy::needless_else)]
if let Some(old_waker) = o_waker {
if old_waker.will_wake(new_waker) {
trace_log!("wg:({:?}) will_wake=true", tokio_task_id!());
check!(SeqCst);
trace_log!("wg:({:?}) PENDING", tokio_task_id!());
return Poll::Pending;
} else {
trace_log!("wg:({:?}) waker will_wake=false", tokio_task_id!())
}
}
}
if self.try_set_waker(ThinWaker::Async(new_waker.clone()), threshold, false).is_err() {
trace_log!("wg:({:?}) READY during set_waker", tokio_task_id!());
Poll::Ready(())
} else {
o_waker.replace(new_waker.clone());
trace_log!("wg:({:?}) PENDING", tokio_task_id!());
Poll::Pending
}
}
}
#[must_use]
pub struct WaitGroupFuture<'a, T> {
inner: &'a WaitGroupInner<T>,
threshold: usize,
waker: Option<Waker>,
}
impl<'a, T> Future for WaitGroupFuture<'a, T>
where
T: Send + Unpin,
{
type Output = ();
fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
this.inner.poll_async(ctx, &mut this.waker, this.threshold)
}
}
#[must_use]
pub struct WaitGroupTimeoutFuture<'a, T, FR, R>
where
FR: Future<Output = R>,
T: Send + Unpin,
{
inner: &'a WaitGroupInner<T>,
sleep: FR,
threshold: usize,
waker: Option<Waker>,
}
impl<'a, T, FR, R> Future for WaitGroupTimeoutFuture<'a, T, FR, R>
where
FR: Future<Output = R>,
T: Send + Unpin,
{
type Output = Result<(), ()>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
if this.inner.poll_async(ctx, &mut this.waker, this.threshold).is_ready() {
return Poll::Ready(Ok(()));
}
let sleep = unsafe { Pin::new_unchecked(&mut this.sleep) };
if sleep.poll(ctx).is_ready() {
Poll::Ready(Err(()))
} else {
Poll::Pending
}
}
}
const WAKER_FLAG_SET: usize = 1 << (usize::BITS - 1);
const WAKER_FLAG_LOCK: usize = 1 << (usize::BITS - 2);
const WAKER_FLAG_MASK: usize = WAKER_FLAG_SET | WAKER_FLAG_LOCK;
const COUNT_MASK: usize = !WAKER_FLAG_MASK;
struct State(usize);
impl State {
#[inline(always)]
fn new(state: usize) -> Self {
Self(state)
}
#[inline(always)]
fn count(&self) -> usize {
self.0 & COUNT_MASK
}
#[inline(always)]
fn waker_flag(&self) -> usize {
self.0 & WAKER_FLAG_MASK
}
#[inline(always)]
fn is_locked(&self) -> bool {
self.0 & WAKER_FLAG_LOCK > 0
}
#[inline(always)]
fn has_waker(&self) -> bool {
self.0 & WAKER_FLAG_SET > 0
}
#[inline(always)]
fn try_lock(&self) -> usize {
self.count() | WAKER_FLAG_LOCK
}
#[inline]
fn is_last(&self, delta: usize) -> bool {
let waker_flag = self.waker_flag();
waker_flag != WAKER_FLAG_LOCK && self.count() == delta
}
#[inline(always)]
fn try_done(&mut self, delta: usize, threshold: usize) -> bool {
let waker_flag = self.waker_flag();
let old_count = self.count();
let new_count = if old_count >= delta {
old_count - delta
} else {
panic!("underflow detected {} < {}", old_count, delta);
};
let try_lock = new_count <= threshold && waker_flag == WAKER_FLAG_SET;
if try_lock {
self.0 = WAKER_FLAG_LOCK | new_count;
true
} else {
self.0 = waker_flag | new_count;
false
}
}
#[inline(always)]
#[allow(clippy::wrong_self_convention)]
fn to_usize(&self) -> usize {
self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use captains_log::{recipe, ConsoleTarget, Level};
use std::thread;
#[test]
fn test_waitgroup_inner_count() {
let wg = WaitGroup::new((), 0);
assert_eq!(wg.get_left_seqcst(), 0);
let guard1 = wg.add_guard();
assert_eq!(wg.get_left_seqcst(), 1);
let guard2 = wg.add_guard();
assert_eq!(wg.get_left_seqcst(), 2);
drop(guard1);
assert_eq!(wg.get_left_seqcst(), 1);
drop(guard2);
assert_eq!(wg.get_left_seqcst(), 0);
}
#[test]
fn test_waitgroup_state() {
assert_eq!(State::new(2).count(), 2);
assert!(State::new(2 | WAKER_FLAG_SET).has_waker());
assert!(!State::new(2 | WAKER_FLAG_SET).is_locked());
assert!(!State::new(2 | WAKER_FLAG_LOCK).has_waker());
assert!(State::new(2 | WAKER_FLAG_LOCK).is_locked());
let mut s = State::new(2);
assert_eq!(s.try_done(1, 1), false);
assert!(!s.is_locked());
assert_eq!(s.count(), 1);
assert!(s.is_last(1));
assert_eq!(s.count(), 1);
let mut s = State::new(3 | WAKER_FLAG_SET);
assert!(!s.is_last(1));
assert_eq!(s.try_done(1, 2), true);
assert!(s.is_locked());
assert!(!s.has_waker());
assert_eq!(s.count(), 2);
assert_eq!(s.try_done(1, 0), false);
assert!(s.is_locked());
assert_eq!(s.count(), 1);
let _s = s.0 & (!WAKER_FLAG_MASK);
assert_eq!(_s, 1);
assert_eq!(s.try_done(1, 0), false);
assert_eq!(s.count(), 0);
}
#[test]
fn test_waitgroup_ptr() {
recipe::console_logger(ConsoleTarget::Stdout, Level::Trace).test().build().expect("log");
let inner = Box::new(WaitGroupInner::new((), 1));
assert_eq!(inner.count(SeqCst), 1);
assert_eq!(State::new(inner.state.load(Ordering::SeqCst)).waker_flag(), 0);
println!("test try_set_waker met threshold reach");
assert_eq!(inner.try_set_waker(ThinWaker::Blocking(thread::current()), 1, false), Err(()));
inner.add(1);
assert_eq!(inner.count(SeqCst), 2);
println!("test try_set_waker ok");
assert!(inner.try_set_waker(ThinWaker::Blocking(thread::current()), 1, false).is_ok());
let s = State::new(inner.state.load(Ordering::SeqCst));
assert_eq!(s.waker_flag(), WAKER_FLAG_SET, "s {}, {}", s.is_locked(), s.has_waker());
println!("test try_set_waker again skip");
assert!(inner.try_set_waker(ThinWaker::Blocking(thread::current()), 1, true).is_ok());
let s = State::new(inner.state.load(Ordering::SeqCst));
assert_eq!(s.waker_flag(), WAKER_FLAG_SET);
println!("test try_set_waker again force");
assert!(inner.try_set_waker(ThinWaker::Blocking(thread::current()), 1, false).is_ok());
let s = State::new(inner.state.load(Ordering::SeqCst));
assert_eq!(s.waker_flag(), WAKER_FLAG_SET);
assert_eq!(inner.count(SeqCst), 2);
let p = unsafe { NonNull::new_unchecked(Box::into_raw(inner)) };
println!("test done triggering wakeup");
unsafe {
assert!(!WaitGroupInner::done_ptr(p, 1, 1));
{
let inner = p.as_ref();
assert_eq!(inner.count(SeqCst), 1);
let s = State::new(inner.state.load(Ordering::SeqCst));
assert_eq!(s.waker_flag(), 0);
}
println!("test done triggering drop");
assert!(WaitGroupInner::done_ptr(p, 1, 0));
}
}
#[test]
fn test_waitgroup_inner() {
recipe::console_logger(ConsoleTarget::Stdout, Level::Trace).test().build().expect("log");
let inner = WaitGroupInner::new((), 1);
assert_eq!(inner.count(SeqCst), 1);
assert_eq!(State::new(inner.state.load(Ordering::SeqCst)).waker_flag(), 0);
println!("test try_set_waker met threshold reach");
assert_eq!(inner.try_set_waker(ThinWaker::Blocking(thread::current()), 1, false), Err(()));
inner.add(1);
assert_eq!(inner.count(SeqCst), 2);
println!("test try_set_waker ok");
assert!(inner.try_set_waker(ThinWaker::Blocking(thread::current()), 1, false).is_ok());
let s = State::new(inner.state.load(Ordering::SeqCst));
assert_eq!(s.waker_flag(), WAKER_FLAG_SET, "s {}, {}", s.is_locked(), s.has_waker());
println!("test try_set_waker again skip");
assert!(inner.try_set_waker(ThinWaker::Blocking(thread::current()), 1, true).is_ok());
let s = State::new(inner.state.load(Ordering::SeqCst));
assert_eq!(s.waker_flag(), WAKER_FLAG_SET);
println!("test try_set_waker again force");
assert!(inner.try_set_waker(ThinWaker::Blocking(thread::current()), 1, false).is_ok());
let s = State::new(inner.state.load(Ordering::SeqCst));
assert_eq!(s.waker_flag(), WAKER_FLAG_SET);
assert_eq!(inner.count(SeqCst), 2);
let p = &inner as *const WaitGroupInner<()>;
println!("test done triggering wakeup");
assert!(!WaitGroupInner::<()>::done::<false>(p, 1, 1));
{
assert_eq!(inner.count(SeqCst), 1);
let s = State::new(inner.state.load(Ordering::SeqCst));
assert_eq!(s.waker_flag(), 0);
}
println!("test done last");
WaitGroupInner::<()>::done::<false>(p, 1, 0);
assert_eq!(inner.count(Ordering::SeqCst), 0)
}
}