use crate::loom_shim::{Condvar, Mutex};
use crate::runtime::g::{current_g, WaitReason, G};
use crate::runtime::park::{gopark, goready};
struct WgState {
count: i64,
waiters: Vec<*mut G>,
}
unsafe impl Send for WgState {}
pub struct WaitGroup {
state: Mutex<WgState>,
cond: Condvar,
}
impl WaitGroup {
pub fn new() -> Self {
Self {
state: Mutex::new(WgState { count: 0, waiters: Vec::new() }),
cond: Condvar::new(),
}
}
pub fn add(&self, delta: i64) {
let goroutine_waiters: Vec<*mut G> = {
let mut state = self.state.lock().unwrap();
state.count += delta;
if state.count < 0 {
drop(state);
panic!("sync: negative WaitGroup counter");
}
if state.count == 0 {
let w = std::mem::take(&mut state.waiters);
drop(state);
self.cond.notify_all();
w
} else {
Vec::new()
}
};
for gp in goroutine_waiters {
unsafe { goready(gp) };
}
}
pub fn done(&self) {
self.add(-1);
}
pub fn wait(&self) {
let gp = current_g();
if !gp.is_null() {
let mut state = self.state.lock().unwrap();
if state.count == 0 {
return; }
state.waiters.push(gp);
drop(state);
unsafe { gopark(WaitReason::Semacquire) };
return;
}
let mut state = self.state.lock().unwrap();
while state.count > 0 {
state = self.cond.wait(state).unwrap();
}
}
}
impl Default for WaitGroup {
fn default() -> Self {
Self::new()
}
}
#[cfg(all(test, not(loom)))]
mod tests {
use super::*;
use crate::runtime::sched::run_impl;
use std::sync::atomic::{AtomicI32, Ordering};
use std::sync::Arc;
#[test]
fn new_wait_returns_immediately() {
let wg = WaitGroup::new();
wg.wait(); }
#[test]
fn single_worker() {
run_impl(|| {
let wg = Arc::new(WaitGroup::new());
let done = Arc::new(AtomicI32::new(0));
wg.add(1);
let wg2 = Arc::clone(&wg);
let done2 = Arc::clone(&done);
unsafe {
crate::runtime::sched::spawn_goroutine(move || {
done2.fetch_add(1, Ordering::Relaxed);
wg2.done();
});
}
wg.wait();
assert_eq!(done.load(Ordering::Acquire), 1);
});
}
#[test]
fn multiple_workers() {
const N: i32 = 5;
let count = Arc::new(AtomicI32::new(0));
let count2 = Arc::clone(&count);
run_impl(move || {
let wg = Arc::new(WaitGroup::new());
for _ in 0..N {
wg.add(1);
let wg2 = Arc::clone(&wg);
let count3 = Arc::clone(&count2);
unsafe {
crate::runtime::sched::spawn_goroutine(move || {
count3.fetch_add(1, Ordering::Relaxed);
wg2.done();
});
}
}
wg.wait();
assert_eq!(count2.load(Ordering::Acquire), N);
});
assert_eq!(count.load(Ordering::Acquire), N);
}
#[test]
fn multiple_waiters() {
let woke = Arc::new(AtomicI32::new(0));
let woke2 = Arc::clone(&woke);
run_impl(move || {
let wg = Arc::new(WaitGroup::new());
wg.add(1);
for _ in 0..2 {
let wg3 = Arc::clone(&wg);
let woke3 = Arc::clone(&woke2);
unsafe {
crate::runtime::sched::spawn_goroutine(move || {
wg3.wait();
woke3.fetch_add(1, Ordering::Relaxed);
});
}
}
for _ in 0..20 { crate::gosched(); }
wg.done();
let deadline = std::time::Instant::now()
+ std::time::Duration::from_millis(500);
loop {
if woke2.load(Ordering::Acquire) >= 2 { break; }
assert!(
std::time::Instant::now() < deadline,
"timed out: only {} of 2 waiters woke",
woke2.load(Ordering::Relaxed),
);
crate::gosched();
}
});
assert_eq!(woke.load(Ordering::Acquire), 2, "both waiters must wake");
}
#[test]
fn reuse_after_wait() {
run_impl(|| {
let wg = Arc::new(WaitGroup::new());
wg.add(1);
let wg2 = Arc::clone(&wg);
unsafe {
crate::runtime::sched::spawn_goroutine(move || { wg2.done(); });
}
wg.wait();
let done = Arc::new(AtomicI32::new(0));
wg.add(1);
let wg3 = Arc::clone(&wg);
let done2 = Arc::clone(&done);
unsafe {
crate::runtime::sched::spawn_goroutine(move || {
done2.store(1, Ordering::Relaxed);
wg3.done();
});
}
wg.wait();
assert_eq!(done.load(Ordering::Acquire), 1);
});
}
#[test]
#[should_panic(expected = "sync: negative WaitGroup counter")]
fn negative_counter_panics() {
let wg = WaitGroup::new();
wg.add(-1); }
}
#[cfg(all(test, loom))]
mod loom_tests {
use super::*;
use loom::sync::Arc;
#[test]
fn done_unblocks_wait() {
loom::model(|| {
let wg = Arc::new(WaitGroup::new());
let wg2 = Arc::clone(&wg);
wg.add(1);
let worker = loom::thread::spawn(move || {
wg2.done();
});
wg.wait();
worker.join().unwrap();
});
}
#[test]
fn two_workers_unblock_wait() {
loom::model(|| {
let wg = Arc::new(WaitGroup::new());
let wg2 = Arc::clone(&wg);
let wg3 = Arc::clone(&wg);
let wg4 = Arc::clone(&wg);
wg.add(2);
let t1 = loom::thread::spawn(move || wg2.done());
let t2 = loom::thread::spawn(move || wg3.done());
let waiter = loom::thread::spawn(move || wg4.wait());
t1.join().unwrap();
t2.join().unwrap();
waiter.join().unwrap();
});
}
#[test]
fn add_and_done_interleave() {
loom::model(|| {
let wg = Arc::new(WaitGroup::new());
let wg2 = Arc::clone(&wg);
let wg3 = Arc::clone(&wg);
wg.add(1);
let adder = loom::thread::spawn(move || wg2.done());
wg3.wait();
adder.join().unwrap();
});
}
}