1use std::sync::atomic::{AtomicU8, Ordering};
11use std::future::Future;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14
15use super::atomic_waker::AtomicWaker;
16
17const EMPTY: u8 = 0; const WAITING: u8 = 1; const NOTIFIED: u8 = 2; pub struct SingleWaiterNotify {
26 state: AtomicU8,
27 waker: AtomicWaker,
28}
29
30impl std::fmt::Debug for SingleWaiterNotify {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 let state = self.state.load(Ordering::Acquire);
33 let state_str = match state {
34 EMPTY => "Empty",
35 WAITING => "Waiting",
36 NOTIFIED => "Notified",
37 _ => "Unknown",
38 };
39 f.debug_struct("SingleWaiterNotify")
40 .field("state", &state_str)
41 .finish()
42 }
43}
44
45impl Default for SingleWaiterNotify {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl SingleWaiterNotify {
52 #[inline]
56 pub fn new() -> Self {
57 Self {
58 state: AtomicU8::new(EMPTY),
59 waker: AtomicWaker::new(),
60 }
61 }
62
63 #[inline]
67 pub fn notified(&self) -> Notified<'_> {
68 Notified {
69 notify: self,
70 registered: false,
71 }
72 }
73
74 #[inline]
82 pub fn notify_one(&self) {
83 let prev_state = self.state.swap(NOTIFIED, Ordering::AcqRel);
85
86 if prev_state == WAITING {
88 self.waker.wake();
89 }
90 }
91
92 #[inline]
100 fn register_waker(&self, waker: &std::task::Waker) -> bool {
101 self.waker.register(waker);
107
108 let current_state = self.state.load(Ordering::Acquire);
109
110 if current_state == NOTIFIED {
112 self.state.store(EMPTY, Ordering::Release);
114 return true;
115 }
116
117 match self.state.compare_exchange(
119 EMPTY,
120 WAITING,
121 Ordering::AcqRel,
122 Ordering::Acquire,
123 ) {
124 Ok(_) => {
125 if self.state.load(Ordering::Acquire) == NOTIFIED {
128 self.state.store(EMPTY, Ordering::Release);
130 true
131 } else {
132 false
133 }
134 }
135 Err(state) => {
136 if state == NOTIFIED {
138 self.state.store(EMPTY, Ordering::Release);
140 true
141 } else {
142 if self.state.load(Ordering::Acquire) == NOTIFIED {
145 self.state.store(EMPTY, Ordering::Release);
146 true
147 } else {
148 false
149 }
150 }
151 }
152 }
153 }
154}
155
156pub struct Notified<'a> {
166 notify: &'a SingleWaiterNotify,
167 registered: bool,
168}
169
170impl<'a> std::fmt::Debug for Notified<'a> {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 let state = self.notify.state.load(Ordering::Acquire);
173 let state_str = match state {
174 EMPTY => "Empty",
175 WAITING => "Waiting",
176 NOTIFIED => "Notified",
177 _ => "Unknown",
178 };
179 f.debug_struct("Notified")
180 .field("state", &state_str)
181 .field("registered", &self.registered)
182 .finish()
183 }
184}
185
186impl Future for Notified<'_> {
187 type Output = ();
188
189 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
190 if !self.registered {
192 self.registered = true;
193 if self.notify.register_waker(cx.waker()) {
194 return Poll::Ready(());
196 }
197 } else {
198 if self.notify.state.load(Ordering::Acquire) == NOTIFIED {
200 self.notify.state.store(EMPTY, Ordering::Release);
201 return Poll::Ready(());
202 }
203 if self.notify.register_waker(cx.waker()) {
206 return Poll::Ready(());
207 }
208 }
209
210 Poll::Pending
211 }
212}
213
214impl Drop for Notified<'_> {
215 fn drop(&mut self) {
216 if self.registered {
217 let _ = self.notify.state.compare_exchange(
229 WAITING,
230 EMPTY,
231 Ordering::Relaxed,
232 Ordering::Relaxed,
233 );
234 }
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use std::sync::Arc;
242 use tokio::time::{sleep, Duration};
243
244 #[tokio::test]
245 async fn test_notify_before_wait() {
246 let notify = Arc::new(SingleWaiterNotify::new());
247
248 notify.notify_one();
250
251 notify.notified().await;
253 }
254
255 #[tokio::test]
256 async fn test_notify_after_wait() {
257 let notify = Arc::new(SingleWaiterNotify::new());
258 let notify_clone = notify.clone();
259
260 tokio::spawn(async move {
262 sleep(Duration::from_millis(10)).await;
263 notify_clone.notify_one();
264 });
265
266 notify.notified().await;
268 }
269
270 #[tokio::test]
271 async fn test_multiple_notify_cycles() {
272 let notify = Arc::new(SingleWaiterNotify::new());
273
274 for _ in 0..10 {
275 let notify_clone = notify.clone();
276 tokio::spawn(async move {
277 sleep(Duration::from_millis(5)).await;
278 notify_clone.notify_one();
279 });
280
281 notify.notified().await;
282 }
283 }
284
285 #[tokio::test]
286 async fn test_concurrent_notify() {
287 let notify = Arc::new(SingleWaiterNotify::new());
288 let notify_clone = notify.clone();
289
290 for _ in 0..5 {
292 let n = notify_clone.clone();
293 tokio::spawn(async move {
294 sleep(Duration::from_millis(10)).await;
295 n.notify_one();
296 });
297 }
298
299 notify.notified().await;
300 }
301
302 #[tokio::test]
303 async fn test_notify_no_waiter() {
304 let notify = SingleWaiterNotify::new();
305
306 notify.notify_one();
308 notify.notify_one();
309
310 notify.notified().await;
312 }
313
314 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
315 async fn test_stress_test() {
316 let notify = Arc::new(SingleWaiterNotify::new());
317
318 for i in 0..100 {
319 let notify_clone = notify.clone();
320 tokio::spawn(async move {
321 sleep(Duration::from_micros(i % 10)).await;
322 notify_clone.notify_one();
323 });
324
325 notify.notified().await;
326 }
327 }
328
329 #[tokio::test]
330 async fn test_immediate_notification_race() {
331 for _ in 0..100 {
333 let notify = Arc::new(SingleWaiterNotify::new());
334 let notify_clone = notify.clone();
335
336 let waiter = tokio::spawn(async move {
337 notify.notified().await;
338 });
339
340 notify_clone.notify_one();
342
343 tokio::time::timeout(Duration::from_millis(100), waiter)
345 .await
346 .expect("Should not timeout")
347 .expect("Task should complete");
348 }
349 }
350}
351