1use log::error;
52use std::{
53 future::Future,
54 pin::Pin,
55 sync::{
56 atomic::{AtomicI64, Ordering},
57 Arc,
58 },
59 task::{Context, Poll, Waker},
60};
61use parking_lot::Mutex;
62
63pub struct WaitGroup(Arc<WaitGroupInner>);
78
79impl Clone for WaitGroup {
81 fn clone(&self) -> Self {
82 Self(self.0.clone())
83 }
84}
85
86macro_rules! log_and_panic {
87 ($($arg:tt)+) => (
88 error!($($arg)+);
89 panic!($($arg)+);
90 );
91}
92
93macro_rules! trace_log {
94 ($($arg:tt)+) => (
95 #[cfg(feature="trace_log")]
96 {
97 log::trace!($($arg)+);
98 }
99 );
100}
101
102impl WaitGroup {
103 pub fn new() -> Self {
104 Self(WaitGroupInner::new())
105 }
106
107 #[inline(always)]
109 pub fn left(&self) -> usize {
110 let count = self.0.left.load(Ordering::SeqCst);
111 if count < 0 {
112 log_and_panic!("WaitGroup.left {} < 0", count);
113 }
114 count as usize
115 }
116
117 #[inline(always)]
121 pub fn add(&self, i: usize) {
122 let _r = self.0.left.fetch_add(i as i64, Ordering::Acquire);
124 trace_log!("add {}->{}", i, _r + i as i64);
125 }
126
127 #[inline(always)]
148 pub fn add_guard(&self) -> WaitGroupGuard {
149 self.add(1);
150 WaitGroupGuard {
151 inner: self.0.clone(),
152 }
153 }
154
155 pub async fn wait_to(&self, target: usize) -> bool {
167 let _self = self.0.as_ref();
168 let left = _self.left.load(Ordering::Acquire);
170 if left <= target as i64 {
171 trace_log!("wait_to skip {} <= target {}", left, target);
172 return false;
173 }
174 WaitGroupFuture {
175 wg: &_self,
176 target,
177 waker: None,
178 }
179 .await;
180 return true;
181 }
182
183 #[inline(always)]
191 pub async fn wait(&self) {
192 self.wait_to(0).await;
193 }
194
195 #[inline]
197 pub fn done(&self) {
198 let inner = self.0.as_ref();
199 inner.done(1);
200 }
201
202 #[inline]
204 pub fn done_many(&self, count: usize) {
205 let inner = self.0.as_ref();
206 inner.done(count as i64);
207 }
208}
209
210pub struct WaitGroupGuard {
211 inner: Arc<WaitGroupInner>,
212}
213
214impl Drop for WaitGroupGuard {
215 fn drop(&mut self) {
216 let inner = &self.inner;
217 inner.done(1);
218 }
219}
220
221struct WaitGroupInner {
222 left: AtomicI64,
224 waiting: AtomicI64,
226 waker: Mutex<Option<Arc<Waker>>>,
227}
228
229impl WaitGroupInner {
230 #[inline(always)]
231 fn new() -> Arc<Self> {
232 Arc::new(Self {
233 left: AtomicI64::new(0),
234 waiting: AtomicI64::new(-1),
235 waker: Mutex::new(None),
236 })
237 }
238 #[inline]
239 fn done(&self, count: i64) {
240 let left = self.left.fetch_sub(count, Ordering::SeqCst) - count;
242 if left < 0 {
243 log_and_panic!("WaitGroup.left {} < 0", left);
244 }
245 let waiting = self.waiting.load(Ordering::SeqCst);
246 if waiting < 0 {
247 trace_log!("done {}->{} not waiting", count, left);
248 return;
249 }
250 if left <= waiting {
251 if self.waiting.compare_exchange(waiting, -1, Ordering::SeqCst, Ordering::Relaxed).is_ok() {
252 let mut guard = self.waker.lock();
253 if let Some(waker) = guard.take() {
254 waker.wake_by_ref();
255 drop(guard);
256 trace_log!("done {}->{} wake {}", count, left, waiting);
257 } else {
258 drop(guard);
259 trace_log!("done {}->{} wake {} but no waker", count, left, waiting);
260 }
261 }
262 } else {
264 trace_log!("done {}->{} waiting {}", count, left, waiting);
265 }
266 }
267
268 #[inline]
270 fn set_waker(&self, waker: Arc<Waker>, target: usize, force: bool) {
271 trace_log!("set_waker {} force={}", target, force);
272 {
273 let mut guard = self.waker.lock();
274 if !force {
275 if guard.is_some() {
276 drop(guard);
277 log_and_panic!("concurrent wait detected");
278 }
279 }
280 guard.replace(waker);
281 let old_target = self.waiting.swap(target as i64, Ordering::SeqCst);
282 drop(guard);
283 if ! force && old_target >= 0 {
284 log_and_panic!("Concurrent wait() by multiple coroutines, enter unlikely code");
285 }
286 }
287 }
288
289 #[inline]
290 fn cancel_wait(&self) {
291 trace_log!("cancel_wait");
292 {
293 let mut guard = self.waker.lock();
294 self.waiting.store(-1, Ordering::SeqCst);
295 let _ = guard.take();
296 }
297 }
298}
299
300struct WaitGroupFuture<'a> {
301 wg: &'a WaitGroupInner,
302 target: usize,
303 waker: Option<Arc<Waker>>,
304}
305
306impl<'a> WaitGroupFuture<'a> {
307 #[inline(always)]
308 fn _poll(&mut self) -> bool {
309 let cur = self.wg.left.load(Ordering::SeqCst);
311 if cur <= self.target as i64 {
312 trace_log!("poll ready {}<={}", cur, self.target);
313 self._clear();
314 true
315 } else {
316 trace_log!("poll not ready {}>{}", cur, self.target);
317 false
318 }
319 }
320
321 #[inline(always)]
322 fn _clear(&mut self) {
323 if self.waker.take().is_some() {
324 self.wg.cancel_wait();
325 }
326 }
327}
328
329impl<'a> Drop for WaitGroupFuture<'a> {
331 fn drop(&mut self) {
332 self._clear();
333 }
334}
335
336impl<'a> Future for WaitGroupFuture<'a> {
337 type Output = ();
338
339 fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
340 let _self = self.get_mut();
341 if _self._poll() {
342 return Poll::Ready(());
343 }
344 let force = {
345 if let Some(waker) = _self.waker.as_ref() {
346 if _self.wg.waiting.load(Ordering::SeqCst) >= 0 &&
348 waker.will_wake(ctx.waker()) {
351 return Poll::Pending;
352 }
353 true
355 } else {
356 false
357 }
358 };
359 let waker = Arc::new(ctx.waker().clone());
361 _self.wg.set_waker(waker.clone(), _self.target, force);
362 _self.waker.replace(waker);
363 if _self._poll() {
364 return Poll::Ready(());
365 }
366 Poll::Pending
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 extern crate rand;
373
374 use std::time::Duration;
375 use tokio::time::{sleep, timeout};
376
377 use super::*;
378
379 fn make_runtime(threads: usize) -> tokio::runtime::Runtime {
380 return tokio::runtime::Builder::new_multi_thread()
381 .enable_all()
382 .worker_threads(threads)
383 .build()
384 .unwrap();
385 }
386
387 #[test]
388 fn test_inner() {
389 make_runtime(1).block_on(async move {
390 let wg = WaitGroup::new();
391 wg.add(2);
392 let _wg = wg.clone();
393 let th = tokio::spawn(async move {
394 assert!(_wg.wait_to(1).await);
395 });
396 sleep(Duration::from_secs(1)).await;
397 {
398 let guard = wg.0.waker.lock();
399 assert!(guard.is_some());
400 assert_eq!(wg.0.waiting.load(Ordering::Acquire), 1);
401 }
402 wg.done();
403 let _ = th.await;
404 assert_eq!(wg.0.waiting.load(Ordering::Acquire), -1);
405 assert_eq!(wg.left(), 1);
406 wg.done();
407 assert_eq!(wg.left(), 0);
408 assert_eq!(wg.wait_to(0).await, false);
409 });
410 }
411
412 #[test]
413 fn test_cancel() {
414 let wg = WaitGroup::new();
415 make_runtime(1).block_on(async move {
416 wg.add(1);
417 println!("test timeout");
418 assert!(timeout(Duration::from_secs(1), wg.wait()).await.is_err());
419 println!("timeout happened");
420 assert_eq!(wg.0.waiting.load(Ordering::Acquire), -1);
421 wg.done();
422 wg.add(2);
423 wg.done_many(2);
424 wg.add(2);
425 let _wg = wg.clone();
426 let th = tokio::spawn(async move {
427 _wg.wait().await;
428 });
429 sleep(Duration::from_millis(200)).await;
430 wg.done();
431 wg.done();
432 let _ = th.await;
433 });
434 }
435}