use std::sync::{Arc, atomic::{AtomicI64, Ordering}};
use futures_util::{{future::Future}, task::{Context, AtomicWaker, Poll}};
use std::pin::Pin;
#[derive(Clone)]
pub struct TaskGroup {
inner: Arc<Inner>
}
impl Default for TaskGroup {
fn default() -> Self {
Self::new()
}
}
impl TaskGroup {
pub fn new() -> Self {
TaskGroup {
inner: Arc::new(Inner::new()),
}
}
pub fn add(&self) {
self.inner.add(1);
}
pub fn add_n(&self, n: u32) {
self.inner.add(n)
}
pub fn add_work(&self, n: u32) -> Work {
self.add_n(n);
Work {
n,
inner: self.inner.clone(),
}
}
pub fn done(&self) {
self.inner.done();
}
pub fn done_n(&self, n: u32) {
self.inner.done_n(n);
}
pub fn wait(&self) -> WaitFuture {
WaitFuture {
inner: self.inner.clone(),
}
}
}
struct Inner {
counter: AtomicI64,
waker: AtomicWaker,
}
impl Inner {
fn new() -> Self {
Inner {
counter: AtomicI64::new(0),
waker: AtomicWaker::new(),
}
}
fn reset(&self) {
self.counter.store(0, Ordering::Relaxed);
}
fn add(&self, n: u32) {
if n == 0 {
return
}
self.counter.fetch_add(n as i64, Ordering::Relaxed);
}
fn done_n(&self, n: u32) {
if n == 0 {
return
}
let n = n as i64;
let prev_val = self.counter.fetch_sub(n, Ordering::Release);
if prev_val - n <= 1 {
self.waker.wake();
}
}
pub fn done(&self) {
self.done_n(1);
}
}
pub struct Work {
n: u32,
inner: Arc<Inner>,
}
impl Drop for Work {
fn drop(&mut self) {
self.inner.done_n(self.n)
}
}
pub struct WaitFuture {
inner: Arc<Inner>,
}
impl Future for WaitFuture {
type Output = i64;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let n = self.inner.counter.load(Ordering::Acquire);
if n <= 0 {
self.inner.reset();
Poll::Ready(n)
}else{
self.inner.waker.register(cx.waker());
Poll::Pending
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use tokio::sync::Mutex;
use super::*;
#[tokio::test]
async fn basic_test_add() {
let tg = TaskGroup::new();
let num = Arc::new(Mutex::new(0));
let count = 10000;
for _ in 0..count {
tg.add();
let tg_c = tg.clone();
let n = num.clone();
tokio::spawn(async move {
{
let mut n = n.lock().await;
*n += 1;
}
tg_c.done();
});
}
tg.wait().await;
let n = num.lock().await;
assert_eq!(count, *n);
}
async fn _basic_test_add_work(tg: TaskGroup) {
let num = Arc::new(Mutex::new(0));
let count = 10000;
for _ in 0..count {
let work = tg.add_work(1);
let n = num.clone();
tokio::spawn(async move {
let _work = work;
let mut n = n.lock().await;
*n += 1;
});
}
tg.wait().await;
let n = num.lock().await;
assert_eq!(count, *n);
}
#[tokio::test]
async fn basic_test_add_work() {
let tg = TaskGroup::new();
_basic_test_add_work(tg).await;
}
#[tokio::test]
async fn basic_test_add_work_resuse() {
let tg = TaskGroup::new();
_basic_test_add_work(tg.clone()).await;
_basic_test_add_work(tg).await;
}
#[tokio::test]
async fn basic_test_addn_workn() {
let tg = TaskGroup::new();
let num = Arc::new(Mutex::new(0));
let count = 10000;
let work = tg.add_work(count);
let num_c = num.clone();
tokio::spawn(async move {
let _work = work;
let mut hvec = Vec::new();
for _ in 0..count {
let n = num_c.clone();
let h = tokio::spawn(async move {
let mut n = n.lock().await;
*n += 1;
});
hvec.push(h);
}
for h in hvec {
let _ = h.await;
}
});
tg.wait().await;
let n = num.lock().await;
assert_eq!(count, *n);
}
#[tokio::test]
async fn basic_test_addn_donen() {
let tg = TaskGroup::new();
let num = Arc::new(Mutex::new(0));
let count = 10000;
tg.add_n(count);
let num_c = num.clone();
let tg_c = tg.clone();
tokio::spawn(async move {
let mut hvec = Vec::new();
for _ in 0..count {
let n = num_c.clone();
let h = tokio::spawn(async move {
let mut n = n.lock().await;
*n += 1;
});
hvec.push(h);
}
for h in hvec {
let _ = h.await;
}
tg_c.done_n(count);
});
tg.wait().await;
let n = num.lock().await;
assert_eq!(count, *n);
}
#[tokio::test]
async fn basic_test_add0_work() {
let tg = TaskGroup::new();
let count = 10000;
let _work_0 = tg.add_work(0);
let work = tg.add_work(count);
drop(work);
tg.wait().await;
}
#[tokio::test]
async fn neg_wait() {
let tg = TaskGroup::new();
let count = 1000;
let _work = tg.add_work(count);
tg.done_n(count + 1);
let n = tg.wait().await;
assert_eq!(-1, n);
}
}