use std::cell::UnsafeCell;
use crate::loom_shim::{Condvar, Mutex};
use crate::runtime::g::{current_g, WaitReason, G};
use crate::runtime::park::{gopark_commit, goready};
use crate::runtime::rawmutex::RawMutex;
struct WgState {
count: i64,
waiters: Vec<*mut G>,
}
unsafe impl Send for WgState {}
pub struct WaitGroup {
mu: RawMutex,
state: UnsafeCell<WgState>,
cond_lock: Mutex<()>,
cond: Condvar,
}
unsafe impl Send for WaitGroup {}
unsafe impl Sync for WaitGroup {}
impl WaitGroup {
pub fn new() -> Self {
Self {
mu: RawMutex::new(),
state: UnsafeCell::new(WgState { count: 0, waiters: Vec::new() }),
cond_lock: Mutex::new(()),
cond: Condvar::new(),
}
}
pub fn add(&self, delta: i64) {
let _lk = crate::runtime::m::m_lock();
let goroutine_waiters: Vec<*mut G> = {
self.mu.lock();
let state = unsafe { &mut *self.state.get() };
state.count += delta;
if state.count < 0 {
unsafe { self.mu.unlock() };
panic!("sync: negative WaitGroup counter");
}
let zero = state.count == 0;
let w = if zero { std::mem::take(&mut state.waiters) } else { Vec::new() };
unsafe { self.mu.unlock() };
if zero {
let _g = self.cond_lock.lock().unwrap();
self.cond.notify_all();
}
w
};
for gp in goroutine_waiters {
unsafe { goready(gp) };
}
}
pub fn done(&self) {
self.add(-1);
}
pub fn wait(&self) {
let gp = current_g();
if !gp.is_null() {
let _lk = crate::runtime::m::m_lock();
self.mu.lock();
let state = unsafe { &mut *self.state.get() };
if state.count == 0 {
unsafe { self.mu.unlock() };
return; }
state.waiters.push(gp);
std::mem::forget(_lk);
unsafe {
gopark_commit(
WaitReason::Semacquire,
unlock_wg_mutex,
&self.mu as *const RawMutex as *mut u8,
);
}
return;
}
let mut guard = self.cond_lock.lock().unwrap();
loop {
self.mu.lock();
let count = unsafe { (*self.state.get()).count };
unsafe { self.mu.unlock() };
if count == 0 {
return;
}
guard = self.cond.wait(guard).unwrap();
}
}
}
impl Default for WaitGroup {
fn default() -> Self {
Self::new()
}
}
unsafe fn unlock_wg_mutex(arg: *mut u8) {
unsafe { (*(arg as *const RawMutex)).unlock() }
}
#[cfg(all(test, not(loom)))]
mod tests {
use super::*;
use crate::runtime::sched::run_impl;
use std::sync::atomic::{AtomicI32, Ordering};
use std::sync::Arc;
#[test]
fn new_wait_returns_immediately() {
let wg = WaitGroup::new();
wg.wait(); }
#[test]
fn single_worker() {
run_impl(|| {
let wg = Arc::new(WaitGroup::new());
let done = Arc::new(AtomicI32::new(0));
wg.add(1);
let wg2 = Arc::clone(&wg);
let done2 = Arc::clone(&done);
crate::runtime::sched::spawn_goroutine(move || {
done2.fetch_add(1, Ordering::Relaxed);
wg2.done();
});
wg.wait();
assert_eq!(done.load(Ordering::Acquire), 1);
});
}
#[test]
fn multiple_workers() {
const N: i32 = 5;
let count = Arc::new(AtomicI32::new(0));
let count2 = Arc::clone(&count);
run_impl(move || {
let wg = Arc::new(WaitGroup::new());
for _ in 0..N {
wg.add(1);
let wg2 = Arc::clone(&wg);
let count3 = Arc::clone(&count2);
crate::runtime::sched::spawn_goroutine(move || {
count3.fetch_add(1, Ordering::Relaxed);
wg2.done();
});
}
wg.wait();
assert_eq!(count2.load(Ordering::Acquire), N);
});
assert_eq!(count.load(Ordering::Acquire), N);
}
#[test]
fn multiple_waiters() {
let woke = Arc::new(AtomicI32::new(0));
let woke2 = Arc::clone(&woke);
run_impl(move || {
let wg = Arc::new(WaitGroup::new());
wg.add(1);
for _ in 0..2 {
let wg3 = Arc::clone(&wg);
let woke3 = Arc::clone(&woke2);
crate::runtime::sched::spawn_goroutine(move || {
wg3.wait();
woke3.fetch_add(1, Ordering::Relaxed);
});
}
for _ in 0..20 { crate::gosched(); }
wg.done();
let deadline = std::time::Instant::now()
+ std::time::Duration::from_millis(500);
loop {
if woke2.load(Ordering::Acquire) >= 2 { break; }
assert!(
std::time::Instant::now() < deadline,
"timed out: only {} of 2 waiters woke",
woke2.load(Ordering::Relaxed),
);
crate::gosched();
}
});
assert_eq!(woke.load(Ordering::Acquire), 2, "both waiters must wake");
}
#[test]
fn reuse_after_wait() {
run_impl(|| {
let wg = Arc::new(WaitGroup::new());
wg.add(1);
let wg2 = Arc::clone(&wg);
crate::runtime::sched::spawn_goroutine(move || { wg2.done(); });
wg.wait();
let done = Arc::new(AtomicI32::new(0));
wg.add(1);
let wg3 = Arc::clone(&wg);
let done2 = Arc::clone(&done);
crate::runtime::sched::spawn_goroutine(move || {
done2.store(1, Ordering::Relaxed);
wg3.done();
});
wg.wait();
assert_eq!(done.load(Ordering::Acquire), 1);
});
}
#[test]
#[should_panic(expected = "sync: negative WaitGroup counter")]
fn negative_counter_panics() {
let wg = WaitGroup::new();
wg.add(-1); }
}
#[cfg(all(test, loom))]
mod loom_tests {
use super::*;
use loom::sync::Arc;
#[test]
fn done_unblocks_wait() {
loom::model(|| {
let wg = Arc::new(WaitGroup::new());
let wg2 = Arc::clone(&wg);
wg.add(1);
let worker = loom::thread::spawn(move || {
wg2.done();
});
wg.wait();
worker.join().unwrap();
});
}
#[test]
fn two_workers_unblock_wait() {
loom::model(|| {
let wg = Arc::new(WaitGroup::new());
let wg2 = Arc::clone(&wg);
let wg3 = Arc::clone(&wg);
let wg4 = Arc::clone(&wg);
wg.add(2);
let t1 = loom::thread::spawn(move || wg2.done());
let t2 = loom::thread::spawn(move || wg3.done());
let waiter = loom::thread::spawn(move || wg4.wait());
t1.join().unwrap();
t2.join().unwrap();
waiter.join().unwrap();
});
}
#[test]
fn add_and_done_interleave() {
loom::model(|| {
let wg = Arc::new(WaitGroup::new());
let wg2 = Arc::clone(&wg);
let wg3 = Arc::clone(&wg);
wg.add(1);
let adder = loom::thread::spawn(move || wg2.done());
wg3.wait();
adder.join().unwrap();
});
}
}