Skip to main content

go_lib/sync/
waitgroup.rs

1// SPDX-License-Identifier: Apache-2.0
2//! `WaitGroup` — ported from `src/sync/waitgroup.go`.
3//!
4//! ## Semantics (match Go)
5//!
6//! - [`WaitGroup::add`]`(delta)` — increment the counter by `delta`.  `delta`
7//!   may be negative (used internally by [`done`][WaitGroup::done]).
8//!   Panics if the counter goes negative.
9//! - [`WaitGroup::done`]`()` — shorthand for `add(-1)`.
10//! - [`WaitGroup::wait`]`()` — block until the counter reaches zero.
11//!   Multiple goroutines may call `wait` concurrently; all are unblocked when
12//!   the last worker calls `done`.
13//!
14//! ## Implementation
15//!
16//! ### Goroutine path (the common case)
17//!
18//! `wait` uses [`gopark`][crate::runtime::park::gopark] to suspend the calling
19//! goroutine back into the scheduler **without blocking the OS thread**.  The
20//! M and its P remain free to run other goroutines.  When `add` decrements the
21//! counter to zero it drains the waiters list and calls
22//! [`goready`][crate::runtime::park::goready] on each, re-enqueuing them for
23//! scheduling.
24//!
25//! ### Non-goroutine / loom path
26//!
27//! If `wait` is called from a bare OS thread (outside the go-lib scheduler, or
28//! from a loom model thread) it falls back to blocking on a `Condvar`.  This
29//! path is also used by the `negative_counter_panics` unit test which never
30//! enters the scheduler.
31//!
32//! Ported from `sync/waitgroup.go`.
33
34use std::cell::UnsafeCell;
35
36use crate::loom_shim::{Condvar, Mutex};
37use crate::runtime::g::{current_g, WaitReason, G};
38use crate::runtime::park::{gopark_commit, goready};
39use crate::runtime::rawmutex::RawMutex;
40
41// ---------------------------------------------------------------------------
42// Internal state
43// ---------------------------------------------------------------------------
44
45struct WgState {
46    /// Number of outstanding workers (`Add` increments, `Done` decrements).
47    count:   i64,
48    /// Goroutines suspended in [`WaitGroup::wait`].  Drained by
49    /// [`WaitGroup::add`] when the counter reaches zero; each entry is woken
50    /// via [`goready`].
51    waiters: Vec<*mut G>,
52}
53
54// SAFETY: `WgState` is always accessed under the `WaitGroup`'s `Mutex`.
55// The `*mut G` pointers are goroutines owned by the scheduler; we never
56// dereference them without holding the lock (except to pass to `goready`,
57// which is safe once the goroutine has reached `GWAITING`).
58unsafe impl Send for WgState {}
59
60// ---------------------------------------------------------------------------
61// WaitGroup
62// ---------------------------------------------------------------------------
63
64/// A synchronisation barrier: wait for a set of goroutines to complete.
65///
66/// The typical usage pattern is:
67///
68/// ```no_run
69/// use std::sync::Arc;
70/// use go_lib::sync::WaitGroup;
71///
72/// #[go_lib::main]
73/// fn main() {
74///     let wg = Arc::new(WaitGroup::new());
75///     for i in 0..5 {
76///         let wg = Arc::clone(&wg);
77///         go_lib::go!(move || {
78///             wg.add(1);
79///             // ... do work, then call wg.done() when finished ...
80///         });
81///     }
82///     // wg.wait();  // blocks until all done() calls have been made
83/// }
84/// ```
85pub struct WaitGroup {
86    /// Spinlock protecting `state`.  A `RawMutex` (not `std::sync::Mutex`)
87    /// so that the goroutine `wait()` path can hold the lock ACROSS the park
88    /// via `gopark_commit` — the lock is released on g0 only after the
89    /// waiter is `GWAITING`.  With a guard-based Mutex released before
90    /// `gopark`, an async preemption in the release-to-park window made the
91    /// registered waiter `GRUNNABLE`; `add()`'s `goready` then dropped the
92    /// wake on the floor and the waiter parked forever.
93    mu:    RawMutex,
94    /// Interior state — always accessed under `mu`.
95    state: UnsafeCell<WgState>,
96    /// Companion lock for `cond` — used only by the non-goroutine fallback
97    /// path (bare OS threads / loom model threads).
98    cond_lock: Mutex<()>,
99    /// Condvar used only on the non-goroutine fallback path.
100    cond:  Condvar,
101}
102
103// SAFETY: `state` is only accessed while `mu` is held; the `*mut G` waiters
104// are scheduler-owned and only ever passed to `goready` (which dereferences
105// the `G` descriptor, never the waiter's stack — see `add`).
106unsafe impl Send for WaitGroup {}
107unsafe impl Sync for WaitGroup {}
108
109impl WaitGroup {
110    /// Create a new `WaitGroup` with a counter of zero.
111    pub fn new() -> Self {
112        Self {
113            mu:        RawMutex::new(),
114            state:     UnsafeCell::new(WgState { count: 0, waiters: Vec::new() }),
115            cond_lock: Mutex::new(()),
116            cond:      Condvar::new(),
117        }
118    }
119
120    /// Add `delta` to the counter.
121    ///
122    /// `delta` is typically positive when called before spawning goroutines and
123    /// negative when they finish (see [`done`][Self::done]).
124    ///
125    /// # Panics
126    ///
127    /// Panics if the counter drops below zero.
128    pub fn add(&self, delta: i64) {
129        // Hold an `m.locks` guard across the std::sync::Mutex critical section.
130        // Without it, SIGURG-based async preemption can fire after
131        // `pthread_mutex_lock` returns but before we drop the MutexGuard: the
132        // preempted goroutine still holds the OS-level pthread mutex, and the
133        // next goroutine scheduled on the same M will self-deadlock trying to
134        // re-acquire it (the default pthread mutex is non-recursive, so a
135        // same-thread re-lock blocks forever in `__psynch_mutexwait`).
136        // Captured live via lldb on a hung `many_goroutines` run.
137        let _lk = crate::runtime::m::m_lock();
138        // WaitGroup waiters are woken only via `goready`, which dereferences
139        // only the `G` descriptor and never the waiter's stack.
140        // Collect goroutine waiters to wake (if counter reaches zero).
141        let goroutine_waiters: Vec<*mut G> = {
142            self.mu.lock();
143            // SAFETY: `mu` is held.
144            let state = unsafe { &mut *self.state.get() };
145            state.count += delta;
146            if state.count < 0 {
147                unsafe { self.mu.unlock() };
148                panic!("sync: negative WaitGroup counter");
149            }
150            let zero = state.count == 0;
151            let w = if zero { std::mem::take(&mut state.waiters) } else { Vec::new() };
152            unsafe { self.mu.unlock() };
153            if zero {
154                // Wake condvar waiters (non-goroutine / loom path).  Taking
155                // `cond_lock` between the count update and the notify pairs
156                // with the re-check the bare-thread `wait` does under
157                // `cond_lock`, so the notify cannot be missed.
158                let _g = self.cond_lock.lock().unwrap();
159                self.cond.notify_all();
160            }
161            w
162        };
163
164        // Wake goroutine waiters outside the lock so we don't hold it during
165        // the goready spin (which waits for GRUNNING → GWAITING).
166        for gp in goroutine_waiters {
167            unsafe { goready(gp) };
168        }
169    }
170
171    /// Decrement the counter by one.
172    ///
173    /// Shorthand for `self.add(-1)`.
174    pub fn done(&self) {
175        self.add(-1);
176    }
177
178    /// Block until the counter is zero.
179    ///
180    /// When called from a goroutine: suspends the goroutine back into the
181    /// scheduler via `gopark` so the M and P remain free to run other
182    /// goroutines.  Resumed by `add` calling `goready` when the counter
183    /// reaches zero.
184    ///
185    /// When called from a bare OS thread (outside the go-lib scheduler):
186    /// blocks the thread on an internal `Condvar`.
187    pub fn wait(&self) {
188        // ── Goroutine path ──────────────────────────────────────────────────
189        // gopark suspends this goroutine without blocking the OS thread.
190        // The M+P are returned to the scheduler to run other goroutines
191        // (including whoever will call done() to reach count == 0).
192        let gp = current_g();
193        if !gp.is_null() {
194            // m_lock suppresses async preemption while we hold `mu` AND
195            // across the `gopark_commit` below — the increment is
196            // transferred to `park_fn`, which balances it on the same M
197            // (see gopark_commit's contract).
198            let _lk = crate::runtime::m::m_lock();
199            self.mu.lock();
200            // SAFETY: `mu` is held.
201            let state = unsafe { &mut *self.state.get() };
202            if state.count == 0 {
203                unsafe { self.mu.unlock() };
204                return; // fast path: already done (`_lk` drops normally)
205            }
206            // Register as a waiter; the lock stays held until park_fn has
207            // committed this goroutine to GWAITING (commit-park protocol),
208            // so add() can only observe the registration once goready is
209            // guaranteed to find us parked.  Releasing before the park left
210            // a window where preemption made us GRUNNABLE and add()'s wake
211            // was silently dropped — the waiter then parked forever.
212            state.waiters.push(gp);
213            // Transfer the m.locks increment to park_fn.
214            std::mem::forget(_lk);
215            unsafe {
216                gopark_commit(
217                    WaitReason::Semacquire,
218                    unlock_wg_mutex,
219                    &self.mu as *const RawMutex as *mut u8,
220                );
221            }
222            return;
223        }
224
225        // ── Non-goroutine / loom path ───────────────────────────────────────
226        // Block the calling OS thread on the condvar.  Used by tests that call
227        // wait() from a bare thread and by loom model threads.  The count is
228        // re-checked under `cond_lock`, which `add()` acquires between
229        // setting count == 0 and notifying — so the notify cannot fall
230        // between our check and the `cond.wait`.
231        let mut guard = self.cond_lock.lock().unwrap();
232        loop {
233            self.mu.lock();
234            // SAFETY: `mu` is held.
235            let count = unsafe { (*self.state.get()).count };
236            unsafe { self.mu.unlock() };
237            if count == 0 {
238                return;
239            }
240            guard = self.cond.wait(guard).unwrap();
241        }
242    }
243}
244
245impl Default for WaitGroup {
246    fn default() -> Self {
247        Self::new()
248    }
249}
250
251/// `gopark_commit` unlock shim: release the WaitGroup's `RawMutex` from g0
252/// after the parking goroutine has reached `GWAITING`.
253///
254/// # Safety
255/// `arg` must be the `&RawMutex` of a `WaitGroup` whose lock is held by the
256/// parking goroutine.
257unsafe fn unlock_wg_mutex(arg: *mut u8) {
258    unsafe { (*(arg as *const RawMutex)).unlock() }
259}
260
261// ---------------------------------------------------------------------------
262// Tests
263// ---------------------------------------------------------------------------
264
265#[cfg(all(test, not(loom)))]
266mod tests {
267    use super::*;
268    use std::sync::atomic::{AtomicI32, Ordering};
269    use std::sync::Arc;
270
271    /// A freshly created WaitGroup has counter zero; wait() returns immediately.
272    #[test]
273    fn new_wait_returns_immediately() {
274        let wg = WaitGroup::new();
275        wg.wait(); // must not block
276    }
277
278    /// add + done in a single goroutine; wait unblocks after done.
279    #[test]
280    #[go_lib::main]
281    fn single_worker() {
282        let wg = Arc::new(WaitGroup::new());
283        let done = Arc::new(AtomicI32::new(0));
284
285        wg.add(1);
286        let wg2   = Arc::clone(&wg);
287        let done2 = Arc::clone(&done);
288        crate::runtime::sched::spawn_goroutine(move || {
289            done2.fetch_add(1, Ordering::Relaxed);
290            wg2.done();
291        });
292
293        wg.wait();
294        assert_eq!(done.load(Ordering::Acquire), 1);
295    }
296
297    /// Five workers; wait unblocks only after all five call done.
298    #[test]
299    #[go_lib::main]
300    fn multiple_workers() {
301        const N: i32 = 5;
302        let count = Arc::new(AtomicI32::new(0));
303        let count2 = Arc::clone(&count);
304
305        let wg = Arc::new(WaitGroup::new());
306
307        for _ in 0..N {
308            wg.add(1);
309            let wg2    = Arc::clone(&wg);
310            let count3 = Arc::clone(&count2);
311            crate::runtime::sched::spawn_goroutine(move || {
312                count3.fetch_add(1, Ordering::Relaxed);
313                wg2.done();
314            });
315        }
316
317        wg.wait();
318        assert_eq!(count2.load(Ordering::Acquire), N);
319
320        assert_eq!(count.load(Ordering::Acquire), N);
321    }
322
323    /// Two goroutines both wait; both unblock when the counter reaches zero.
324    #[test]
325    #[go_lib::main]
326    fn multiple_waiters() {
327        let woke = Arc::new(AtomicI32::new(0));
328        let woke2 = Arc::clone(&woke);
329
330        let wg = Arc::new(WaitGroup::new());
331        wg.add(1);
332
333        // Spawn two waiters.
334        for _ in 0..2 {
335            let wg3   = Arc::clone(&wg);
336            let woke3 = Arc::clone(&woke2);
337            crate::runtime::sched::spawn_goroutine(move || {
338                wg3.wait();
339                woke3.fetch_add(1, Ordering::Relaxed);
340            });
341        }
342
343        // Yield so the waiters have a chance to call wg.wait() and park.
344        for _ in 0..20 { crate::gosched(); }
345        wg.done();
346
347        // Poll until both waiter goroutines have run past wg.wait() and
348        // incremented woke.  A fixed gosched-loop is not deterministic
349        // under parallel test load; polling on the atomic is race-free.
350        let deadline = std::time::Instant::now()
351            + std::time::Duration::from_millis(500);
352        loop {
353            if woke2.load(Ordering::Acquire) >= 2 { break; }
354            assert!(
355                std::time::Instant::now() < deadline,
356                "timed out: only {} of 2 waiters woke",
357                woke2.load(Ordering::Relaxed),
358            );
359            crate::gosched();
360        }
361
362        assert_eq!(woke.load(Ordering::Acquire), 2, "both waiters must wake");
363    }
364
365    /// WaitGroup is reusable: after wait() the counter can be incremented again.
366    #[test]
367    #[go_lib::main]
368    fn reuse_after_wait() {
369        let wg = Arc::new(WaitGroup::new());
370
371        // Round 1.
372        wg.add(1);
373        let wg2 = Arc::clone(&wg);
374        crate::runtime::sched::spawn_goroutine(move || { wg2.done(); });
375        wg.wait();
376
377        // Round 2.
378        let done = Arc::new(AtomicI32::new(0));
379        wg.add(1);
380        let wg3   = Arc::clone(&wg);
381        let done2 = Arc::clone(&done);
382        crate::runtime::sched::spawn_goroutine(move || {
383            done2.store(1, Ordering::Relaxed);
384            wg3.done();
385        });
386        wg.wait();
387        assert_eq!(done.load(Ordering::Acquire), 1);
388    }
389
390    /// add(-1) below zero panics.
391    #[test]
392    #[should_panic(expected = "sync: negative WaitGroup counter")]
393    fn negative_counter_panics() {
394        let wg = WaitGroup::new();
395        wg.add(-1); // counter is 0 → -1 → panic
396    }
397}
398
399// ---------------------------------------------------------------------------
400// Loom model tests
401// ---------------------------------------------------------------------------
402
403#[cfg(all(test, loom))]
404mod loom_tests {
405    use super::*;
406    use loom::sync::Arc;
407
408    /// One worker calls done(); the waiter must unblock without deadlocking.
409    /// Loom explores all interleavings of done() vs wait().
410    #[test]
411    fn done_unblocks_wait() {
412        loom::model(|| {
413            let wg  = Arc::new(WaitGroup::new());
414            let wg2 = Arc::clone(&wg);
415
416            wg.add(1);
417
418            let worker = loom::thread::spawn(move || {
419                wg2.done();
420            });
421
422            wg.wait(); // must not deadlock in any interleaving
423
424            worker.join().unwrap();
425        });
426    }
427
428    /// Two concurrent done() calls both reach zero; a concurrent wait()
429    /// must see the final count of zero in every interleaving.
430    #[test]
431    fn two_workers_unblock_wait() {
432        loom::model(|| {
433            let wg  = Arc::new(WaitGroup::new());
434            let wg2 = Arc::clone(&wg);
435            let wg3 = Arc::clone(&wg);
436            let wg4 = Arc::clone(&wg);
437
438            wg.add(2);
439
440            let t1 = loom::thread::spawn(move || wg2.done());
441            let t2 = loom::thread::spawn(move || wg3.done());
442            let waiter = loom::thread::spawn(move || wg4.wait());
443
444            t1.join().unwrap();
445            t2.join().unwrap();
446            waiter.join().unwrap();
447        });
448    }
449
450    /// add() and done() may interleave; wait() must always see the true zero.
451    #[test]
452    fn add_and_done_interleave() {
453        loom::model(|| {
454            let wg  = Arc::new(WaitGroup::new());
455            let wg2 = Arc::clone(&wg);
456            let wg3 = Arc::clone(&wg);
457
458            // One add(1) followed concurrently by done() and wait().
459            wg.add(1);
460
461            let adder = loom::thread::spawn(move || wg2.done());
462            wg3.wait();
463
464            adder.join().unwrap();
465        });
466    }
467}