sync_utils/
notifier.rs

1use std::{
2    collections::LinkedList,
3    future::Future,
4    pin::Pin,
5    sync::{
6        Arc,
7        atomic::{AtomicBool, Ordering},
8    },
9    task::{Context, Poll, Waker},
10};
11
12use parking_lot::Mutex;
13
14struct NotifyOnceInner {
15    loaded: AtomicBool,
16    wakers: Mutex<LinkedList<Waker>>,
17}
18
19/// NotifyOnce Assumes:
20///
21/// One coroutine issue some loading job, multiple coroutines wait for it to complete.
22///
23/// ## exmaple:
24/// ``` rust
25///
26/// async fn foo() {
27///     use sync_utils::notifier::NotifyOnce;
28///     use tokio::time::*;
29///     use std::sync::{Arc, atomic::{AtomicBool, AtomicUsize, Ordering}};
30///     let noti = NotifyOnce::new();
31///     let done = Arc::new(AtomicBool::new(false));
32///     for _ in 0..10 {
33///         let _noti = noti.clone();
34///         let _done = done.clone();
35///         tokio::spawn(async move {
36///             assert_eq!(_done.load(Ordering::Acquire), false);
37///             _noti.wait().await;
38///             assert_eq!(_done.load(Ordering::Acquire), true);
39///         });
40///     }
41///     sleep(Duration::from_secs(1)).await;
42///     done.store(true, Ordering::Release);
43///     noti.done();
44/// }
45/// ```
46
47#[derive(Clone)]
48pub struct NotifyOnce(Arc<NotifyOnceInner>);
49
50impl NotifyOnce {
51    pub fn new() -> Self {
52        Self(Arc::new(NotifyOnceInner {
53            loaded: AtomicBool::new(false),
54            wakers: Mutex::new(LinkedList::new()),
55        }))
56    }
57
58    #[inline]
59    pub fn done(&self) {
60        let _self = self.0.as_ref();
61        _self.loaded.store(true, Ordering::Release);
62        {
63            let mut guard = _self.wakers.lock();
64            while let Some(waker) = guard.pop_front() {
65                waker.wake();
66            }
67        }
68    }
69
70    #[inline]
71    pub async fn wait(&self) {
72        NotifyOnceWaitFuture { inner: self.0.as_ref(), is_new: true }.await;
73    }
74}
75
76struct NotifyOnceWaitFuture<'a> {
77    inner: &'a NotifyOnceInner,
78    is_new: bool,
79}
80
81impl<'a> Future for NotifyOnceWaitFuture<'a> {
82    type Output = ();
83
84    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
85        let _self = self.get_mut();
86        if _self.inner.loaded.load(Ordering::Acquire) {
87            return Poll::Ready(());
88        }
89        if _self.is_new {
90            {
91                let mut guard = _self.inner.wakers.lock();
92                guard.push_back(ctx.waker().clone());
93            }
94            _self.is_new = false;
95            if _self.inner.loaded.load(Ordering::Acquire) {
96                return Poll::Ready(());
97            }
98        }
99        Poll::Pending
100    }
101}
102
103#[cfg(test)]
104mod tests {
105
106    use std::sync::{
107        Arc,
108        atomic::{AtomicUsize, Ordering},
109    };
110
111    use tokio::time::{Duration, sleep};
112
113    use super::*;
114
115    #[test]
116    fn test_notify_once() {
117        let rt = tokio::runtime::Builder::new_multi_thread()
118            .enable_all()
119            .worker_threads(2)
120            .build()
121            .unwrap();
122
123        rt.block_on(async move {
124            let noti = NotifyOnce::new();
125            let done = Arc::new(AtomicBool::new(false));
126            let wait_count = Arc::new(AtomicUsize::new(0));
127            let mut ths = Vec::new();
128            for _ in 0..10 {
129                let _noti = noti.clone();
130                let _done = done.clone();
131                let _wait_count = wait_count.clone();
132                ths.push(tokio::spawn(async move {
133                    assert_eq!(_done.load(Ordering::Acquire), false);
134                    _noti.wait().await;
135                    _wait_count.fetch_add(1, Ordering::SeqCst);
136                    assert_eq!(_done.load(Ordering::Acquire), true);
137                }));
138            }
139            sleep(Duration::from_secs(1)).await;
140            assert_eq!(wait_count.load(Ordering::Acquire), 0);
141            done.store(true, Ordering::Release);
142            noti.done();
143            for th in ths {
144                let _ = th.await.expect("");
145            }
146            assert_eq!(wait_count.load(Ordering::Acquire), 10);
147        });
148    }
149}