Skip to main content

go_lib/sync/
waitgroup.rs

1// SPDX-License-Identifier: Apache-2.0
2//! `WaitGroup` — ported from `src/sync/waitgroup.go`.
3//!
4//! ## Semantics (match Go)
5//!
6//! - [`WaitGroup::add`]`(delta)` — increment the counter by `delta`.  `delta`
7//!   may be negative (used internally by [`done`][WaitGroup::done]).
8//!   Panics if the counter goes negative.
9//! - [`WaitGroup::done`]`()` — shorthand for `add(-1)`.
10//! - [`WaitGroup::wait`]`()` — block until the counter reaches zero.
11//!   Multiple goroutines may call `wait` concurrently; all are unblocked when
12//!   the last worker calls `done`.
13//!
14//! ## Implementation
15//!
16//! ### Goroutine path (the common case)
17//!
18//! `wait` uses [`gopark`][crate::runtime::park::gopark] to suspend the calling
19//! goroutine back into the scheduler **without blocking the OS thread**.  The
20//! M and its P remain free to run other goroutines.  When `add` decrements the
21//! counter to zero it drains the waiters list and calls
22//! [`goready`][crate::runtime::park::goready] on each, re-enqueuing them for
23//! scheduling.
24//!
25//! ### Non-goroutine / loom path
26//!
27//! If `wait` is called from a bare OS thread (outside the go-lib scheduler, or
28//! from a loom model thread) it falls back to blocking on a `Condvar`.  This
29//! path is also used by the `negative_counter_panics` unit test which never
30//! enters the scheduler.
31//!
32//! Ported from `sync/waitgroup.go`.
33
34use crate::loom_shim::{Condvar, Mutex};
35use crate::runtime::g::{current_g, WaitReason, G};
36use crate::runtime::park::{gopark, goready};
37
38// ---------------------------------------------------------------------------
39// Internal state
40// ---------------------------------------------------------------------------
41
42struct WgState {
43    /// Number of outstanding workers (`Add` increments, `Done` decrements).
44    count:   i64,
45    /// Goroutines suspended in [`WaitGroup::wait`].  Drained by
46    /// [`WaitGroup::add`] when the counter reaches zero; each entry is woken
47    /// via [`goready`].
48    waiters: Vec<*mut G>,
49}
50
51// SAFETY: `WgState` is always accessed under the `WaitGroup`'s `Mutex`.
52// The `*mut G` pointers are goroutines owned by the scheduler; we never
53// dereference them without holding the lock (except to pass to `goready`,
54// which is safe once the goroutine has reached `GWAITING`).
55unsafe impl Send for WgState {}
56
57// ---------------------------------------------------------------------------
58// WaitGroup
59// ---------------------------------------------------------------------------
60
61/// A synchronisation barrier: wait for a set of goroutines to complete.
62///
63/// The typical usage pattern is:
64///
65/// ```no_run
66/// use std::sync::Arc;
67/// use go_lib::sync::WaitGroup;
68///
69/// let wg = Arc::new(WaitGroup::new());
70/// for i in 0..5 {
71///     let wg = Arc::clone(&wg);
72///     go_lib::run(move || {
73///         wg.add(1);
74///         // ... spawn goroutine that calls wg.done() when finished ...
75///     });
76/// }
77/// // wg.wait();  // blocks until all Done() calls have been made
78/// ```
79pub struct WaitGroup {
80    state: Mutex<WgState>,
81    /// Condvar used only on the non-goroutine fallback path.
82    cond:  Condvar,
83}
84
85impl WaitGroup {
86    /// Create a new `WaitGroup` with a counter of zero.
87    pub fn new() -> Self {
88        Self {
89            state: Mutex::new(WgState { count: 0, waiters: Vec::new() }),
90            cond:  Condvar::new(),
91        }
92    }
93
94    /// Add `delta` to the counter.
95    ///
96    /// `delta` is typically positive when called before spawning goroutines and
97    /// negative when they finish (see [`done`][Self::done]).
98    ///
99    /// # Panics
100    ///
101    /// Panics if the counter drops below zero.
102    pub fn add(&self, delta: i64) {
103        // Hold an `m.locks` guard across the std::sync::Mutex critical section.
104        // Without it, SIGURG-based async preemption can fire after
105        // `pthread_mutex_lock` returns but before we drop the MutexGuard: the
106        // preempted goroutine still holds the OS-level pthread mutex, and the
107        // next goroutine scheduled on the same M will self-deadlock trying to
108        // re-acquire it (the default pthread mutex is non-recursive, so a
109        // same-thread re-lock blocks forever in `__psynch_mutexwait`).
110        // Captured live via lldb on a hung `many_goroutines` run.
111        let _lk = crate::runtime::m::m_lock();
112        // Collect goroutine waiters to wake (if counter reaches zero).
113        let goroutine_waiters: Vec<*mut G> = {
114            let mut state = self.state.lock().unwrap();
115            state.count += delta;
116            if state.count < 0 {
117                drop(state);
118                panic!("sync: negative WaitGroup counter");
119            }
120            if state.count == 0 {
121                // Wake condvar waiters (non-goroutine / loom path).
122                // Drain goroutine waiters to wake via goready below.
123                let w = std::mem::take(&mut state.waiters);
124                drop(state);
125                self.cond.notify_all();
126                w
127            } else {
128                Vec::new()
129            }
130        };
131
132        // Wake goroutine waiters outside the lock so we don't hold it during
133        // the goready spin (which waits for GRUNNING → GWAITING).
134        for gp in goroutine_waiters {
135            unsafe { goready(gp) };
136        }
137    }
138
139    /// Decrement the counter by one.
140    ///
141    /// Shorthand for `self.add(-1)`.
142    pub fn done(&self) {
143        self.add(-1);
144    }
145
146    /// Block until the counter is zero.
147    ///
148    /// When called from a goroutine: suspends the goroutine back into the
149    /// scheduler via `gopark` so the M and P remain free to run other
150    /// goroutines.  Resumed by `add` calling `goready` when the counter
151    /// reaches zero.
152    ///
153    /// When called from a bare OS thread (outside the go-lib scheduler):
154    /// blocks the thread on an internal `Condvar`.
155    pub fn wait(&self) {
156        // ── Goroutine path ──────────────────────────────────────────────────
157        // gopark suspends this goroutine without blocking the OS thread.
158        // The M+P are returned to the scheduler to run other goroutines
159        // (including whoever will call done() to reach count == 0).
160        let gp = current_g();
161        if !gp.is_null() {
162            // Critical section guarded by `m_lock` — see matching comment in
163            // `add()`.  Scope the guard so it is dropped *before* `gopark`:
164            // we must not hold `m.locks` across the yield (other goroutines
165            // scheduled on this M would inherit suppressed preemption until
166            // this goroutine wakes again, which can be milliseconds away).
167            {
168                let _lk = crate::runtime::m::m_lock();
169                let mut state = self.state.lock().unwrap();
170                if state.count == 0 {
171                    return; // fast path: already done
172                }
173                // Register as a waiter *before* releasing the lock so that any
174                // concurrent add() that drives count to zero will see us and
175                // call goready.  goready itself spins until our status reaches
176                // GWAITING, which closes the window between drop(state) and
177                // gopark().
178                state.waiters.push(gp);
179                drop(state);
180            }
181            // Suspend this goroutine.  Execution resumes here after add()
182            // calls goready(gp) once the counter reaches zero.
183            gopark(WaitReason::Semacquire);
184            return;
185        }
186
187        // ── Non-goroutine / loom path ───────────────────────────────────────
188        // Block the calling OS thread on the condvar.  Used by tests that call
189        // wait() from a bare thread and by loom model threads.
190        let mut state = self.state.lock().unwrap();
191        while state.count > 0 {
192            state = self.cond.wait(state).unwrap();
193        }
194    }
195}
196
197impl Default for WaitGroup {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203// ---------------------------------------------------------------------------
204// Tests
205// ---------------------------------------------------------------------------
206
207#[cfg(all(test, not(loom)))]
208mod tests {
209    use super::*;
210    use crate::runtime::sched::run_impl;
211    use std::sync::atomic::{AtomicI32, Ordering};
212    use std::sync::Arc;
213
214    /// A freshly created WaitGroup has counter zero; wait() returns immediately.
215    #[test]
216    fn new_wait_returns_immediately() {
217        let wg = WaitGroup::new();
218        wg.wait(); // must not block
219    }
220
221    /// add + done in a single goroutine; wait unblocks after done.
222    #[test]
223    fn single_worker() {
224        run_impl(|| {
225            let wg = Arc::new(WaitGroup::new());
226            let done = Arc::new(AtomicI32::new(0));
227
228            wg.add(1);
229            let wg2   = Arc::clone(&wg);
230            let done2 = Arc::clone(&done);
231            crate::runtime::sched::spawn_goroutine(move || {
232                done2.fetch_add(1, Ordering::Relaxed);
233                wg2.done();
234            });
235
236            wg.wait();
237            assert_eq!(done.load(Ordering::Acquire), 1);
238        });
239    }
240
241    /// Five workers; wait unblocks only after all five call done.
242    #[test]
243    fn multiple_workers() {
244        const N: i32 = 5;
245        let count = Arc::new(AtomicI32::new(0));
246        let count2 = Arc::clone(&count);
247
248        run_impl(move || {
249            let wg = Arc::new(WaitGroup::new());
250
251            for _ in 0..N {
252                wg.add(1);
253                let wg2    = Arc::clone(&wg);
254                let count3 = Arc::clone(&count2);
255                crate::runtime::sched::spawn_goroutine(move || {
256                    count3.fetch_add(1, Ordering::Relaxed);
257                    wg2.done();
258                });
259            }
260
261            wg.wait();
262            assert_eq!(count2.load(Ordering::Acquire), N);
263        });
264
265        assert_eq!(count.load(Ordering::Acquire), N);
266    }
267
268    /// Two goroutines both wait; both unblock when the counter reaches zero.
269    #[test]
270    fn multiple_waiters() {
271        let woke = Arc::new(AtomicI32::new(0));
272        let woke2 = Arc::clone(&woke);
273
274        run_impl(move || {
275            let wg = Arc::new(WaitGroup::new());
276            wg.add(1);
277
278            // Spawn two waiters.
279            for _ in 0..2 {
280                let wg3   = Arc::clone(&wg);
281                let woke3 = Arc::clone(&woke2);
282                crate::runtime::sched::spawn_goroutine(move || {
283                    wg3.wait();
284                    woke3.fetch_add(1, Ordering::Relaxed);
285                });
286            }
287
288            // Yield so the waiters have a chance to call wg.wait() and park.
289            for _ in 0..20 { crate::gosched(); }
290            wg.done();
291
292            // Poll until both waiter goroutines have run past wg.wait() and
293            // incremented woke.  A fixed gosched-loop is not deterministic
294            // under parallel test load; polling on the atomic is race-free.
295            let deadline = std::time::Instant::now()
296                + std::time::Duration::from_millis(500);
297            loop {
298                if woke2.load(Ordering::Acquire) >= 2 { break; }
299                assert!(
300                    std::time::Instant::now() < deadline,
301                    "timed out: only {} of 2 waiters woke",
302                    woke2.load(Ordering::Relaxed),
303                );
304                crate::gosched();
305            }
306        });
307
308        assert_eq!(woke.load(Ordering::Acquire), 2, "both waiters must wake");
309    }
310
311    /// WaitGroup is reusable: after wait() the counter can be incremented again.
312    #[test]
313    fn reuse_after_wait() {
314        run_impl(|| {
315            let wg = Arc::new(WaitGroup::new());
316
317            // Round 1.
318            wg.add(1);
319            let wg2 = Arc::clone(&wg);
320            crate::runtime::sched::spawn_goroutine(move || { wg2.done(); });
321            wg.wait();
322
323            // Round 2.
324            let done = Arc::new(AtomicI32::new(0));
325            wg.add(1);
326            let wg3   = Arc::clone(&wg);
327            let done2 = Arc::clone(&done);
328            crate::runtime::sched::spawn_goroutine(move || {
329                done2.store(1, Ordering::Relaxed);
330                wg3.done();
331            });
332            wg.wait();
333            assert_eq!(done.load(Ordering::Acquire), 1);
334        });
335    }
336
337    /// add(-1) below zero panics.
338    #[test]
339    #[should_panic(expected = "sync: negative WaitGroup counter")]
340    fn negative_counter_panics() {
341        let wg = WaitGroup::new();
342        wg.add(-1); // counter is 0 → -1 → panic
343    }
344}
345
346// ---------------------------------------------------------------------------
347// Loom model tests
348// ---------------------------------------------------------------------------
349
350#[cfg(all(test, loom))]
351mod loom_tests {
352    use super::*;
353    use loom::sync::Arc;
354
355    /// One worker calls done(); the waiter must unblock without deadlocking.
356    /// Loom explores all interleavings of done() vs wait().
357    #[test]
358    fn done_unblocks_wait() {
359        loom::model(|| {
360            let wg  = Arc::new(WaitGroup::new());
361            let wg2 = Arc::clone(&wg);
362
363            wg.add(1);
364
365            let worker = loom::thread::spawn(move || {
366                wg2.done();
367            });
368
369            wg.wait(); // must not deadlock in any interleaving
370
371            worker.join().unwrap();
372        });
373    }
374
375    /// Two concurrent done() calls both reach zero; a concurrent wait()
376    /// must see the final count of zero in every interleaving.
377    #[test]
378    fn two_workers_unblock_wait() {
379        loom::model(|| {
380            let wg  = Arc::new(WaitGroup::new());
381            let wg2 = Arc::clone(&wg);
382            let wg3 = Arc::clone(&wg);
383            let wg4 = Arc::clone(&wg);
384
385            wg.add(2);
386
387            let t1 = loom::thread::spawn(move || wg2.done());
388            let t2 = loom::thread::spawn(move || wg3.done());
389            let waiter = loom::thread::spawn(move || wg4.wait());
390
391            t1.join().unwrap();
392            t2.join().unwrap();
393            waiter.join().unwrap();
394        });
395    }
396
397    /// add() and done() may interleave; wait() must always see the true zero.
398    #[test]
399    fn add_and_done_interleave() {
400        loom::model(|| {
401            let wg  = Arc::new(WaitGroup::new());
402            let wg2 = Arc::clone(&wg);
403            let wg3 = Arc::clone(&wg);
404
405            // One add(1) followed concurrently by done() and wait().
406            wg.add(1);
407
408            let adder = loom::thread::spawn(move || wg2.done());
409            wg3.wait();
410
411            adder.join().unwrap();
412        });
413    }
414}