use core::sync::atomic::{AtomicUsize, Ordering};
#[cfg(all(any(feature = "std", feature = "alloc"), not(feature = "triomphe")))]
use std::sync::Arc;
#[cfg(feature = "triomphe")]
use triomphe::Arc;
#[inline]
fn backoff_step(iter: &mut u32) {
const SPIN_LIMIT: u32 = 6;
if *iter <= SPIN_LIMIT {
for _ in 0..(1u32 << *iter) {
core::hint::spin_loop();
}
} else {
#[cfg(feature = "std")]
std::thread::yield_now();
#[cfg(not(feature = "std"))]
for _ in 0..(1u32 << SPIN_LIMIT) {
core::hint::spin_loop();
}
}
*iter = iter.saturating_add(1);
}
#[derive(Debug)]
struct Inner {
counter: AtomicUsize,
}
pub struct WaitGroup {
inner: Arc<Inner>,
}
impl Default for WaitGroup {
fn default() -> Self {
Self {
inner: Arc::new(Inner {
counter: AtomicUsize::new(0),
}),
}
}
}
impl From<usize> for WaitGroup {
fn from(count: usize) -> Self {
Self {
inner: Arc::new(Inner {
counter: AtomicUsize::new(count),
}),
}
}
}
impl Clone for WaitGroup {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl core::fmt::Debug for WaitGroup {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("WaitGroup")
.field("counter", &self.inner.counter)
.finish()
}
}
impl core::ops::AddAssign<usize> for WaitGroup {
fn add_assign(&mut self, rhs: usize) {
self.add(rhs);
}
}
impl WaitGroup {
pub fn new() -> Self {
Self::default()
}
pub fn add(&self, num: usize) -> Self {
self
.inner
.counter
.fetch_update(Ordering::Release, Ordering::Relaxed, |prev| {
prev.checked_add(num)
})
.unwrap_or_else(|prev| panic!("WaitGroup counter overflow: prev={prev}, num={num}"));
Self {
inner: self.inner.clone(),
}
}
pub fn done(&self) -> usize {
match self
.inner
.counter
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| v.checked_sub(1))
{
Ok(old) => old - 1,
Err(_) => 0,
}
}
pub fn remaining(&self) -> usize {
self.inner.counter.load(Ordering::Acquire)
}
pub fn wait(&self) {
if self.inner.counter.load(Ordering::Acquire) == 0 {
return;
}
let mut iter = 0u32;
while self.inner.counter.load(Ordering::Acquire) != 0 {
backoff_step(&mut iter);
}
}
}