1use log::error;
52use std::{
53 future::Future,
54 pin::Pin,
55 sync::{
56 atomic::{AtomicI64, AtomicU64, Ordering},
57 Arc,
58 },
59 task::{Context, Poll, Waker},
60};
61
62use parking_lot::Mutex;
63
64pub struct WaitGroup(Arc<WaitGroupInner>);
77
78impl Clone for WaitGroup {
80 fn clone(&self) -> Self {
81 Self(self.0.clone())
82 }
83}
84
85impl WaitGroup {
86 pub fn new() -> Self {
87 Self(WaitGroupInner::new())
88 }
89
90 #[inline(always)]
92 pub fn left(&self) -> usize {
93 let count = self.0.left.load(Ordering::SeqCst);
94 if count < 0 {
95 error!("WaitGroup.left {} < 0", count);
96 panic!("WaitGroup.left {} < 0", count);
97 }
98 count as usize
99 }
100
101 #[inline(always)]
103 pub fn add(&self, i: usize) {
104 self.0.left.fetch_add(i as i64, Ordering::SeqCst);
105 }
106
107 #[inline(always)]
128 pub fn add_guard(&self) -> WaitGroupGuard {
129 self.0.left.fetch_add(1, Ordering::SeqCst);
130 WaitGroupGuard {
131 inner: self.0.clone(),
132 }
133 }
134
135 pub async fn wait_to(&self, target: usize) -> bool {
147 let _self = self.0.as_ref();
148 let left = _self.left.load(Ordering::Acquire);
149 if left <= target as i64 {
150 return false;
151 }
152 WaitGroupFuture {
153 wg: &_self,
154 target,
155 waker_id: 0,
156 }
157 .await;
158 return true;
159 }
160
161 #[inline(always)]
169 pub async fn wait(&self) {
170 self.wait_to(0).await;
171 }
172
173 #[inline]
175 pub fn done(&self) {
176 let inner = self.0.as_ref();
177 inner.done(1);
178 }
179
180 #[inline]
182 pub fn done_many(&self, count: usize) {
183 let inner = self.0.as_ref();
184 inner.done(count as i64);
185 }
186}
187
188pub struct WaitGroupGuard {
189 inner: Arc<WaitGroupInner>,
190}
191
192impl Drop for WaitGroupGuard {
193 fn drop(&mut self) {
194 let inner = &self.inner;
195 inner.done(1);
196 }
197}
198
199struct WaitGroupInner {
200 left: AtomicI64,
201 waiting: AtomicI64,
202 waker: Mutex<Option<Waker>>,
203 waker_id: AtomicU64,
204}
205
206impl WaitGroupInner {
207 #[inline(always)]
208 fn new() -> Arc<Self> {
209 Arc::new(Self {
210 left: AtomicI64::new(0),
211 waiting: AtomicI64::new(-1),
212 waker: Mutex::new(None),
213 waker_id: AtomicU64::new(0),
214 })
215 }
216 #[inline]
217 fn done(&self, count: i64) {
218 let left = self.left.fetch_sub(count, Ordering::SeqCst) - count;
219 let waiting = self.waiting.load(Ordering::Acquire);
220 if left < 0 {
221 error!("WaitGroup.left {} < 0", left);
222 panic!("WaitGroup.left {} < 0", left);
223 }
224 if waiting < 0 {
225 return;
226 }
227 if left <= waiting {
228 if let Some(waker) = self.waker.lock().as_ref() {
230 waker.wake_by_ref();
231 }
232 }
233 }
234
235 #[inline]
238 fn set_waker(&self, waker: Waker, target: usize) -> u64 {
239 let waker_id = self.waker_id.fetch_add(1, Ordering::SeqCst) + 1;
240 {
241 let mut guard = self.waker.lock();
242 guard.replace(waker);
243 let old_target = self.waiting.swap(target as i64, Ordering::SeqCst);
244 if old_target >= 0 {
245 panic!("Concurrent wait() by multiple coroutines is not supported")
246 }
247 }
248 waker_id
249 }
250
251 #[inline]
252 fn cancel_wait(&self, waker_id: u64) {
253 let mut guard = self.waker.lock();
254 if self.waker_id.load(Ordering::Acquire) == waker_id {
256 self.waiting.store(-1, Ordering::Release);
257 let _ = guard.take();
258 }
259 }
260}
261
262struct WaitGroupFuture<'a> {
263 wg: &'a WaitGroupInner,
264 target: usize,
265 waker_id: u64,
266}
267
268impl<'a> WaitGroupFuture<'a> {
269 #[inline(always)]
270 fn _poll(&mut self) -> bool {
271 let cur = self.wg.left.load(Ordering::Acquire);
272 if cur <= self.target as i64 {
273 self._clear();
274 true
275 } else {
276 false
277 }
278 }
279
280 #[inline(always)]
281 fn _clear(&mut self) {
282 if self.waker_id == 0 {
283 return;
284 }
285 self.wg.cancel_wait(self.waker_id);
286 self.waker_id = 0;
287 }
288}
289
290impl<'a> Drop for WaitGroupFuture<'a> {
292 fn drop(&mut self) {
293 self._clear();
294 }
295}
296
297impl<'a> Future for WaitGroupFuture<'a> {
298 type Output = ();
299
300 fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
301 let _self = self.get_mut();
302 if _self.waker_id == 0 {
303 if _self._poll() {
304 return Poll::Ready(());
305 }
306 _self.waker_id = _self.wg.set_waker(ctx.waker().clone(), _self.target);
307 }
308 if _self._poll() {
309 return Poll::Ready(());
310 }
311 Poll::Pending
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 extern crate rand;
318
319 use std::time::Duration;
320 use tokio::time::{sleep, timeout};
321
322 use super::*;
323
324 fn make_runtime(threads: usize) -> tokio::runtime::Runtime {
325 return tokio::runtime::Builder::new_multi_thread()
326 .enable_all()
327 .worker_threads(threads)
328 .build()
329 .unwrap();
330 }
331
332 #[test]
333 fn test_inner() {
334 make_runtime(1).block_on(async move {
335 let wg = WaitGroup::new();
336 wg.add(2);
337 let _wg = wg.clone();
338 let th = tokio::spawn(async move {
339 assert!(_wg.wait_to(1).await);
340 });
341 sleep(Duration::from_secs(1)).await;
342 assert_eq!(wg.0.waker_id.load(Ordering::Acquire), 1);
343 {
344 let guard = wg.0.waker.lock();
345 assert!(guard.is_some());
346 assert_eq!(wg.0.waiting.load(Ordering::Acquire), 1);
347 }
348 wg.done();
349 let _ = th.await;
350 assert_eq!(wg.0.waker_id.load(Ordering::Acquire), 1);
351 assert_eq!(wg.0.waiting.load(Ordering::Acquire), -1);
352 assert_eq!(wg.left(), 1);
353 wg.done();
354 assert_eq!(wg.left(), 0);
355 assert_eq!(wg.wait_to(0).await, false);
356 });
357 }
358
359 #[test]
360 fn test_cancel() {
361 let wg = WaitGroup::new();
362 make_runtime(1).block_on(async move {
363 wg.add(1);
364 println!("test timeout");
365 assert!(timeout(Duration::from_secs(1), wg.wait()).await.is_err());
366 println!("timeout happened");
367 assert_eq!(wg.0.waiting.load(Ordering::Acquire), -1);
368 wg.done();
369 wg.add(2);
370 wg.done_many(2);
371 wg.add(2);
372 let _wg = wg.clone();
373 let th = tokio::spawn(async move {
374 _wg.wait().await;
375 });
376 sleep(Duration::from_millis(200)).await;
377 assert_eq!(wg.0.waker_id.load(Ordering::Acquire), 2);
378 wg.done();
379 wg.done();
380 let _ = th.await;
381 });
382 }
383}