use std::sync::atomic::{AtomicU32, Ordering};
use omango_util::hint::likely;
pub struct WaitGroup {
count: AtomicU32,
flag: AtomicU32,
}
impl Default for WaitGroup {
#[inline(always)]
fn default() -> Self {
Self::new(0)
}
}
impl WaitGroup {
#[inline(always)]
pub fn new(n: u32) -> Self {
Self {
count: AtomicU32::new(n),
flag: AtomicU32::new(0),
}
}
#[inline(always)]
pub fn add(&self, n: u32) {
self.count.fetch_add(n, Ordering::SeqCst);
}
#[inline(always)]
pub fn done(&self) {
let count = self.count.fetch_sub(1, Ordering::SeqCst);
assert!(count >= 1);
if likely(count > 1) {
return;
}
self.flag.store(1, Ordering::Relaxed);
omango_futex::wake_all(&self.flag);
}
pub fn wait(&self) {
while self.should_wait() {
omango_futex::wait(&self.flag, 0);
}
self.flag.store(0, Ordering::Relaxed);
}
#[inline(always)]
fn should_wait(&self) -> bool {
self.count.load(Ordering::SeqCst) > 0
}
}
mod test {
#[test]
fn test_wait_on_one() {
let wg = std::sync::Arc::new(crate::wg::WaitGroup::new(1));
let wg_clone = wg.clone();
let count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let count_clone = count.clone();
let thread = std::thread::spawn(move || {
wg_clone.add(1);
count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
wg_clone.done();
});
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
wg.done();
wg.wait();
thread.join().unwrap();
assert_eq!(count.load(std::sync::atomic::Ordering::Relaxed), 2);
}
#[test]
fn test_wait_on_gt_one() {
let wg = std::sync::Arc::new(crate::wg::WaitGroup::new(1));
let wg_clone = wg.clone();
let count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let count_clone = count.clone();
let thread = std::thread::spawn(move || {
wg_clone.add(1);
count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
wg_clone.done();
wg_clone.wait();
assert_eq!(count_clone.load(std::sync::atomic::Ordering::Relaxed), 2);
});
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
wg.done();
wg.wait();
thread.join().unwrap();
assert_eq!(count.load(std::sync::atomic::Ordering::Relaxed), 2);
}
#[test]
fn test_done_without_size() {
let result = std::panic::catch_unwind(|| {
let wg = crate::wg::WaitGroup::default();
wg.done();
});
assert!(result.is_err());
}
}