go_lib/sync/waitgroup.rs
1// SPDX-License-Identifier: Apache-2.0
2//! `WaitGroup` — ported from `src/sync/waitgroup.go`.
3//!
4//! ## Semantics (match Go)
5//!
6//! - [`WaitGroup::add`]`(delta)` — increment the counter by `delta`. `delta`
7//! may be negative (used internally by [`done`][WaitGroup::done]).
8//! Panics if the counter goes negative.
9//! - [`WaitGroup::done`]`()` — shorthand for `add(-1)`.
10//! - [`WaitGroup::wait`]`()` — block until the counter reaches zero.
11//! Multiple goroutines may call `wait` concurrently; all are unblocked when
12//! the last worker calls `done`.
13//!
14//! ## Implementation
15//!
16//! ### Goroutine path (the common case)
17//!
18//! `wait` uses [`gopark`][crate::runtime::park::gopark] to suspend the calling
19//! goroutine back into the scheduler **without blocking the OS thread**. The
20//! M and its P remain free to run other goroutines. When `add` decrements the
21//! counter to zero it drains the waiters list and calls
22//! [`goready`][crate::runtime::park::goready] on each, re-enqueuing them for
23//! scheduling.
24//!
25//! ### Non-goroutine / loom path
26//!
27//! If `wait` is called from a bare OS thread (outside the go-lib scheduler, or
28//! from a loom model thread) it falls back to blocking on a `Condvar`. This
29//! path is also used by the `negative_counter_panics` unit test which never
30//! enters the scheduler.
31//!
32//! Ported from `sync/waitgroup.go`.
33
34use crate::loom_shim::{Condvar, Mutex};
35use crate::runtime::g::{current_g, WaitReason, G};
36use crate::runtime::park::{gopark, goready};
37
38// ---------------------------------------------------------------------------
39// Internal state
40// ---------------------------------------------------------------------------
41
42struct WgState {
43 /// Number of outstanding workers (`Add` increments, `Done` decrements).
44 count: i64,
45 /// Goroutines suspended in [`WaitGroup::wait`]. Drained by
46 /// [`WaitGroup::add`] when the counter reaches zero; each entry is woken
47 /// via [`goready`].
48 waiters: Vec<*mut G>,
49}
50
51// SAFETY: `WgState` is always accessed under the `WaitGroup`'s `Mutex`.
52// The `*mut G` pointers are goroutines owned by the scheduler; we never
53// dereference them without holding the lock (except to pass to `goready`,
54// which is safe once the goroutine has reached `GWAITING`).
55unsafe impl Send for WgState {}
56
57// ---------------------------------------------------------------------------
58// WaitGroup
59// ---------------------------------------------------------------------------
60
61/// A synchronisation barrier: wait for a set of goroutines to complete.
62///
63/// The typical usage pattern is:
64///
65/// ```no_run
66/// use std::sync::Arc;
67/// use go_lib::sync::WaitGroup;
68///
69/// let wg = Arc::new(WaitGroup::new());
70/// for i in 0..5 {
71/// let wg = Arc::clone(&wg);
72/// go_lib::run(move || {
73/// wg.add(1);
74/// // ... spawn goroutine that calls wg.done() when finished ...
75/// });
76/// }
77/// // wg.wait(); // blocks until all Done() calls have been made
78/// ```
79pub struct WaitGroup {
80 state: Mutex<WgState>,
81 /// Condvar used only on the non-goroutine fallback path.
82 cond: Condvar,
83}
84
85impl WaitGroup {
86 /// Create a new `WaitGroup` with a counter of zero.
87 pub fn new() -> Self {
88 Self {
89 state: Mutex::new(WgState { count: 0, waiters: Vec::new() }),
90 cond: Condvar::new(),
91 }
92 }
93
94 /// Add `delta` to the counter.
95 ///
96 /// `delta` is typically positive when called before spawning goroutines and
97 /// negative when they finish (see [`done`][Self::done]).
98 ///
99 /// # Panics
100 ///
101 /// Panics if the counter drops below zero.
102 pub fn add(&self, delta: i64) {
103 // Collect goroutine waiters to wake (if counter reaches zero).
104 let goroutine_waiters: Vec<*mut G> = {
105 let mut state = self.state.lock().unwrap();
106 state.count += delta;
107 if state.count < 0 {
108 drop(state);
109 panic!("sync: negative WaitGroup counter");
110 }
111 if state.count == 0 {
112 // Wake condvar waiters (non-goroutine / loom path).
113 // Drain goroutine waiters to wake via goready below.
114 let w = std::mem::take(&mut state.waiters);
115 drop(state);
116 self.cond.notify_all();
117 w
118 } else {
119 Vec::new()
120 }
121 };
122
123 // Wake goroutine waiters outside the lock so we don't hold it during
124 // the goready spin (which waits for GRUNNING → GWAITING).
125 for gp in goroutine_waiters {
126 unsafe { goready(gp) };
127 }
128 }
129
130 /// Decrement the counter by one.
131 ///
132 /// Shorthand for `self.add(-1)`.
133 pub fn done(&self) {
134 self.add(-1);
135 }
136
137 /// Block until the counter is zero.
138 ///
139 /// When called from a goroutine: suspends the goroutine back into the
140 /// scheduler via `gopark` so the M and P remain free to run other
141 /// goroutines. Resumed by `add` calling `goready` when the counter
142 /// reaches zero.
143 ///
144 /// When called from a bare OS thread (outside the go-lib scheduler):
145 /// blocks the thread on an internal `Condvar`.
146 pub fn wait(&self) {
147 // ── Goroutine path ──────────────────────────────────────────────────
148 // gopark suspends this goroutine without blocking the OS thread.
149 // The M+P are returned to the scheduler to run other goroutines
150 // (including whoever will call done() to reach count == 0).
151 let gp = current_g();
152 if !gp.is_null() {
153 let mut state = self.state.lock().unwrap();
154 if state.count == 0 {
155 return; // fast path: already done
156 }
157 // Register as a waiter *before* releasing the lock so that any
158 // concurrent add() that drives count to zero will see us and call
159 // goready. goready itself spins until our status reaches GWAITING,
160 // which closes the window between drop(state) and gopark().
161 state.waiters.push(gp);
162 drop(state);
163 // Suspend this goroutine. Execution resumes here after add()
164 // calls goready(gp) once the counter reaches zero.
165 unsafe { gopark(WaitReason::Semacquire) };
166 return;
167 }
168
169 // ── Non-goroutine / loom path ───────────────────────────────────────
170 // Block the calling OS thread on the condvar. Used by tests that call
171 // wait() from a bare thread and by loom model threads.
172 let mut state = self.state.lock().unwrap();
173 while state.count > 0 {
174 state = self.cond.wait(state).unwrap();
175 }
176 }
177}
178
179impl Default for WaitGroup {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185// ---------------------------------------------------------------------------
186// Tests
187// ---------------------------------------------------------------------------
188
189#[cfg(all(test, not(loom)))]
190mod tests {
191 use super::*;
192 use crate::runtime::sched::run_impl;
193 use std::sync::atomic::{AtomicI32, Ordering};
194 use std::sync::Arc;
195
196 /// A freshly created WaitGroup has counter zero; wait() returns immediately.
197 #[test]
198 fn new_wait_returns_immediately() {
199 let wg = WaitGroup::new();
200 wg.wait(); // must not block
201 }
202
203 /// add + done in a single goroutine; wait unblocks after done.
204 #[test]
205 fn single_worker() {
206 run_impl(|| {
207 let wg = Arc::new(WaitGroup::new());
208 let done = Arc::new(AtomicI32::new(0));
209
210 wg.add(1);
211 let wg2 = Arc::clone(&wg);
212 let done2 = Arc::clone(&done);
213 unsafe {
214 crate::runtime::sched::spawn_goroutine(move || {
215 done2.fetch_add(1, Ordering::Relaxed);
216 wg2.done();
217 });
218 }
219
220 wg.wait();
221 assert_eq!(done.load(Ordering::Acquire), 1);
222 });
223 }
224
225 /// Five workers; wait unblocks only after all five call done.
226 #[test]
227 fn multiple_workers() {
228 const N: i32 = 5;
229 let count = Arc::new(AtomicI32::new(0));
230 let count2 = Arc::clone(&count);
231
232 run_impl(move || {
233 let wg = Arc::new(WaitGroup::new());
234
235 for _ in 0..N {
236 wg.add(1);
237 let wg2 = Arc::clone(&wg);
238 let count3 = Arc::clone(&count2);
239 unsafe {
240 crate::runtime::sched::spawn_goroutine(move || {
241 count3.fetch_add(1, Ordering::Relaxed);
242 wg2.done();
243 });
244 }
245 }
246
247 wg.wait();
248 assert_eq!(count2.load(Ordering::Acquire), N);
249 });
250
251 assert_eq!(count.load(Ordering::Acquire), N);
252 }
253
254 /// Two goroutines both wait; both unblock when the counter reaches zero.
255 #[test]
256 fn multiple_waiters() {
257 let woke = Arc::new(AtomicI32::new(0));
258 let woke2 = Arc::clone(&woke);
259
260 run_impl(move || {
261 let wg = Arc::new(WaitGroup::new());
262 wg.add(1);
263
264 // Spawn two waiters.
265 for _ in 0..2 {
266 let wg3 = Arc::clone(&wg);
267 let woke3 = Arc::clone(&woke2);
268 unsafe {
269 crate::runtime::sched::spawn_goroutine(move || {
270 wg3.wait();
271 woke3.fetch_add(1, Ordering::Relaxed);
272 });
273 }
274 }
275
276 // Yield so the waiters have a chance to call wg.wait() and park.
277 for _ in 0..20 { crate::gosched(); }
278 wg.done();
279
280 // Poll until both waiter goroutines have run past wg.wait() and
281 // incremented woke. A fixed gosched-loop is not deterministic
282 // under parallel test load; polling on the atomic is race-free.
283 let deadline = std::time::Instant::now()
284 + std::time::Duration::from_millis(500);
285 loop {
286 if woke2.load(Ordering::Acquire) >= 2 { break; }
287 assert!(
288 std::time::Instant::now() < deadline,
289 "timed out: only {} of 2 waiters woke",
290 woke2.load(Ordering::Relaxed),
291 );
292 crate::gosched();
293 }
294 });
295
296 assert_eq!(woke.load(Ordering::Acquire), 2, "both waiters must wake");
297 }
298
299 /// WaitGroup is reusable: after wait() the counter can be incremented again.
300 #[test]
301 fn reuse_after_wait() {
302 run_impl(|| {
303 let wg = Arc::new(WaitGroup::new());
304
305 // Round 1.
306 wg.add(1);
307 let wg2 = Arc::clone(&wg);
308 unsafe {
309 crate::runtime::sched::spawn_goroutine(move || { wg2.done(); });
310 }
311 wg.wait();
312
313 // Round 2.
314 let done = Arc::new(AtomicI32::new(0));
315 wg.add(1);
316 let wg3 = Arc::clone(&wg);
317 let done2 = Arc::clone(&done);
318 unsafe {
319 crate::runtime::sched::spawn_goroutine(move || {
320 done2.store(1, Ordering::Relaxed);
321 wg3.done();
322 });
323 }
324 wg.wait();
325 assert_eq!(done.load(Ordering::Acquire), 1);
326 });
327 }
328
329 /// add(-1) below zero panics.
330 #[test]
331 #[should_panic(expected = "sync: negative WaitGroup counter")]
332 fn negative_counter_panics() {
333 let wg = WaitGroup::new();
334 wg.add(-1); // counter is 0 → -1 → panic
335 }
336}
337
338// ---------------------------------------------------------------------------
339// Loom model tests
340// ---------------------------------------------------------------------------
341
342#[cfg(all(test, loom))]
343mod loom_tests {
344 use super::*;
345 use loom::sync::Arc;
346
347 /// One worker calls done(); the waiter must unblock without deadlocking.
348 /// Loom explores all interleavings of done() vs wait().
349 #[test]
350 fn done_unblocks_wait() {
351 loom::model(|| {
352 let wg = Arc::new(WaitGroup::new());
353 let wg2 = Arc::clone(&wg);
354
355 wg.add(1);
356
357 let worker = loom::thread::spawn(move || {
358 wg2.done();
359 });
360
361 wg.wait(); // must not deadlock in any interleaving
362
363 worker.join().unwrap();
364 });
365 }
366
367 /// Two concurrent done() calls both reach zero; a concurrent wait()
368 /// must see the final count of zero in every interleaving.
369 #[test]
370 fn two_workers_unblock_wait() {
371 loom::model(|| {
372 let wg = Arc::new(WaitGroup::new());
373 let wg2 = Arc::clone(&wg);
374 let wg3 = Arc::clone(&wg);
375 let wg4 = Arc::clone(&wg);
376
377 wg.add(2);
378
379 let t1 = loom::thread::spawn(move || wg2.done());
380 let t2 = loom::thread::spawn(move || wg3.done());
381 let waiter = loom::thread::spawn(move || wg4.wait());
382
383 t1.join().unwrap();
384 t2.join().unwrap();
385 waiter.join().unwrap();
386 });
387 }
388
389 /// add() and done() may interleave; wait() must always see the true zero.
390 #[test]
391 fn add_and_done_interleave() {
392 loom::model(|| {
393 let wg = Arc::new(WaitGroup::new());
394 let wg2 = Arc::clone(&wg);
395 let wg3 = Arc::clone(&wg);
396
397 // One add(1) followed concurrently by done() and wait().
398 wg.add(1);
399
400 let adder = loom::thread::spawn(move || wg2.done());
401 wg3.wait();
402
403 adder.join().unwrap();
404 });
405 }
406}