1use std::{
2 collections::LinkedList,
3 future::Future,
4 pin::Pin,
5 sync::{
6 Arc,
7 atomic::{AtomicBool, Ordering},
8 },
9 task::{Context, Poll, Waker},
10};
11
12use parking_lot::Mutex;
13
14struct NotifyOnceInner {
15 loaded: AtomicBool,
16 wakers: Mutex<LinkedList<Waker>>,
17}
18
19#[derive(Clone)]
24pub struct NotifyOnce(Arc<NotifyOnceInner>);
25
26impl NotifyOnce {
27 pub fn new() -> Self {
28 Self(Arc::new(NotifyOnceInner {
29 loaded: AtomicBool::new(false),
30 wakers: Mutex::new(LinkedList::new()),
31 }))
32 }
33
34 #[inline]
35 pub fn done(&self) {
36 let _self = self.0.as_ref();
37 _self.loaded.store(true, Ordering::Release);
38 {
39 let mut guard = _self.wakers.lock();
40 while let Some(waker) = guard.pop_front() {
41 waker.wake();
42 }
43 }
44 }
45
46 #[inline]
47 pub async fn wait(&self) {
48 NotifyOnceWaitFuture { inner: self.0.as_ref(), is_new: true }.await;
49 }
50}
51
52struct NotifyOnceWaitFuture<'a> {
53 inner: &'a NotifyOnceInner,
54 is_new: bool,
55}
56
57impl<'a> Future for NotifyOnceWaitFuture<'a> {
58 type Output = ();
59
60 fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
61 let _self = self.get_mut();
62 if _self.inner.loaded.load(Ordering::Acquire) {
63 return Poll::Ready(());
64 }
65 if _self.is_new {
66 {
67 let mut guard = _self.inner.wakers.lock();
68 guard.push_back(ctx.waker().clone());
69 }
70 _self.is_new = false;
71 if _self.inner.loaded.load(Ordering::Acquire) {
72 return Poll::Ready(());
73 }
74 }
75 Poll::Pending
76 }
77}
78
79#[cfg(test)]
80mod tests {
81
82 use std::sync::{
83 Arc,
84 atomic::{AtomicUsize, Ordering},
85 };
86
87 use tokio::time::{Duration, sleep};
88
89 use super::*;
90
91 #[test]
92 fn test_notify_once() {
93 let rt = tokio::runtime::Builder::new_multi_thread()
94 .enable_all()
95 .worker_threads(2)
96 .build()
97 .unwrap();
98
99 rt.block_on(async move {
100 let noti = NotifyOnce::new();
101 let done = Arc::new(AtomicUsize::new(0));
102 let mut ths = Vec::new();
103 for _ in 0..10 {
104 let _noti = noti.clone();
105 let _done = done.clone();
106 ths.push(tokio::spawn(async move {
107 _noti.wait().await;
108 _done.fetch_add(1, Ordering::SeqCst);
109 }));
110 }
111 sleep(Duration::from_secs(1)).await;
112 assert_eq!(done.load(Ordering::Acquire), 0);
113 noti.done();
114 for th in ths {
115 let _ = th.await.expect("");
116 }
117 assert_eq!(done.load(Ordering::Acquire), 10);
118 });
119 }
120}