1use parking_lot::{Condvar, Mutex};
29use std::fmt;
30use std::sync::Arc;
31
32pub struct WaitGroup {
34 inner: Arc<Inner>,
35}
36
37struct Inner {
39 cvar: Condvar,
40 count: Mutex<usize>,
41}
42
43impl Default for WaitGroup {
44 #[inline]
45 fn default() -> Self {
46 WaitGroup::new()
47 }
48}
49
50impl WaitGroup {
51 #[inline]
53 pub fn new() -> Self {
54 Self {
55 inner: Arc::new(Inner {
56 cvar: Condvar::new(),
57 count: Mutex::new(1),
58 }),
59 }
60 }
61
62 #[inline]
64 pub fn wait(self) {
65 if *self.inner.count.lock() == 1 {
66 return;
67 }
68
69 let inner = self.inner.clone();
70 drop(self);
71
72 let mut count = inner.count.lock();
73 while *count > 0 {
74 inner.cvar.wait(&mut count);
75 }
76 }
77}
78
79impl Drop for WaitGroup {
80 #[inline]
81 fn drop(&mut self) {
82 let mut count = self.inner.count.lock();
83 *count -= 1;
84
85 if *count == 0 {
86 self.inner.cvar.notify_all();
87 }
88 }
89}
90
91impl Clone for WaitGroup {
92 #[inline]
93 fn clone(&self) -> WaitGroup {
94 let mut count = self.inner.count.lock();
95 *count += 1;
96
97 WaitGroup {
98 inner: self.inner.clone(),
99 }
100 }
101}
102
103impl fmt::Debug for WaitGroup {
104 #[inline]
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 let count: &usize = &*self.inner.count.lock();
107 f.debug_struct("WaitGroup").field("count", count).finish()
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use std::thread;
115 use std::time::Duration;
116
117 const THREADS: usize = 10;
118
119 #[test]
120 fn wait() {
121 let wg = WaitGroup::new();
122 let (tx, rx) = std::sync::mpsc::channel();
123
124 for _ in 0..THREADS {
125 let wg = wg.clone();
126 let tx = tx.clone();
127
128 thread::spawn(move || {
129 wg.wait();
130 tx.send(()).unwrap();
131 });
132 }
133
134 thread::sleep(Duration::from_millis(100));
135
136 assert!(rx.try_recv().is_err());
139
140 wg.wait();
141
142 for _ in 0..THREADS {
144 rx.recv().unwrap();
145 }
146 }
147
148 #[test]
149 fn wait_and_drop() {
150 let wg = WaitGroup::new();
151 let (tx, rx) = std::sync::mpsc::channel();
152
153 for _ in 0..THREADS {
154 let wg = wg.clone();
155 let tx = tx.clone();
156
157 thread::spawn(move || {
158 thread::sleep(Duration::from_millis(100));
159 tx.send(()).unwrap();
160 drop(wg);
161 });
162 }
163
164 assert!(rx.try_recv().is_err());
167
168 wg.wait();
169
170 for _ in 0..THREADS {
172 rx.try_recv().unwrap();
173 }
174 }
175}