Skip to main content

go_lib/sync/
waitgroup.rs

1// SPDX-License-Identifier: Apache-2.0
2//! `WaitGroup` — ported from `src/sync/waitgroup.go`.
3//!
4//! ## Semantics (match Go)
5//!
6//! - [`WaitGroup::add`]`(delta)` — increment the counter by `delta`.  `delta`
7//!   may be negative (used internally by [`done`][WaitGroup::done]).
8//!   Panics if the counter goes negative.
9//! - [`WaitGroup::done`]`()` — shorthand for `add(-1)`.
10//! - [`WaitGroup::wait`]`()` — block until the counter reaches zero.
11//!   Multiple goroutines may call `wait` concurrently; all are unblocked when
12//!   the last worker calls `done`.
13//!
14//! ## Implementation
15//!
16//! ### Goroutine path (the common case)
17//!
18//! `wait` uses [`gopark`][crate::runtime::park::gopark] to suspend the calling
19//! goroutine back into the scheduler **without blocking the OS thread**.  The
20//! M and its P remain free to run other goroutines.  When `add` decrements the
21//! counter to zero it drains the waiters list and calls
22//! [`goready`][crate::runtime::park::goready] on each, re-enqueuing them for
23//! scheduling.
24//!
25//! ### Non-goroutine / loom path
26//!
27//! If `wait` is called from a bare OS thread (outside the go-lib scheduler, or
28//! from a loom model thread) it falls back to blocking on a `Condvar`.  This
29//! path is also used by the `negative_counter_panics` unit test which never
30//! enters the scheduler.
31//!
32//! Ported from `sync/waitgroup.go`.
33
34use std::cell::UnsafeCell;
35
36use crate::loom_shim::{Condvar, Mutex};
37use crate::runtime::g::{current_g, WaitReason, G};
38use crate::runtime::park::{gopark_commit, goready};
39use crate::runtime::rawmutex::RawMutex;
40
41// ---------------------------------------------------------------------------
42// Internal state
43// ---------------------------------------------------------------------------
44
45struct WgState {
46    /// Number of outstanding workers (`Add` increments, `Done` decrements).
47    count:   i64,
48    /// Goroutines suspended in [`WaitGroup::wait`].  Drained by
49    /// [`WaitGroup::add`] when the counter reaches zero; each entry is woken
50    /// via [`goready`].
51    waiters: Vec<*mut G>,
52}
53
54// SAFETY: `WgState` is always accessed under the `WaitGroup`'s `Mutex`.
55// The `*mut G` pointers are goroutines owned by the scheduler; we never
56// dereference them without holding the lock (except to pass to `goready`,
57// which is safe once the goroutine has reached `GWAITING`).
58unsafe impl Send for WgState {}
59
60// ---------------------------------------------------------------------------
61// WaitGroup
62// ---------------------------------------------------------------------------
63
64/// A synchronisation barrier: wait for a set of goroutines to complete.
65///
66/// The typical usage pattern is:
67///
68/// ```no_run
69/// use std::sync::Arc;
70/// use go_lib::sync::WaitGroup;
71///
72/// let wg = Arc::new(WaitGroup::new());
73/// for i in 0..5 {
74///     let wg = Arc::clone(&wg);
75///     go_lib::run(move || {
76///         wg.add(1);
77///         // ... spawn goroutine that calls wg.done() when finished ...
78///     });
79/// }
80/// // wg.wait();  // blocks until all Done() calls have been made
81/// ```
82pub struct WaitGroup {
83    /// Spinlock protecting `state`.  A `RawMutex` (not `std::sync::Mutex`)
84    /// so that the goroutine `wait()` path can hold the lock ACROSS the park
85    /// via `gopark_commit` — the lock is released on g0 only after the
86    /// waiter is `GWAITING`.  With a guard-based Mutex released before
87    /// `gopark`, an async preemption in the release-to-park window made the
88    /// registered waiter `GRUNNABLE`; `add()`'s `goready` then dropped the
89    /// wake on the floor and the waiter parked forever.
90    mu:    RawMutex,
91    /// Interior state — always accessed under `mu`.
92    state: UnsafeCell<WgState>,
93    /// Companion lock for `cond` — used only by the non-goroutine fallback
94    /// path (bare OS threads / loom model threads).
95    cond_lock: Mutex<()>,
96    /// Condvar used only on the non-goroutine fallback path.
97    cond:  Condvar,
98}
99
100// SAFETY: `state` is only accessed while `mu` is held; the `*mut G` waiters
101// are scheduler-owned and only ever passed to `goready` (which dereferences
102// the `G` descriptor, never the waiter's stack — see `add`).
103unsafe impl Send for WaitGroup {}
104unsafe impl Sync for WaitGroup {}
105
106impl WaitGroup {
107    /// Create a new `WaitGroup` with a counter of zero.
108    pub fn new() -> Self {
109        Self {
110            mu:        RawMutex::new(),
111            state:     UnsafeCell::new(WgState { count: 0, waiters: Vec::new() }),
112            cond_lock: Mutex::new(()),
113            cond:      Condvar::new(),
114        }
115    }
116
117    /// Add `delta` to the counter.
118    ///
119    /// `delta` is typically positive when called before spawning goroutines and
120    /// negative when they finish (see [`done`][Self::done]).
121    ///
122    /// # Panics
123    ///
124    /// Panics if the counter drops below zero.
125    pub fn add(&self, delta: i64) {
126        // Hold an `m.locks` guard across the std::sync::Mutex critical section.
127        // Without it, SIGURG-based async preemption can fire after
128        // `pthread_mutex_lock` returns but before we drop the MutexGuard: the
129        // preempted goroutine still holds the OS-level pthread mutex, and the
130        // next goroutine scheduled on the same M will self-deadlock trying to
131        // re-acquire it (the default pthread mutex is non-recursive, so a
132        // same-thread re-lock blocks forever in `__psynch_mutexwait`).
133        // Captured live via lldb on a hung `many_goroutines` run.
134        let _lk = crate::runtime::m::m_lock();
135        // WaitGroup waiters are woken only via `goready`, which dereferences
136        // only the `G` descriptor and never the waiter's stack.
137        // Collect goroutine waiters to wake (if counter reaches zero).
138        let goroutine_waiters: Vec<*mut G> = {
139            self.mu.lock();
140            // SAFETY: `mu` is held.
141            let state = unsafe { &mut *self.state.get() };
142            state.count += delta;
143            if state.count < 0 {
144                unsafe { self.mu.unlock() };
145                panic!("sync: negative WaitGroup counter");
146            }
147            let zero = state.count == 0;
148            let w = if zero { std::mem::take(&mut state.waiters) } else { Vec::new() };
149            unsafe { self.mu.unlock() };
150            if zero {
151                // Wake condvar waiters (non-goroutine / loom path).  Taking
152                // `cond_lock` between the count update and the notify pairs
153                // with the re-check the bare-thread `wait` does under
154                // `cond_lock`, so the notify cannot be missed.
155                let _g = self.cond_lock.lock().unwrap();
156                self.cond.notify_all();
157            }
158            w
159        };
160
161        // Wake goroutine waiters outside the lock so we don't hold it during
162        // the goready spin (which waits for GRUNNING → GWAITING).
163        for gp in goroutine_waiters {
164            unsafe { goready(gp) };
165        }
166    }
167
168    /// Decrement the counter by one.
169    ///
170    /// Shorthand for `self.add(-1)`.
171    pub fn done(&self) {
172        self.add(-1);
173    }
174
175    /// Block until the counter is zero.
176    ///
177    /// When called from a goroutine: suspends the goroutine back into the
178    /// scheduler via `gopark` so the M and P remain free to run other
179    /// goroutines.  Resumed by `add` calling `goready` when the counter
180    /// reaches zero.
181    ///
182    /// When called from a bare OS thread (outside the go-lib scheduler):
183    /// blocks the thread on an internal `Condvar`.
184    pub fn wait(&self) {
185        // ── Goroutine path ──────────────────────────────────────────────────
186        // gopark suspends this goroutine without blocking the OS thread.
187        // The M+P are returned to the scheduler to run other goroutines
188        // (including whoever will call done() to reach count == 0).
189        let gp = current_g();
190        if !gp.is_null() {
191            // m_lock suppresses async preemption while we hold `mu` AND
192            // across the `gopark_commit` below — the increment is
193            // transferred to `park_fn`, which balances it on the same M
194            // (see gopark_commit's contract).
195            let _lk = crate::runtime::m::m_lock();
196            self.mu.lock();
197            // SAFETY: `mu` is held.
198            let state = unsafe { &mut *self.state.get() };
199            if state.count == 0 {
200                unsafe { self.mu.unlock() };
201                return; // fast path: already done (`_lk` drops normally)
202            }
203            // Register as a waiter; the lock stays held until park_fn has
204            // committed this goroutine to GWAITING (commit-park protocol),
205            // so add() can only observe the registration once goready is
206            // guaranteed to find us parked.  Releasing before the park left
207            // a window where preemption made us GRUNNABLE and add()'s wake
208            // was silently dropped — the waiter then parked forever.
209            state.waiters.push(gp);
210            // Transfer the m.locks increment to park_fn.
211            std::mem::forget(_lk);
212            unsafe {
213                gopark_commit(
214                    WaitReason::Semacquire,
215                    unlock_wg_mutex,
216                    &self.mu as *const RawMutex as *mut u8,
217                );
218            }
219            return;
220        }
221
222        // ── Non-goroutine / loom path ───────────────────────────────────────
223        // Block the calling OS thread on the condvar.  Used by tests that call
224        // wait() from a bare thread and by loom model threads.  The count is
225        // re-checked under `cond_lock`, which `add()` acquires between
226        // setting count == 0 and notifying — so the notify cannot fall
227        // between our check and the `cond.wait`.
228        let mut guard = self.cond_lock.lock().unwrap();
229        loop {
230            self.mu.lock();
231            // SAFETY: `mu` is held.
232            let count = unsafe { (*self.state.get()).count };
233            unsafe { self.mu.unlock() };
234            if count == 0 {
235                return;
236            }
237            guard = self.cond.wait(guard).unwrap();
238        }
239    }
240}
241
242impl Default for WaitGroup {
243    fn default() -> Self {
244        Self::new()
245    }
246}
247
248/// `gopark_commit` unlock shim: release the WaitGroup's `RawMutex` from g0
249/// after the parking goroutine has reached `GWAITING`.
250///
251/// # Safety
252/// `arg` must be the `&RawMutex` of a `WaitGroup` whose lock is held by the
253/// parking goroutine.
254unsafe fn unlock_wg_mutex(arg: *mut u8) {
255    unsafe { (*(arg as *const RawMutex)).unlock() }
256}
257
258// ---------------------------------------------------------------------------
259// Tests
260// ---------------------------------------------------------------------------
261
262#[cfg(all(test, not(loom)))]
263mod tests {
264    use super::*;
265    use crate::runtime::sched::run_impl;
266    use std::sync::atomic::{AtomicI32, Ordering};
267    use std::sync::Arc;
268
269    /// A freshly created WaitGroup has counter zero; wait() returns immediately.
270    #[test]
271    fn new_wait_returns_immediately() {
272        let wg = WaitGroup::new();
273        wg.wait(); // must not block
274    }
275
276    /// add + done in a single goroutine; wait unblocks after done.
277    #[test]
278    fn single_worker() {
279        run_impl(|| {
280            let wg = Arc::new(WaitGroup::new());
281            let done = Arc::new(AtomicI32::new(0));
282
283            wg.add(1);
284            let wg2   = Arc::clone(&wg);
285            let done2 = Arc::clone(&done);
286            crate::runtime::sched::spawn_goroutine(move || {
287                done2.fetch_add(1, Ordering::Relaxed);
288                wg2.done();
289            });
290
291            wg.wait();
292            assert_eq!(done.load(Ordering::Acquire), 1);
293        });
294    }
295
296    /// Five workers; wait unblocks only after all five call done.
297    #[test]
298    fn multiple_workers() {
299        const N: i32 = 5;
300        let count = Arc::new(AtomicI32::new(0));
301        let count2 = Arc::clone(&count);
302
303        run_impl(move || {
304            let wg = Arc::new(WaitGroup::new());
305
306            for _ in 0..N {
307                wg.add(1);
308                let wg2    = Arc::clone(&wg);
309                let count3 = Arc::clone(&count2);
310                crate::runtime::sched::spawn_goroutine(move || {
311                    count3.fetch_add(1, Ordering::Relaxed);
312                    wg2.done();
313                });
314            }
315
316            wg.wait();
317            assert_eq!(count2.load(Ordering::Acquire), N);
318        });
319
320        assert_eq!(count.load(Ordering::Acquire), N);
321    }
322
323    /// Two goroutines both wait; both unblock when the counter reaches zero.
324    #[test]
325    fn multiple_waiters() {
326        let woke = Arc::new(AtomicI32::new(0));
327        let woke2 = Arc::clone(&woke);
328
329        run_impl(move || {
330            let wg = Arc::new(WaitGroup::new());
331            wg.add(1);
332
333            // Spawn two waiters.
334            for _ in 0..2 {
335                let wg3   = Arc::clone(&wg);
336                let woke3 = Arc::clone(&woke2);
337                crate::runtime::sched::spawn_goroutine(move || {
338                    wg3.wait();
339                    woke3.fetch_add(1, Ordering::Relaxed);
340                });
341            }
342
343            // Yield so the waiters have a chance to call wg.wait() and park.
344            for _ in 0..20 { crate::gosched(); }
345            wg.done();
346
347            // Poll until both waiter goroutines have run past wg.wait() and
348            // incremented woke.  A fixed gosched-loop is not deterministic
349            // under parallel test load; polling on the atomic is race-free.
350            let deadline = std::time::Instant::now()
351                + std::time::Duration::from_millis(500);
352            loop {
353                if woke2.load(Ordering::Acquire) >= 2 { break; }
354                assert!(
355                    std::time::Instant::now() < deadline,
356                    "timed out: only {} of 2 waiters woke",
357                    woke2.load(Ordering::Relaxed),
358                );
359                crate::gosched();
360            }
361        });
362
363        assert_eq!(woke.load(Ordering::Acquire), 2, "both waiters must wake");
364    }
365
366    /// WaitGroup is reusable: after wait() the counter can be incremented again.
367    #[test]
368    fn reuse_after_wait() {
369        run_impl(|| {
370            let wg = Arc::new(WaitGroup::new());
371
372            // Round 1.
373            wg.add(1);
374            let wg2 = Arc::clone(&wg);
375            crate::runtime::sched::spawn_goroutine(move || { wg2.done(); });
376            wg.wait();
377
378            // Round 2.
379            let done = Arc::new(AtomicI32::new(0));
380            wg.add(1);
381            let wg3   = Arc::clone(&wg);
382            let done2 = Arc::clone(&done);
383            crate::runtime::sched::spawn_goroutine(move || {
384                done2.store(1, Ordering::Relaxed);
385                wg3.done();
386            });
387            wg.wait();
388            assert_eq!(done.load(Ordering::Acquire), 1);
389        });
390    }
391
392    /// add(-1) below zero panics.
393    #[test]
394    #[should_panic(expected = "sync: negative WaitGroup counter")]
395    fn negative_counter_panics() {
396        let wg = WaitGroup::new();
397        wg.add(-1); // counter is 0 → -1 → panic
398    }
399}
400
401// ---------------------------------------------------------------------------
402// Loom model tests
403// ---------------------------------------------------------------------------
404
405#[cfg(all(test, loom))]
406mod loom_tests {
407    use super::*;
408    use loom::sync::Arc;
409
410    /// One worker calls done(); the waiter must unblock without deadlocking.
411    /// Loom explores all interleavings of done() vs wait().
412    #[test]
413    fn done_unblocks_wait() {
414        loom::model(|| {
415            let wg  = Arc::new(WaitGroup::new());
416            let wg2 = Arc::clone(&wg);
417
418            wg.add(1);
419
420            let worker = loom::thread::spawn(move || {
421                wg2.done();
422            });
423
424            wg.wait(); // must not deadlock in any interleaving
425
426            worker.join().unwrap();
427        });
428    }
429
430    /// Two concurrent done() calls both reach zero; a concurrent wait()
431    /// must see the final count of zero in every interleaving.
432    #[test]
433    fn two_workers_unblock_wait() {
434        loom::model(|| {
435            let wg  = Arc::new(WaitGroup::new());
436            let wg2 = Arc::clone(&wg);
437            let wg3 = Arc::clone(&wg);
438            let wg4 = Arc::clone(&wg);
439
440            wg.add(2);
441
442            let t1 = loom::thread::spawn(move || wg2.done());
443            let t2 = loom::thread::spawn(move || wg3.done());
444            let waiter = loom::thread::spawn(move || wg4.wait());
445
446            t1.join().unwrap();
447            t2.join().unwrap();
448            waiter.join().unwrap();
449        });
450    }
451
452    /// add() and done() may interleave; wait() must always see the true zero.
453    #[test]
454    fn add_and_done_interleave() {
455        loom::model(|| {
456            let wg  = Arc::new(WaitGroup::new());
457            let wg2 = Arc::clone(&wg);
458            let wg3 = Arc::clone(&wg);
459
460            // One add(1) followed concurrently by done() and wait().
461            wg.add(1);
462
463            let adder = loom::thread::spawn(move || wg2.done());
464            wg3.wait();
465
466            adder.join().unwrap();
467        });
468    }
469}