1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
//! Golang like WaitGroup implementation that supports both sync and async Rust.
#![deny(missing_docs)]
#![deny(unsafe_code)]
#![deny(unused_qualifications)]
extern crate alloc;
use alloc::sync::Arc;
use core::pin::Pin;
use core::sync::atomic::{AtomicUsize, Ordering};
use core::task::Poll;
#[cfg(feature = "std")]
use std::fmt;
use event_listener::{Event, EventListener};
use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy};
use futures_core::ready;
/// Enables tasks to synchronize the beginning or end of some computation.
///
/// # Examples
///
/// ```
/// use async_waitgroup::WaitGroup;
///
/// # #[tokio::main(flavor = "current_thread")] async fn main() {
/// // Create a new wait group.
/// let wg = WaitGroup::new();
///
/// for _ in 0..4 {
/// // Create another reference to the wait group.
/// let wg = wg.clone();
///
/// tokio::spawn(async move {
/// // Do some work.
///
/// // Drop the reference to the wait group.
/// drop(wg);
/// });
/// }
///
/// // Block until all tasks have finished their work.
/// wg.wait().await;
/// # }
/// ```
pub struct WaitGroup {
inner: Arc<WgInner>,
}
/// Inner state of a `WaitGroup`.
struct WgInner {
count: AtomicUsize,
drop_ops: Event,
}
impl Default for WaitGroup {
fn default() -> Self {
Self {
inner: Arc::new(WgInner {
count: AtomicUsize::new(1),
drop_ops: Event::new(),
}),
}
}
}
impl WaitGroup {
/// Creates a new wait group and returns the single reference to it.
///
/// # Examples
///
/// ```
/// use async_waitgroup::WaitGroup;
///
/// let wg = WaitGroup::new();
/// ```
pub fn new() -> Self {
Self::default()
}
/// Drops this reference and waits until all other references are dropped.
///
/// # Examples
///
/// ```
/// use async_waitgroup::WaitGroup;
///
/// # #[tokio::main(flavor = "current_thread")] async fn main() {
/// let wg = WaitGroup::new();
///
/// tokio::spawn({
/// let wg = wg.clone();
/// async move {
/// // Block until both tasks have reached `wait()`.
/// wg.wait().await;
/// }
/// });
///
/// // Block until all tasks have finished their work.
/// wg.wait().await;
/// # }
/// ```
pub fn wait(self) -> Wait {
Wait::_new(WaitInner {
wg: self.inner.clone(),
listener: EventListener::new(),
})
}
/// Waits using the blocking strategy.
///
/// # Examples
///
/// ```
/// use std::thread;
///
/// use async_waitgroup::WaitGroup;
///
/// let wg = WaitGroup::new();
///
/// thread::spawn({
/// let wg = wg.clone();
/// move || {
/// wg.wait_blocking();
/// }
/// });
///
/// wg.wait_blocking();
/// ```
#[cfg(all(feature = "std", not(target_family = "wasm")))]
pub fn wait_blocking(self) {
self.wait().wait();
}
}
easy_wrapper! {
/// A future returned by [`WaitGroup::wait()`].
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Wait(WaitInner => ());
#[cfg(all(feature = "std", not(target_family = "wasm")))]
pub(crate) wait();
}
pin_project_lite::pin_project! {
struct WaitInner {
wg: Arc<WgInner>,
#[pin]
listener: EventListener,
}
}
impl EventListenerFuture for WaitInner {
type Output = ();
fn poll_with_strategy<'a, S: Strategy<'a>>(
self: Pin<&mut Self>,
strategy: &mut S,
context: &mut S::Context,
) -> Poll<Self::Output> {
let mut this = self.project();
if this.wg.count.load(Ordering::SeqCst) == 0 {
return Poll::Ready(());
}
let mut count = this.wg.count.load(Ordering::SeqCst);
while count > 0 {
if this.listener.is_listening() {
ready!(strategy.poll(this.listener.as_mut(), context))
} else {
this.listener.as_mut().listen(&this.wg.drop_ops);
}
count = this.wg.count.load(Ordering::SeqCst);
}
Poll::Ready(())
}
}
impl Drop for WaitGroup {
fn drop(&mut self) {
if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 {
self.inner.drop_ops.notify(usize::MAX);
}
}
}
impl Clone for WaitGroup {
fn clone(&self) -> Self {
self.inner.count.fetch_add(1, Ordering::SeqCst);
Self {
inner: self.inner.clone(),
}
}
}
#[cfg(feature = "std")]
impl fmt::Debug for WaitGroup {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let count = self.inner.count.load(Ordering::SeqCst);
f.debug_struct("WaitGroup").field("count", &count).finish()
}
}