atomic_waitgroup/
lib.rs

1//!
2//! A waitgroup support async with advanced features,
3//! implemented with atomic operations to reduce locking in mind.
4//!
5//! # Features & restrictions
6//!
7//! * wait_to() is supported to wait for a value larger than zero.
8//!
9//! * wait() & wait_to() can be canceled by tokio::time::timeout or futures::select!.
10//!
11//! * Assumes only one thread calls wait(). If multiple concurrent wait() is detected,
12//! will panic for this invalid usage.
13//!
14//! * done() & wait() is allowed to called concurrently.
15//!
16//! * add() & done() is allowed to called concurrently.
17//!
18//! * add() & wait() will not conflict, but concurrent calls are not a good pattern.
19//!
20//! # Example
21//!
22//! ```
23//! extern crate atomic_waitgroup;
24//! use atomic_waitgroup::WaitGroup;
25//! use tokio::runtime::Runtime;
26//!
27//! let rt = Runtime::new().unwrap();
28//! let wg = WaitGroup::new();
29//! rt.block_on(async move {
30//!     for i in 0..2 {
31//!         let _guard = wg.add_guard();
32//!         tokio::spawn(async move {
33//!            // Do something
34//!             drop(_guard);
35//!         });
36//!     }
37//!     match tokio::time::timeout(
38//!         tokio::time::Duration::from_secs(1),
39//!         wg.wait_to(1)).await {
40//!         Ok(_) => {
41//!             assert!(wg.left() <= 1);
42//!         }
43//!         Err(_) => {
44//!             println!("wg.wait_to(1) timeouted");
45//!         }
46//!     }
47//! });
48//!
49//!
50
51use log::error;
52use std::{
53    future::Future,
54    pin::Pin,
55    sync::{
56        atomic::{AtomicI64, AtomicU64, Ordering},
57        Arc,
58    },
59    task::{Context, Poll, Waker},
60};
61
62use parking_lot::Mutex;
63
64/*
65
66NOTE: Multiple atomic operation must happen at the same order
67
68WaitGroupFuture |   done()
69----------
70left.load()     |   left -=1
71waiting = true  |   load_waiting
72left.load ()    |
73------------
74
75*/
76pub struct WaitGroup(Arc<WaitGroupInner>);
77
78// do not allow multiple wait
79impl Clone for WaitGroup {
80    fn clone(&self) -> Self {
81        Self(self.0.clone())
82    }
83}
84
85impl WaitGroup {
86    pub fn new() -> Self {
87        Self(WaitGroupInner::new())
88    }
89
90    /// Return the count left inside this WaitGroup
91    #[inline(always)]
92    pub fn left(&self) -> usize {
93        let count = self.0.left.load(Ordering::SeqCst);
94        if count < 0 {
95            error!("WaitGroup.left {} < 0", count);
96            panic!("WaitGroup.left {} < 0", count);
97        }
98        count as usize
99    }
100
101    /// Add specified count.
102    #[inline(always)]
103    pub fn add(&self, i: usize) {
104        self.0.left.fetch_add(i as i64, Ordering::SeqCst);
105    }
106
107    /// Add one to the WaitGroup, return a guard to decrease the count on drop.
108    ///
109    /// # Example
110    ///
111    /// ```
112    /// extern crate atomic_waitgroup;
113    /// use atomic_waitgroup::WaitGroup;
114    /// use tokio::runtime::Runtime;
115    ///
116    /// let wg = WaitGroup::new();
117    /// let rt = Runtime::new().unwrap();
118
119    /// rt.block_on(async move {
120    ///     let _guard = wg.add_guard();
121    ///     tokio::spawn(async move {
122    ///         // Do something
123    ///         drop(_guard);
124    ///     });
125    ///     wg.wait().await;
126    /// });
127    #[inline(always)]
128    pub fn add_guard(&self) -> WaitGroupGuard {
129        self.0.left.fetch_add(1, Ordering::SeqCst);
130        WaitGroupGuard {
131            inner: self.0.clone(),
132        }
133    }
134
135    /// Wait until specified count is left in the WaitGroup.
136    ///
137    /// Return false means there's no waiting happened.
138    ///
139    /// Return true means the blocking actually happened.
140    ///
141    /// # NOTE
142    ///
143    /// * Only assume one waiting future at the same time, otherwise will panic.
144    ///
145    /// * Canceling future is supported.
146    pub async fn wait_to(&self, target: usize) -> bool {
147        let _self = self.0.as_ref();
148        let left = _self.left.load(Ordering::Acquire);
149        if left <= target as i64 {
150            return false;
151        }
152        WaitGroupFuture {
153            wg: &_self,
154            target,
155            waker_id: 0,
156        }
157        .await;
158        return true;
159    }
160
161    /// Wait until zero count in the WaitGroup.
162    ///
163    /// # NOTE
164    ///
165    /// * Only assume one waiting future at the same time, otherwise will panic.
166    ///
167    /// * Canceling future is supported.
168    #[inline(always)]
169    pub async fn wait(&self) {
170        self.wait_to(0).await;
171    }
172
173    /// Decrease count by one.
174    #[inline]
175    pub fn done(&self) {
176        let inner = self.0.as_ref();
177        inner.done(1);
178    }
179
180    /// Decrease count by specified value
181    #[inline]
182    pub fn done_many(&self, count: usize) {
183        let inner = self.0.as_ref();
184        inner.done(count as i64);
185    }
186}
187
188pub struct WaitGroupGuard {
189    inner: Arc<WaitGroupInner>,
190}
191
192impl Drop for WaitGroupGuard {
193    fn drop(&mut self) {
194        let inner = &self.inner;
195        inner.done(1);
196    }
197}
198
199struct WaitGroupInner {
200    left: AtomicI64,
201    waiting: AtomicI64,
202    waker: Mutex<Option<Waker>>,
203    waker_id: AtomicU64,
204}
205
206impl WaitGroupInner {
207    #[inline(always)]
208    fn new() -> Arc<Self> {
209        Arc::new(Self {
210            left: AtomicI64::new(0),
211            waiting: AtomicI64::new(-1),
212            waker: Mutex::new(None),
213            waker_id: AtomicU64::new(0),
214        })
215    }
216    #[inline]
217    fn done(&self, count: i64) {
218        let left = self.left.fetch_sub(count, Ordering::SeqCst) - count;
219        let waiting = self.waiting.load(Ordering::Acquire);
220        if left < 0 {
221            error!("WaitGroup.left {} < 0", left);
222            panic!("WaitGroup.left {} < 0", left);
223        }
224        if waiting < 0 {
225            return;
226        }
227        if left <= waiting {
228            // Do not take waker, it may be false waken when done() happened before newer wait()
229            if let Some(waker) = self.waker.lock().as_ref() {
230                waker.wake_by_ref();
231            }
232        }
233    }
234
235    /// Once waker set, waker might be false waken many times
236    /// Returns: waker_id
237    #[inline]
238    fn set_waker(&self, waker: Waker, target: usize) -> u64 {
239        let waker_id = self.waker_id.fetch_add(1, Ordering::SeqCst) + 1;
240        {
241            let mut guard = self.waker.lock();
242            guard.replace(waker);
243            let old_target = self.waiting.swap(target as i64, Ordering::SeqCst);
244            if old_target >= 0 {
245                panic!("Concurrent wait() by multiple coroutines is not supported")
246            }
247        }
248        waker_id
249    }
250
251    #[inline]
252    fn cancel_wait(&self, waker_id: u64) {
253        let mut guard = self.waker.lock();
254        // In case wait() is canceled, eg. tokio timeout, do not disrupt other thread wait()
255        if self.waker_id.load(Ordering::Acquire) == waker_id {
256            self.waiting.store(-1, Ordering::Release);
257            let _ = guard.take();
258        }
259    }
260}
261
262struct WaitGroupFuture<'a> {
263    wg: &'a WaitGroupInner,
264    target: usize,
265    waker_id: u64,
266}
267
268impl<'a> WaitGroupFuture<'a> {
269    #[inline(always)]
270    fn _poll(&mut self) -> bool {
271        let cur = self.wg.left.load(Ordering::Acquire);
272        if cur <= self.target as i64 {
273            self._clear();
274            true
275        } else {
276            false
277        }
278    }
279
280    #[inline(always)]
281    fn _clear(&mut self) {
282        if self.waker_id == 0 {
283            return;
284        }
285        self.wg.cancel_wait(self.waker_id);
286        self.waker_id = 0;
287    }
288}
289
290/// When wait() is canceled with timeout(),  make sure it clear the waker.
291impl<'a> Drop for WaitGroupFuture<'a> {
292    fn drop(&mut self) {
293        self._clear();
294    }
295}
296
297impl<'a> Future for WaitGroupFuture<'a> {
298    type Output = ();
299
300    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
301        let _self = self.get_mut();
302        if _self.waker_id == 0 {
303            if _self._poll() {
304                return Poll::Ready(());
305            }
306            _self.waker_id = _self.wg.set_waker(ctx.waker().clone(), _self.target);
307        }
308        if _self._poll() {
309            return Poll::Ready(());
310        }
311        Poll::Pending
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    extern crate rand;
318
319    use std::time::Duration;
320    use tokio::time::{sleep, timeout};
321
322    use super::*;
323
324    fn make_runtime(threads: usize) -> tokio::runtime::Runtime {
325        return tokio::runtime::Builder::new_multi_thread()
326            .enable_all()
327            .worker_threads(threads)
328            .build()
329            .unwrap();
330    }
331
332    #[test]
333    fn test_inner() {
334        make_runtime(1).block_on(async move {
335            let wg = WaitGroup::new();
336            wg.add(2);
337            let _wg = wg.clone();
338            let th = tokio::spawn(async move {
339                assert!(_wg.wait_to(1).await);
340            });
341            sleep(Duration::from_secs(1)).await;
342            assert_eq!(wg.0.waker_id.load(Ordering::Acquire), 1);
343            {
344                let guard = wg.0.waker.lock();
345                assert!(guard.is_some());
346                assert_eq!(wg.0.waiting.load(Ordering::Acquire), 1);
347            }
348            wg.done();
349            let _ = th.await;
350            assert_eq!(wg.0.waker_id.load(Ordering::Acquire), 1);
351            assert_eq!(wg.0.waiting.load(Ordering::Acquire), -1);
352            assert_eq!(wg.left(), 1);
353            wg.done();
354            assert_eq!(wg.left(), 0);
355            assert_eq!(wg.wait_to(0).await, false);
356        });
357    }
358
359    #[test]
360    fn test_cancel() {
361        let wg = WaitGroup::new();
362        make_runtime(1).block_on(async move {
363            wg.add(1);
364            println!("test timeout");
365            assert!(timeout(Duration::from_secs(1), wg.wait()).await.is_err());
366            println!("timeout happened");
367            assert_eq!(wg.0.waiting.load(Ordering::Acquire), -1);
368            wg.done();
369            wg.add(2);
370            wg.done_many(2);
371            wg.add(2);
372            let _wg = wg.clone();
373            let th = tokio::spawn(async move {
374                _wg.wait().await;
375            });
376            sleep(Duration::from_millis(200)).await;
377            assert_eq!(wg.0.waker_id.load(Ordering::Acquire), 2);
378            wg.done();
379            wg.done();
380            let _ = th.await;
381        });
382    }
383}