atomic_waitgroup/
lib.rs

1//! A waitgroup support async with advanced features
2//!
3//! implemented with atomic operations to reduce locking.
4//!
5//! # Features & restrictions
6//!
7//! * wait_to() supports waiting for a value >= zero
8//!
9//! * wait() & wait_to() can be canceled wrapped by timeout or futures::select!.
10//!
11//! * Assumes only one thread calls wait(). If multiple concurrent wait() is detected,
12//! will panic for this invalid usage.
13//!
14//! * done() & wait() is allowed to called concurrently.
15//!
16//! * add() & done() is allowed to called concurrently.
17//!
18//! * Assumes add() and wait() are in the same thread.
19//!
20//! # Example
21//!
22//! ```
23//! extern crate atomic_waitgroup;
24//! use atomic_waitgroup::WaitGroup;
25//! use tokio::runtime::Runtime;
26//!
27//! let rt = Runtime::new().unwrap();
28//! let wg = WaitGroup::new();
29//! rt.block_on(async move {
30//!     for i in 0..2 {
31//!         let _guard = wg.add_guard();
32//!         tokio::spawn(async move {
33//!            // Do something
34//!             drop(_guard);
35//!         });
36//!     }
37//!     match tokio::time::timeout(
38//!         tokio::time::Duration::from_secs(1),
39//!         wg.wait_to(1)).await {
40//!         Ok(_) => {
41//!             assert!(wg.left() <= 1);
42//!         }
43//!         Err(_) => {
44//!             println!("wg.wait_to(1) timeouted");
45//!         }
46//!     }
47//! });
48//!
49//!
50
51use log::error;
52use std::{
53    future::Future,
54    pin::Pin,
55    sync::{
56        atomic::{AtomicI64, Ordering},
57        Arc,
58    },
59    task::{Context, Poll, Waker},
60};
61use parking_lot::Mutex;
62
63/*
64
65NOTE: Multiple atomic operation must happen at the same order
66
67WaitGroupFuture |   done()
68----------------|
69left.load()     |
70                |   left -=1
71                |   load_waiting
72waiting = true  |
73left.load ()    |
74------------------------------
75
76*/
77pub struct WaitGroup(Arc<WaitGroupInner>);
78
79// do not allow multiple wait
80impl Clone for WaitGroup {
81    fn clone(&self) -> Self {
82        Self(self.0.clone())
83    }
84}
85
86macro_rules! log_and_panic {
87    ($($arg:tt)+) => (
88        error!($($arg)+);
89        panic!($($arg)+);
90    );
91}
92
93macro_rules! trace_log {
94    ($($arg:tt)+) => (
95        #[cfg(feature="trace_log")]
96        {
97            log::trace!($($arg)+);
98        }
99    );
100}
101
102impl WaitGroup {
103    pub fn new() -> Self {
104        Self(WaitGroupInner::new())
105    }
106
107    /// Return the count left inside this WaitGroup
108    #[inline(always)]
109    pub fn left(&self) -> usize {
110        let count = self.0.left.load(Ordering::SeqCst);
111        if count < 0 {
112            log_and_panic!("WaitGroup.left {} < 0", count);
113        }
114        count as usize
115    }
116
117    /// Add specified count.
118    ///
119    /// NOTE: You should always add() before done()
120    #[inline(always)]
121    pub fn add(&self, i: usize) {
122        // To prevent code below re-order above, use Acquire here.
123        let _r = self.0.left.fetch_add(i as i64, Ordering::Acquire);
124        trace_log!("add {}->{}", i, _r + i as i64);
125    }
126
127    /// Add one to the WaitGroup, return a guard to decrease the count on drop.
128    ///
129    /// # Example
130    ///
131    /// ```
132    /// extern crate atomic_waitgroup;
133    /// use atomic_waitgroup::WaitGroup;
134    /// use tokio::runtime::Runtime;
135    ///
136    /// let wg = WaitGroup::new();
137    /// let rt = Runtime::new().unwrap();
138
139    /// rt.block_on(async move {
140    ///     let _guard = wg.add_guard();
141    ///     tokio::spawn(async move {
142    ///         // Do something
143    ///         drop(_guard);
144    ///     });
145    ///     wg.wait().await;
146    /// });
147    #[inline(always)]
148    pub fn add_guard(&self) -> WaitGroupGuard {
149        self.add(1);
150        WaitGroupGuard {
151            inner: self.0.clone(),
152        }
153    }
154
155    /// Wait until specified count is left in the WaitGroup.
156    ///
157    /// Return false means there's no waiting happened.
158    ///
159    /// Return true means the blocking actually happened.
160    ///
161    /// # NOTE
162    ///
163    /// * Only assume one waiting future at the same time, otherwise will panic.
164    ///
165    /// * Canceling future is supported.
166    pub async fn wait_to(&self, target: usize) -> bool {
167        let _self = self.0.as_ref();
168        // We will check again with SeqCst later to prevent deadlock
169        let left = _self.left.load(Ordering::Acquire);
170        if left <= target as i64 {
171            trace_log!("wait_to skip {} <= target {}", left, target);
172            return false;
173        }
174        WaitGroupFuture {
175            wg: &_self,
176            target,
177            waker: None,
178        }
179        .await;
180        return true;
181    }
182
183    /// Wait until zero count in the WaitGroup.
184    ///
185    /// # NOTE
186    ///
187    /// * Only assume one waiting future at the same time, otherwise will panic.
188    ///
189    /// * Canceling future is supported.
190    #[inline(always)]
191    pub async fn wait(&self) {
192        self.wait_to(0).await;
193    }
194
195    /// Decrease count by one.
196    #[inline]
197    pub fn done(&self) {
198        let inner = self.0.as_ref();
199        inner.done(1);
200    }
201
202    /// Decrease count by specified value
203    #[inline]
204    pub fn done_many(&self, count: usize) {
205        let inner = self.0.as_ref();
206        inner.done(count as i64);
207    }
208}
209
210pub struct WaitGroupGuard {
211    inner: Arc<WaitGroupInner>,
212}
213
214impl Drop for WaitGroupGuard {
215    fn drop(&mut self) {
216        let inner = &self.inner;
217        inner.done(1);
218    }
219}
220
221struct WaitGroupInner {
222    /// The current count
223    left: AtomicI64,
224    /// The target count (>=0) if someone waiting, if no one is waiting, should be -1
225    waiting: AtomicI64,
226    waker: Mutex<Option<Arc<Waker>>>,
227}
228
229impl WaitGroupInner {
230    #[inline(always)]
231    fn new() -> Arc<Self> {
232        Arc::new(Self {
233            left: AtomicI64::new(0),
234            waiting: AtomicI64::new(-1),
235            waker: Mutex::new(None),
236        })
237    }
238    #[inline]
239    fn done(&self, count: i64) {
240        // There's SeqCst behind, it's ok to use Relaxed
241        let left = self.left.fetch_sub(count, Ordering::SeqCst) - count;
242        if left < 0 {
243            log_and_panic!("WaitGroup.left {} < 0", left);
244        }
245        let waiting = self.waiting.load(Ordering::SeqCst);
246        if waiting < 0 {
247            trace_log!("done {}->{} not waiting", count, left);
248            return;
249        }
250        if left <= waiting {
251            if self.waiting.compare_exchange(waiting, -1, Ordering::SeqCst, Ordering::Relaxed).is_ok() {
252                let mut guard = self.waker.lock();
253                if let Some(waker) = guard.take() {
254                    waker.wake_by_ref();
255                    drop(guard);
256                    trace_log!("done {}->{} wake {}", count, left, waiting);
257                } else {
258                    drop(guard);
259                    trace_log!("done {}->{} wake {} but no waker", count, left, waiting);
260                }
261            }
262            // some one already wake
263        } else {
264            trace_log!("done {}->{} waiting {}", count, left, waiting);
265        }
266    }
267
268    /// Once waker set, waker might be false waken many times
269    #[inline]
270    fn set_waker(&self, waker: Arc<Waker>, target: usize, force: bool) {
271        trace_log!("set_waker {} force={}", target, force);
272        {
273            let mut guard = self.waker.lock();
274            if !force {
275                if guard.is_some() {
276                    drop(guard);
277                    log_and_panic!("concurrent wait detected");
278                }
279            }
280            guard.replace(waker);
281            let old_target = self.waiting.swap(target as i64, Ordering::SeqCst);
282            drop(guard);
283            if ! force && old_target >= 0 {
284                log_and_panic!("Concurrent wait() by multiple coroutines, enter unlikely code");
285            }
286        }
287    }
288
289    #[inline]
290    fn cancel_wait(&self) {
291        trace_log!("cancel_wait");
292        {
293            let mut guard = self.waker.lock();
294            self.waiting.store(-1, Ordering::SeqCst);
295            let _ = guard.take();
296        }
297    }
298}
299
300struct WaitGroupFuture<'a> {
301    wg: &'a WaitGroupInner,
302    target: usize,
303    waker: Option<Arc<Waker>>,
304}
305
306impl<'a> WaitGroupFuture<'a> {
307    #[inline(always)]
308    fn _poll(&mut self) -> bool {
309        // Use SeqCst to avoid reading old value
310        let cur = self.wg.left.load(Ordering::SeqCst);
311        if cur <= self.target as i64 {
312            trace_log!("poll ready {}<={}", cur, self.target);
313            self._clear();
314            true
315        } else {
316            trace_log!("poll not ready {}>{}", cur, self.target);
317            false
318        }
319    }
320
321    #[inline(always)]
322    fn _clear(&mut self) {
323        if self.waker.take().is_some() {
324            self.wg.cancel_wait();
325        }
326    }
327}
328
329/// When wait() is canceled with timeout(),  make sure it clear the waker.
330impl<'a> Drop for WaitGroupFuture<'a> {
331    fn drop(&mut self) {
332        self._clear();
333    }
334}
335
336impl<'a> Future for WaitGroupFuture<'a> {
337    type Output = ();
338
339    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
340        let _self = self.get_mut();
341        if _self._poll() {
342            return Poll::Ready(());
343        }
344        let force = {
345            if let Some(waker) = _self.waker.as_ref() {
346                // First check if someone take the waker
347                if _self.wg.waiting.load(Ordering::SeqCst) >= 0 &&
348                    // Sometimes tokio will make waker ineffect,
349                    // we should always check before reuse the same waker.
350                    waker.will_wake(ctx.waker()) {
351                    return Poll::Pending;
352                }
353                // The waker is not usable, reg another
354                true
355            } else {
356                false
357            }
358        };
359        // The Arc is for checking waker without lock
360        let waker = Arc::new(ctx.waker().clone());
361        _self.wg.set_waker(waker.clone(), _self.target, force);
362        _self.waker.replace(waker);
363        if _self._poll() {
364            return Poll::Ready(());
365        }
366        Poll::Pending
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    extern crate rand;
373
374    use std::time::Duration;
375    use tokio::time::{sleep, timeout};
376
377    use super::*;
378
379    fn make_runtime(threads: usize) -> tokio::runtime::Runtime {
380        return tokio::runtime::Builder::new_multi_thread()
381            .enable_all()
382            .worker_threads(threads)
383            .build()
384            .unwrap();
385    }
386
387    #[test]
388    fn test_inner() {
389        make_runtime(1).block_on(async move {
390            let wg = WaitGroup::new();
391            wg.add(2);
392            let _wg = wg.clone();
393            let th = tokio::spawn(async move {
394                assert!(_wg.wait_to(1).await);
395            });
396            sleep(Duration::from_secs(1)).await;
397            {
398                let guard = wg.0.waker.lock();
399                assert!(guard.is_some());
400                assert_eq!(wg.0.waiting.load(Ordering::Acquire), 1);
401            }
402            wg.done();
403            let _ = th.await;
404            assert_eq!(wg.0.waiting.load(Ordering::Acquire), -1);
405            assert_eq!(wg.left(), 1);
406            wg.done();
407            assert_eq!(wg.left(), 0);
408            assert_eq!(wg.wait_to(0).await, false);
409        });
410    }
411
412    #[test]
413    fn test_cancel() {
414        let wg = WaitGroup::new();
415        make_runtime(1).block_on(async move {
416            wg.add(1);
417            println!("test timeout");
418            assert!(timeout(Duration::from_secs(1), wg.wait()).await.is_err());
419            println!("timeout happened");
420            assert_eq!(wg.0.waiting.load(Ordering::Acquire), -1);
421            wg.done();
422            wg.add(2);
423            wg.done_many(2);
424            wg.add(2);
425            let _wg = wg.clone();
426            let th = tokio::spawn(async move {
427                _wg.wait().await;
428            });
429            sleep(Duration::from_millis(200)).await;
430            wg.done();
431            wg.done();
432            let _ = th.await;
433        });
434    }
435}