go_lib/sync/waitgroup.rs
1// SPDX-License-Identifier: Apache-2.0
2//! `WaitGroup` — ported from `src/sync/waitgroup.go`.
3//!
4//! ## Semantics (match Go)
5//!
6//! - [`WaitGroup::add`]`(delta)` — increment the counter by `delta`. `delta`
7//! may be negative (used internally by [`done`][WaitGroup::done]).
8//! Panics if the counter goes negative.
9//! - [`WaitGroup::done`]`()` — shorthand for `add(-1)`.
10//! - [`WaitGroup::wait`]`()` — block until the counter reaches zero.
11//! Multiple goroutines may call `wait` concurrently; all are unblocked when
12//! the last worker calls `done`.
13//!
14//! ## Implementation
15//!
16//! ### Goroutine path (the common case)
17//!
18//! `wait` uses [`gopark`][crate::runtime::park::gopark] to suspend the calling
19//! goroutine back into the scheduler **without blocking the OS thread**. The
20//! M and its P remain free to run other goroutines. When `add` decrements the
21//! counter to zero it drains the waiters list and calls
22//! [`goready`][crate::runtime::park::goready] on each, re-enqueuing them for
23//! scheduling.
24//!
25//! ### Non-goroutine / loom path
26//!
27//! If `wait` is called from a bare OS thread (outside the go-lib scheduler, or
28//! from a loom model thread) it falls back to blocking on a `Condvar`. This
29//! path is also used by the `negative_counter_panics` unit test which never
30//! enters the scheduler.
31//!
32//! Ported from `sync/waitgroup.go`.
33
34use std::cell::UnsafeCell;
35
36use crate::loom_shim::{Condvar, Mutex};
37use crate::runtime::g::{current_g, WaitReason, G};
38use crate::runtime::park::{gopark_commit, goready};
39use crate::runtime::rawmutex::RawMutex;
40
41// ---------------------------------------------------------------------------
42// Internal state
43// ---------------------------------------------------------------------------
44
45struct WgState {
46 /// Number of outstanding workers (`Add` increments, `Done` decrements).
47 count: i64,
48 /// Goroutines suspended in [`WaitGroup::wait`]. Drained by
49 /// [`WaitGroup::add`] when the counter reaches zero; each entry is woken
50 /// via [`goready`].
51 waiters: Vec<*mut G>,
52}
53
54// SAFETY: `WgState` is always accessed under the `WaitGroup`'s `Mutex`.
55// The `*mut G` pointers are goroutines owned by the scheduler; we never
56// dereference them without holding the lock (except to pass to `goready`,
57// which is safe once the goroutine has reached `GWAITING`).
58unsafe impl Send for WgState {}
59
60// ---------------------------------------------------------------------------
61// WaitGroup
62// ---------------------------------------------------------------------------
63
64/// A synchronisation barrier: wait for a set of goroutines to complete.
65///
66/// The typical usage pattern is:
67///
68/// ```no_run
69/// use std::sync::Arc;
70/// use go_lib::sync::WaitGroup;
71///
72/// #[go_lib::main]
73/// fn main() {
74/// let wg = Arc::new(WaitGroup::new());
75/// for i in 0..5 {
76/// let wg = Arc::clone(&wg);
77/// go_lib::go!(move || {
78/// wg.add(1);
79/// // ... do work, then call wg.done() when finished ...
80/// });
81/// }
82/// // wg.wait(); // blocks until all done() calls have been made
83/// }
84/// ```
85pub struct WaitGroup {
86 /// Spinlock protecting `state`. A `RawMutex` (not `std::sync::Mutex`)
87 /// so that the goroutine `wait()` path can hold the lock ACROSS the park
88 /// via `gopark_commit` — the lock is released on g0 only after the
89 /// waiter is `GWAITING`. With a guard-based Mutex released before
90 /// `gopark`, an async preemption in the release-to-park window made the
91 /// registered waiter `GRUNNABLE`; `add()`'s `goready` then dropped the
92 /// wake on the floor and the waiter parked forever.
93 mu: RawMutex,
94 /// Interior state — always accessed under `mu`.
95 state: UnsafeCell<WgState>,
96 /// Companion lock for `cond` — used only by the non-goroutine fallback
97 /// path (bare OS threads / loom model threads).
98 cond_lock: Mutex<()>,
99 /// Condvar used only on the non-goroutine fallback path.
100 cond: Condvar,
101}
102
103// SAFETY: `state` is only accessed while `mu` is held; the `*mut G` waiters
104// are scheduler-owned and only ever passed to `goready` (which dereferences
105// the `G` descriptor, never the waiter's stack — see `add`).
106unsafe impl Send for WaitGroup {}
107unsafe impl Sync for WaitGroup {}
108
109impl WaitGroup {
110 /// Create a new `WaitGroup` with a counter of zero.
111 pub fn new() -> Self {
112 Self {
113 mu: RawMutex::new(),
114 state: UnsafeCell::new(WgState { count: 0, waiters: Vec::new() }),
115 cond_lock: Mutex::new(()),
116 cond: Condvar::new(),
117 }
118 }
119
120 /// Add `delta` to the counter.
121 ///
122 /// `delta` is typically positive when called before spawning goroutines and
123 /// negative when they finish (see [`done`][Self::done]).
124 ///
125 /// # Panics
126 ///
127 /// Panics if the counter drops below zero.
128 pub fn add(&self, delta: i64) {
129 // Hold an `m.locks` guard across the std::sync::Mutex critical section.
130 // Without it, SIGURG-based async preemption can fire after
131 // `pthread_mutex_lock` returns but before we drop the MutexGuard: the
132 // preempted goroutine still holds the OS-level pthread mutex, and the
133 // next goroutine scheduled on the same M will self-deadlock trying to
134 // re-acquire it (the default pthread mutex is non-recursive, so a
135 // same-thread re-lock blocks forever in `__psynch_mutexwait`).
136 // Captured live via lldb on a hung `many_goroutines` run.
137 let _lk = crate::runtime::m::m_lock();
138 // WaitGroup waiters are woken only via `goready`, which dereferences
139 // only the `G` descriptor and never the waiter's stack.
140 // Collect goroutine waiters to wake (if counter reaches zero).
141 let goroutine_waiters: Vec<*mut G> = {
142 self.mu.lock();
143 // SAFETY: `mu` is held.
144 let state = unsafe { &mut *self.state.get() };
145 state.count += delta;
146 if state.count < 0 {
147 unsafe { self.mu.unlock() };
148 panic!("sync: negative WaitGroup counter");
149 }
150 let zero = state.count == 0;
151 let w = if zero { std::mem::take(&mut state.waiters) } else { Vec::new() };
152 unsafe { self.mu.unlock() };
153 if zero {
154 // Wake condvar waiters (non-goroutine / loom path). Taking
155 // `cond_lock` between the count update and the notify pairs
156 // with the re-check the bare-thread `wait` does under
157 // `cond_lock`, so the notify cannot be missed.
158 let _g = self.cond_lock.lock().unwrap();
159 self.cond.notify_all();
160 }
161 w
162 };
163
164 // Wake goroutine waiters outside the lock so we don't hold it during
165 // the goready spin (which waits for GRUNNING → GWAITING).
166 for gp in goroutine_waiters {
167 unsafe { goready(gp) };
168 }
169 }
170
171 /// Decrement the counter by one.
172 ///
173 /// Shorthand for `self.add(-1)`.
174 pub fn done(&self) {
175 self.add(-1);
176 }
177
178 /// Block until the counter is zero.
179 ///
180 /// When called from a goroutine: suspends the goroutine back into the
181 /// scheduler via `gopark` so the M and P remain free to run other
182 /// goroutines. Resumed by `add` calling `goready` when the counter
183 /// reaches zero.
184 ///
185 /// When called from a bare OS thread (outside the go-lib scheduler):
186 /// blocks the thread on an internal `Condvar`.
187 pub fn wait(&self) {
188 // ── Goroutine path ──────────────────────────────────────────────────
189 // gopark suspends this goroutine without blocking the OS thread.
190 // The M+P are returned to the scheduler to run other goroutines
191 // (including whoever will call done() to reach count == 0).
192 let gp = current_g();
193 if !gp.is_null() {
194 // m_lock suppresses async preemption while we hold `mu` AND
195 // across the `gopark_commit` below — the increment is
196 // transferred to `park_fn`, which balances it on the same M
197 // (see gopark_commit's contract).
198 let _lk = crate::runtime::m::m_lock();
199 self.mu.lock();
200 // SAFETY: `mu` is held.
201 let state = unsafe { &mut *self.state.get() };
202 if state.count == 0 {
203 unsafe { self.mu.unlock() };
204 return; // fast path: already done (`_lk` drops normally)
205 }
206 // Register as a waiter; the lock stays held until park_fn has
207 // committed this goroutine to GWAITING (commit-park protocol),
208 // so add() can only observe the registration once goready is
209 // guaranteed to find us parked. Releasing before the park left
210 // a window where preemption made us GRUNNABLE and add()'s wake
211 // was silently dropped — the waiter then parked forever.
212 state.waiters.push(gp);
213 // Transfer the m.locks increment to park_fn.
214 std::mem::forget(_lk);
215 unsafe {
216 gopark_commit(
217 WaitReason::Semacquire,
218 unlock_wg_mutex,
219 &self.mu as *const RawMutex as *mut u8,
220 );
221 }
222 return;
223 }
224
225 // ── Non-goroutine / loom path ───────────────────────────────────────
226 // Block the calling OS thread on the condvar. Used by tests that call
227 // wait() from a bare thread and by loom model threads. The count is
228 // re-checked under `cond_lock`, which `add()` acquires between
229 // setting count == 0 and notifying — so the notify cannot fall
230 // between our check and the `cond.wait`.
231 let mut guard = self.cond_lock.lock().unwrap();
232 loop {
233 self.mu.lock();
234 // SAFETY: `mu` is held.
235 let count = unsafe { (*self.state.get()).count };
236 unsafe { self.mu.unlock() };
237 if count == 0 {
238 return;
239 }
240 guard = self.cond.wait(guard).unwrap();
241 }
242 }
243}
244
245impl Default for WaitGroup {
246 fn default() -> Self {
247 Self::new()
248 }
249}
250
251/// `gopark_commit` unlock shim: release the WaitGroup's `RawMutex` from g0
252/// after the parking goroutine has reached `GWAITING`.
253///
254/// # Safety
255/// `arg` must be the `&RawMutex` of a `WaitGroup` whose lock is held by the
256/// parking goroutine.
257unsafe fn unlock_wg_mutex(arg: *mut u8) {
258 unsafe { (*(arg as *const RawMutex)).unlock() }
259}
260
261// ---------------------------------------------------------------------------
262// Tests
263// ---------------------------------------------------------------------------
264
265#[cfg(all(test, not(loom)))]
266mod tests {
267 use super::*;
268 use std::sync::atomic::{AtomicI32, Ordering};
269 use std::sync::Arc;
270
271 /// A freshly created WaitGroup has counter zero; wait() returns immediately.
272 #[test]
273 fn new_wait_returns_immediately() {
274 let wg = WaitGroup::new();
275 wg.wait(); // must not block
276 }
277
278 /// add + done in a single goroutine; wait unblocks after done.
279 #[test]
280 #[go_lib::main]
281 fn single_worker() {
282 let wg = Arc::new(WaitGroup::new());
283 let done = Arc::new(AtomicI32::new(0));
284
285 wg.add(1);
286 let wg2 = Arc::clone(&wg);
287 let done2 = Arc::clone(&done);
288 crate::runtime::sched::spawn_goroutine(move || {
289 done2.fetch_add(1, Ordering::Relaxed);
290 wg2.done();
291 });
292
293 wg.wait();
294 assert_eq!(done.load(Ordering::Acquire), 1);
295 }
296
297 /// Five workers; wait unblocks only after all five call done.
298 #[test]
299 #[go_lib::main]
300 fn multiple_workers() {
301 const N: i32 = 5;
302 let count = Arc::new(AtomicI32::new(0));
303 let count2 = Arc::clone(&count);
304
305 let wg = Arc::new(WaitGroup::new());
306
307 for _ in 0..N {
308 wg.add(1);
309 let wg2 = Arc::clone(&wg);
310 let count3 = Arc::clone(&count2);
311 crate::runtime::sched::spawn_goroutine(move || {
312 count3.fetch_add(1, Ordering::Relaxed);
313 wg2.done();
314 });
315 }
316
317 wg.wait();
318 assert_eq!(count2.load(Ordering::Acquire), N);
319
320 assert_eq!(count.load(Ordering::Acquire), N);
321 }
322
323 /// Two goroutines both wait; both unblock when the counter reaches zero.
324 #[test]
325 #[go_lib::main]
326 fn multiple_waiters() {
327 let woke = Arc::new(AtomicI32::new(0));
328 let woke2 = Arc::clone(&woke);
329
330 let wg = Arc::new(WaitGroup::new());
331 wg.add(1);
332
333 // Spawn two waiters.
334 for _ in 0..2 {
335 let wg3 = Arc::clone(&wg);
336 let woke3 = Arc::clone(&woke2);
337 crate::runtime::sched::spawn_goroutine(move || {
338 wg3.wait();
339 woke3.fetch_add(1, Ordering::Relaxed);
340 });
341 }
342
343 // Yield so the waiters have a chance to call wg.wait() and park.
344 for _ in 0..20 { crate::gosched(); }
345 wg.done();
346
347 // Poll until both waiter goroutines have run past wg.wait() and
348 // incremented woke. A fixed gosched-loop is not deterministic
349 // under parallel test load; polling on the atomic is race-free.
350 let deadline = std::time::Instant::now()
351 + std::time::Duration::from_millis(500);
352 loop {
353 if woke2.load(Ordering::Acquire) >= 2 { break; }
354 assert!(
355 std::time::Instant::now() < deadline,
356 "timed out: only {} of 2 waiters woke",
357 woke2.load(Ordering::Relaxed),
358 );
359 crate::gosched();
360 }
361
362 assert_eq!(woke.load(Ordering::Acquire), 2, "both waiters must wake");
363 }
364
365 /// WaitGroup is reusable: after wait() the counter can be incremented again.
366 #[test]
367 #[go_lib::main]
368 fn reuse_after_wait() {
369 let wg = Arc::new(WaitGroup::new());
370
371 // Round 1.
372 wg.add(1);
373 let wg2 = Arc::clone(&wg);
374 crate::runtime::sched::spawn_goroutine(move || { wg2.done(); });
375 wg.wait();
376
377 // Round 2.
378 let done = Arc::new(AtomicI32::new(0));
379 wg.add(1);
380 let wg3 = Arc::clone(&wg);
381 let done2 = Arc::clone(&done);
382 crate::runtime::sched::spawn_goroutine(move || {
383 done2.store(1, Ordering::Relaxed);
384 wg3.done();
385 });
386 wg.wait();
387 assert_eq!(done.load(Ordering::Acquire), 1);
388 }
389
390 /// add(-1) below zero panics.
391 #[test]
392 #[should_panic(expected = "sync: negative WaitGroup counter")]
393 fn negative_counter_panics() {
394 let wg = WaitGroup::new();
395 wg.add(-1); // counter is 0 → -1 → panic
396 }
397}
398
399// ---------------------------------------------------------------------------
400// Loom model tests
401// ---------------------------------------------------------------------------
402
403#[cfg(all(test, loom))]
404mod loom_tests {
405 use super::*;
406 use loom::sync::Arc;
407
408 /// One worker calls done(); the waiter must unblock without deadlocking.
409 /// Loom explores all interleavings of done() vs wait().
410 #[test]
411 fn done_unblocks_wait() {
412 loom::model(|| {
413 let wg = Arc::new(WaitGroup::new());
414 let wg2 = Arc::clone(&wg);
415
416 wg.add(1);
417
418 let worker = loom::thread::spawn(move || {
419 wg2.done();
420 });
421
422 wg.wait(); // must not deadlock in any interleaving
423
424 worker.join().unwrap();
425 });
426 }
427
428 /// Two concurrent done() calls both reach zero; a concurrent wait()
429 /// must see the final count of zero in every interleaving.
430 #[test]
431 fn two_workers_unblock_wait() {
432 loom::model(|| {
433 let wg = Arc::new(WaitGroup::new());
434 let wg2 = Arc::clone(&wg);
435 let wg3 = Arc::clone(&wg);
436 let wg4 = Arc::clone(&wg);
437
438 wg.add(2);
439
440 let t1 = loom::thread::spawn(move || wg2.done());
441 let t2 = loom::thread::spawn(move || wg3.done());
442 let waiter = loom::thread::spawn(move || wg4.wait());
443
444 t1.join().unwrap();
445 t2.join().unwrap();
446 waiter.join().unwrap();
447 });
448 }
449
450 /// add() and done() may interleave; wait() must always see the true zero.
451 #[test]
452 fn add_and_done_interleave() {
453 loom::model(|| {
454 let wg = Arc::new(WaitGroup::new());
455 let wg2 = Arc::clone(&wg);
456 let wg3 = Arc::clone(&wg);
457
458 // One add(1) followed concurrently by done() and wait().
459 wg.add(1);
460
461 let adder = loom::thread::spawn(move || wg2.done());
462 wg3.wait();
463
464 adder.join().unwrap();
465 });
466 }
467}