1use std::{
2 collections::LinkedList,
3 future::Future,
4 pin::Pin,
5 sync::{
6 Arc,
7 atomic::{AtomicBool, Ordering},
8 },
9 task::{Context, Poll, Waker},
10};
11
12use parking_lot::Mutex;
13
14struct NotifyOnceInner {
15 loaded: AtomicBool,
16 wakers: Mutex<LinkedList<Waker>>,
17}
18
19#[derive(Clone)]
48pub struct NotifyOnce(Arc<NotifyOnceInner>);
49
50impl NotifyOnce {
51 pub fn new() -> Self {
52 Self(Arc::new(NotifyOnceInner {
53 loaded: AtomicBool::new(false),
54 wakers: Mutex::new(LinkedList::new()),
55 }))
56 }
57
58 #[inline]
59 pub fn done(&self) {
60 let _self = self.0.as_ref();
61 _self.loaded.store(true, Ordering::Release);
62 {
63 let mut guard = _self.wakers.lock();
64 while let Some(waker) = guard.pop_front() {
65 waker.wake();
66 }
67 }
68 }
69
70 #[inline]
71 pub async fn wait(&self) {
72 NotifyOnceWaitFuture { inner: self.0.as_ref(), is_new: true }.await;
73 }
74}
75
76struct NotifyOnceWaitFuture<'a> {
77 inner: &'a NotifyOnceInner,
78 is_new: bool,
79}
80
81impl<'a> Future for NotifyOnceWaitFuture<'a> {
82 type Output = ();
83
84 fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
85 let _self = self.get_mut();
86 if _self.inner.loaded.load(Ordering::Acquire) {
87 return Poll::Ready(());
88 }
89 if _self.is_new {
90 {
91 let mut guard = _self.inner.wakers.lock();
92 guard.push_back(ctx.waker().clone());
93 }
94 _self.is_new = false;
95 if _self.inner.loaded.load(Ordering::Acquire) {
96 return Poll::Ready(());
97 }
98 }
99 Poll::Pending
100 }
101}
102
103#[cfg(test)]
104mod tests {
105
106 use std::sync::{
107 Arc,
108 atomic::{AtomicUsize, Ordering},
109 };
110
111 use tokio::time::{Duration, sleep};
112
113 use super::*;
114
115 #[test]
116 fn test_notify_once() {
117 let rt = tokio::runtime::Builder::new_multi_thread()
118 .enable_all()
119 .worker_threads(2)
120 .build()
121 .unwrap();
122
123 rt.block_on(async move {
124 let noti = NotifyOnce::new();
125 let done = Arc::new(AtomicBool::new(false));
126 let wait_count = Arc::new(AtomicUsize::new(0));
127 let mut ths = Vec::new();
128 for _ in 0..10 {
129 let _noti = noti.clone();
130 let _done = done.clone();
131 let _wait_count = wait_count.clone();
132 ths.push(tokio::spawn(async move {
133 assert_eq!(_done.load(Ordering::Acquire), false);
134 _noti.wait().await;
135 _wait_count.fetch_add(1, Ordering::SeqCst);
136 assert_eq!(_done.load(Ordering::Acquire), true);
137 }));
138 }
139 sleep(Duration::from_secs(1)).await;
140 assert_eq!(wait_count.load(Ordering::Acquire), 0);
141 done.store(true, Ordering::Release);
142 noti.done();
143 for th in ths {
144 let _ = th.await.expect("");
145 }
146 assert_eq!(wait_count.load(Ordering::Acquire), 10);
147 });
148 }
149}