1use crate::shim::atomic::{AtomicU8, Ordering};
11use std::future::Future;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14
15use super::atomic_waker::AtomicWaker;
16
17const EMPTY: u8 = 0; const WAITING: u8 = 1; const NOTIFIED: u8 = 2; pub struct SingleWaiterNotify {
26 state: AtomicU8,
27 waker: AtomicWaker,
28}
29
30impl std::fmt::Debug for SingleWaiterNotify {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 let state = self.state.load(Ordering::Acquire);
33 let state_str = match state {
34 EMPTY => "Empty",
35 WAITING => "Waiting",
36 NOTIFIED => "Notified",
37 _ => "Unknown",
38 };
39 f.debug_struct("SingleWaiterNotify")
40 .field("state", &state_str)
41 .finish()
42 }
43}
44
45impl Default for SingleWaiterNotify {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl SingleWaiterNotify {
52 #[inline]
56 pub fn new() -> Self {
57 Self {
58 state: AtomicU8::new(EMPTY),
59 waker: AtomicWaker::new(),
60 }
61 }
62
63 #[inline]
67 pub fn notified(&self) -> Notified<'_> {
68 Notified {
69 notify: self,
70 registered: false,
71 }
72 }
73
74 #[inline]
82 pub fn notify_one(&self) {
83 let prev_state = self.state.swap(NOTIFIED, Ordering::AcqRel);
85
86 if prev_state == WAITING {
88 self.waker.wake();
89 }
90 }
91
92 #[inline]
100 fn register_waker(&self, waker: &std::task::Waker) -> bool {
101 self.waker.register(waker);
107
108 let current_state = self.state.load(Ordering::Acquire);
109
110 if current_state == NOTIFIED {
112 self.state.store(EMPTY, Ordering::Release);
114 return true;
115 }
116
117 match self
119 .state
120 .compare_exchange(EMPTY, WAITING, Ordering::AcqRel, Ordering::Acquire)
121 {
122 Ok(_) => {
123 if self.state.load(Ordering::Acquire) == NOTIFIED {
126 self.state.store(EMPTY, Ordering::Release);
128 true
129 } else {
130 false
131 }
132 }
133 Err(state) => {
134 if state == NOTIFIED {
136 self.state.store(EMPTY, Ordering::Release);
138 true
139 } else {
140 if self.state.load(Ordering::Acquire) == NOTIFIED {
143 self.state.store(EMPTY, Ordering::Release);
144 true
145 } else {
146 false
147 }
148 }
149 }
150 }
151 }
152}
153
154pub struct Notified<'a> {
164 notify: &'a SingleWaiterNotify,
165 registered: bool,
166}
167
168impl<'a> std::fmt::Debug for Notified<'a> {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 let state = self.notify.state.load(Ordering::Acquire);
171 let state_str = match state {
172 EMPTY => "Empty",
173 WAITING => "Waiting",
174 NOTIFIED => "Notified",
175 _ => "Unknown",
176 };
177 f.debug_struct("Notified")
178 .field("state", &state_str)
179 .field("registered", &self.registered)
180 .finish()
181 }
182}
183
184impl Future for Notified<'_> {
185 type Output = ();
186
187 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
188 if !self.registered {
190 self.registered = true;
191 if self.notify.register_waker(cx.waker()) {
192 return Poll::Ready(());
194 }
195 } else {
196 if self.notify.state.load(Ordering::Acquire) == NOTIFIED {
198 self.notify.state.store(EMPTY, Ordering::Release);
199 return Poll::Ready(());
200 }
201 if self.notify.register_waker(cx.waker()) {
204 return Poll::Ready(());
205 }
206 }
207
208 Poll::Pending
209 }
210}
211
212impl Drop for Notified<'_> {
213 fn drop(&mut self) {
214 if self.registered {
215 let _ = self.notify.state.compare_exchange(
227 WAITING,
228 EMPTY,
229 Ordering::Relaxed,
230 Ordering::Relaxed,
231 );
232 }
233 }
234}
235
236#[cfg(all(test, not(feature = "loom")))]
237mod tests {
238 use super::*;
239 use std::sync::Arc;
240 use tokio::time::{Duration, sleep};
241
242 #[tokio::test]
243 async fn test_notify_before_wait() {
244 let notify = Arc::new(SingleWaiterNotify::new());
245
246 notify.notify_one();
248
249 notify.notified().await;
251 }
252
253 #[tokio::test]
254 async fn test_notify_after_wait() {
255 let notify = Arc::new(SingleWaiterNotify::new());
256 let notify_clone = notify.clone();
257
258 tokio::spawn(async move {
260 sleep(Duration::from_millis(10)).await;
261 notify_clone.notify_one();
262 });
263
264 notify.notified().await;
266 }
267
268 #[tokio::test]
269 async fn test_multiple_notify_cycles() {
270 let notify = Arc::new(SingleWaiterNotify::new());
271
272 for _ in 0..10 {
273 let notify_clone = notify.clone();
274 tokio::spawn(async move {
275 sleep(Duration::from_millis(5)).await;
276 notify_clone.notify_one();
277 });
278
279 notify.notified().await;
280 }
281 }
282
283 #[tokio::test]
284 async fn test_concurrent_notify() {
285 let notify = Arc::new(SingleWaiterNotify::new());
286 let notify_clone = notify.clone();
287
288 for _ in 0..5 {
290 let n = notify_clone.clone();
291 tokio::spawn(async move {
292 sleep(Duration::from_millis(10)).await;
293 n.notify_one();
294 });
295 }
296
297 notify.notified().await;
298 }
299
300 #[tokio::test]
301 async fn test_notify_no_waiter() {
302 let notify = SingleWaiterNotify::new();
303
304 notify.notify_one();
306 notify.notify_one();
307
308 notify.notified().await;
310 }
311
312 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
313 async fn test_stress_test() {
314 let notify = Arc::new(SingleWaiterNotify::new());
315
316 for i in 0..100 {
317 let notify_clone = notify.clone();
318 tokio::spawn(async move {
319 sleep(Duration::from_micros(i % 10)).await;
320 notify_clone.notify_one();
321 });
322
323 notify.notified().await;
324 }
325 }
326
327 #[tokio::test]
328 async fn test_immediate_notification_race() {
329 for _ in 0..100 {
331 let notify = Arc::new(SingleWaiterNotify::new());
332 let notify_clone = notify.clone();
333
334 let waiter = tokio::spawn(async move {
335 notify.notified().await;
336 });
337
338 notify_clone.notify_one();
340
341 tokio::time::timeout(Duration::from_millis(100), waiter)
343 .await
344 .expect("Should not timeout")
345 .expect("Task should complete");
346 }
347 }
348}