use std::fmt;
use std::future::Future;
use std::future::IntoFuture;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use crate::internal::CountdownState;
#[cfg(test)]
mod tests;
pub struct WaitGroup {
state: Arc<CountdownState>,
}
impl fmt::Debug for WaitGroup {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WaitGroup").finish_non_exhaustive()
}
}
impl Default for WaitGroup {
fn default() -> Self {
Self::new()
}
}
impl WaitGroup {
pub fn new() -> Self {
Self {
state: Arc::new(CountdownState::new(1)),
}
}
}
impl Clone for WaitGroup {
fn clone(&self) -> Self {
let sync = self.state.clone();
let mut cnt = sync.state();
loop {
let new_cnt = cnt.saturating_add(1);
match sync.cas_state(cnt, new_cnt) {
Ok(_) => return Self { state: sync },
Err(x) => cnt = x,
}
}
}
}
impl Drop for WaitGroup {
fn drop(&mut self) {
if self.state.decrement(1) {
self.state.wake_all();
}
}
}
impl IntoFuture for WaitGroup {
type Output = ();
type IntoFuture = Wait;
fn into_future(self) -> Self::IntoFuture {
let state = self.state.clone();
drop(self);
Wait { idx: None, state }
}
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Wait {
idx: Option<usize>,
state: Arc<CountdownState>,
}
impl Clone for Wait {
fn clone(&self) -> Self {
Wait {
idx: None,
state: self.state.clone(),
}
}
}
impl fmt::Debug for Wait {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Wait").finish_non_exhaustive()
}
}
impl Future for Wait {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self { idx, state } = self.get_mut();
if state.spin_wait(16).is_err() {
state.register_waker(idx, cx);
if state.spin_wait(0).is_err() {
return Poll::Pending;
}
}
Poll::Ready(())
}
}