use log::error;
use std::{
future::Future,
pin::Pin,
sync::{
atomic::{AtomicI64, Ordering},
Arc,
},
task::{Context, Poll, Waker},
};
use parking_lot::Mutex;
pub struct WaitGroup(Arc<WaitGroupInner>);
impl Clone for WaitGroup {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
macro_rules! log_and_panic {
($($arg:tt)+) => (
error!($($arg)+);
panic!($($arg)+);
);
}
macro_rules! trace_log {
($($arg:tt)+) => (
#[cfg(feature="trace_log")]
{
log::trace!($($arg)+);
}
);
}
impl WaitGroup {
pub fn new() -> Self {
Self(WaitGroupInner::new())
}
#[inline(always)]
pub fn left(&self) -> usize {
let count = self.0.left.load(Ordering::SeqCst);
if count < 0 {
log_and_panic!("WaitGroup.left {} < 0", count);
}
count as usize
}
#[inline(always)]
pub fn add(&self, i: usize) {
let _r = self.0.left.fetch_add(i as i64, Ordering::Acquire);
trace_log!("add {}->{}", i, _r + i as i64);
}
#[inline(always)]
pub fn add_guard(&self) -> WaitGroupGuard {
self.add(1);
WaitGroupGuard {
inner: self.0.clone(),
}
}
pub async fn wait_to(&self, target: usize) -> bool {
let _self = self.0.as_ref();
let left = _self.left.load(Ordering::Acquire);
if left <= target as i64 {
trace_log!("wait_to skip {} <= target {}", left, target);
return false;
}
WaitGroupFuture {
wg: &_self,
target,
waker: None,
}
.await;
return true;
}
#[inline(always)]
pub async fn wait(&self) {
self.wait_to(0).await;
}
#[inline]
pub fn done(&self) {
let inner = self.0.as_ref();
inner.done(1);
}
#[inline]
pub fn done_many(&self, count: usize) {
let inner = self.0.as_ref();
inner.done(count as i64);
}
}
pub struct WaitGroupGuard {
inner: Arc<WaitGroupInner>,
}
impl Drop for WaitGroupGuard {
fn drop(&mut self) {
let inner = &self.inner;
inner.done(1);
}
}
struct WaitGroupInner {
left: AtomicI64,
waiting: AtomicI64,
waker: Mutex<Option<Arc<Waker>>>,
}
impl WaitGroupInner {
#[inline(always)]
fn new() -> Arc<Self> {
Arc::new(Self {
left: AtomicI64::new(0),
waiting: AtomicI64::new(-1),
waker: Mutex::new(None),
})
}
#[inline]
fn done(&self, count: i64) {
let left = self.left.fetch_sub(count, Ordering::SeqCst) - count;
if left < 0 {
log_and_panic!("WaitGroup.left {} < 0", left);
}
let waiting = self.waiting.load(Ordering::SeqCst);
if waiting < 0 {
trace_log!("done {}->{} not waiting", count, left);
return;
}
if left <= waiting {
if self.waiting.compare_exchange(waiting, -1, Ordering::SeqCst, Ordering::Relaxed).is_ok() {
let mut guard = self.waker.lock();
if let Some(waker) = guard.take() {
waker.wake_by_ref();
drop(guard);
trace_log!("done {}->{} wake {}", count, left, waiting);
} else {
drop(guard);
trace_log!("done {}->{} wake {} but no waker", count, left, waiting);
}
}
} else {
trace_log!("done {}->{} waiting {}", count, left, waiting);
}
}
#[inline]
fn set_waker(&self, waker: Arc<Waker>, target: usize, force: bool) {
trace_log!("set_waker {} force={}", target, force);
{
let mut guard = self.waker.lock();
if !force {
if guard.is_some() {
drop(guard);
log_and_panic!("concurrent wait detected");
}
}
guard.replace(waker);
let old_target = self.waiting.swap(target as i64, Ordering::SeqCst);
drop(guard);
if ! force && old_target >= 0 {
log_and_panic!("Concurrent wait() by multiple coroutines, enter unlikely code");
}
}
}
#[inline]
fn cancel_wait(&self) {
trace_log!("cancel_wait");
{
let mut guard = self.waker.lock();
self.waiting.store(-1, Ordering::SeqCst);
let _ = guard.take();
}
}
}
struct WaitGroupFuture<'a> {
wg: &'a WaitGroupInner,
target: usize,
waker: Option<Arc<Waker>>,
}
impl<'a> WaitGroupFuture<'a> {
#[inline(always)]
fn _poll(&mut self) -> bool {
let cur = self.wg.left.load(Ordering::SeqCst);
if cur <= self.target as i64 {
trace_log!("poll ready {}<={}", cur, self.target);
self._clear();
true
} else {
trace_log!("poll not ready {}>{}", cur, self.target);
false
}
}
#[inline(always)]
fn _clear(&mut self) {
if self.waker.take().is_some() {
self.wg.cancel_wait();
}
}
}
impl<'a> Drop for WaitGroupFuture<'a> {
fn drop(&mut self) {
self._clear();
}
}
impl<'a> Future for WaitGroupFuture<'a> {
type Output = ();
fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
let _self = self.get_mut();
if _self._poll() {
return Poll::Ready(());
}
let force = {
if let Some(waker) = _self.waker.as_ref() {
if _self.wg.waiting.load(Ordering::SeqCst) >= 0 &&
waker.will_wake(ctx.waker()) {
return Poll::Pending;
}
true
} else {
false
}
};
let waker = Arc::new(ctx.waker().clone());
_self.wg.set_waker(waker.clone(), _self.target, force);
_self.waker.replace(waker);
if _self._poll() {
return Poll::Ready(());
}
Poll::Pending
}
}
#[cfg(test)]
mod tests {
extern crate rand;
use std::time::Duration;
use tokio::time::{sleep, timeout};
use super::*;
fn make_runtime(threads: usize) -> tokio::runtime::Runtime {
return tokio::runtime::Builder::new_multi_thread()
.enable_all()
.worker_threads(threads)
.build()
.unwrap();
}
#[test]
fn test_inner() {
make_runtime(1).block_on(async move {
let wg = WaitGroup::new();
wg.add(2);
let _wg = wg.clone();
let th = tokio::spawn(async move {
assert!(_wg.wait_to(1).await);
});
sleep(Duration::from_secs(1)).await;
{
let guard = wg.0.waker.lock();
assert!(guard.is_some());
assert_eq!(wg.0.waiting.load(Ordering::Acquire), 1);
}
wg.done();
let _ = th.await;
assert_eq!(wg.0.waiting.load(Ordering::Acquire), -1);
assert_eq!(wg.left(), 1);
wg.done();
assert_eq!(wg.left(), 0);
assert_eq!(wg.wait_to(0).await, false);
});
}
#[test]
fn test_cancel() {
let wg = WaitGroup::new();
make_runtime(1).block_on(async move {
wg.add(1);
println!("test timeout");
assert!(timeout(Duration::from_secs(1), wg.wait()).await.is_err());
println!("timeout happened");
assert_eq!(wg.0.waiting.load(Ordering::Acquire), -1);
wg.done();
wg.add(2);
wg.done_many(2);
wg.add(2);
let _wg = wg.clone();
let th = tokio::spawn(async move {
_wg.wait().await;
});
sleep(Duration::from_millis(200)).await;
wg.done();
wg.done();
let _ = th.await;
});
}
}