lite_sync/notify.rs
1/// Lightweight single-waiter notification primitive
2///
3/// Optimized for SPSC (Single Producer Single Consumer) pattern where
4/// only one task waits at a time. Much lighter than tokio::sync::Notify.
5///
6/// 轻量级单等待者通知原语
7///
8/// 为 SPSC(单生产者单消费者)模式优化,其中每次只有一个任务等待。
9/// 比 tokio::sync::Notify 更轻量。
10use std::sync::atomic::{AtomicU8, Ordering};
11use std::future::Future;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14
15use super::atomic_waker::AtomicWaker;
16
17// States for the notification
18const EMPTY: u8 = 0; // No waiter, no notification
19const WAITING: u8 = 1; // Waiter registered
20const NOTIFIED: u8 = 2; // Notification sent
21
22/// Lightweight single-waiter notifier optimized for SPSC pattern
23///
24/// 为 SPSC 模式优化的轻量级单等待者通知器
25pub struct SingleWaiterNotify {
26 state: AtomicU8,
27 waker: AtomicWaker,
28}
29
30impl Default for SingleWaiterNotify {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl SingleWaiterNotify {
37 /// Create a new single-waiter notifier
38 ///
39 /// 创建一个新的单等待者通知器
40 #[inline]
41 pub fn new() -> Self {
42 Self {
43 state: AtomicU8::new(EMPTY),
44 waker: AtomicWaker::new(),
45 }
46 }
47
48 /// Returns a future that completes when notified
49 ///
50 /// 返回一个在收到通知时完成的 future
51 #[inline]
52 pub fn notified(&self) -> Notified<'_> {
53 Notified {
54 notify: self,
55 registered: false,
56 }
57 }
58
59 /// Wake the waiting task (if any)
60 ///
61 /// If called before wait, the next wait will complete immediately.
62 ///
63 /// 唤醒等待的任务(如果有)
64 ///
65 /// 如果在等待之前调用,下一次等待将立即完成。
66 #[inline]
67 pub fn notify_one(&self) {
68 // Mark as notified
69 let prev_state = self.state.swap(NOTIFIED, Ordering::AcqRel);
70
71 // If there was a waiter, wake it
72 if prev_state == WAITING {
73 self.waker.wake();
74 }
75 }
76
77 /// Register a waker to be notified
78 ///
79 /// Returns true if already notified (fast path)
80 ///
81 /// 注册一个 waker 以接收通知
82 ///
83 /// 如果已经被通知则返回 true(快速路径)
84 #[inline]
85 fn register_waker(&self, waker: &std::task::Waker) -> bool {
86 // CRITICAL: Register waker FIRST, before changing state to WAITING
87 // This prevents the race where notify_one() sees WAITING but waker isn't registered yet
88 //
89 // 关键:先注册 waker,再将状态改为 WAITING
90 // 这可以防止 notify_one() 看到 WAITING 但 waker 还未注册的竞态条件
91 self.waker.register(waker);
92
93 let current_state = self.state.load(Ordering::Acquire);
94
95 // Fast path: already notified
96 if current_state == NOTIFIED {
97 // Reset to EMPTY for next wait
98 self.state.store(EMPTY, Ordering::Release);
99 return true;
100 }
101
102 // Try to transition from EMPTY to WAITING
103 match self.state.compare_exchange(
104 EMPTY,
105 WAITING,
106 Ordering::AcqRel,
107 Ordering::Acquire,
108 ) {
109 Ok(_) => {
110 // Successfully transitioned to WAITING
111 // Check if we were notified immediately after setting WAITING
112 if self.state.load(Ordering::Acquire) == NOTIFIED {
113 // Race: notified between CAS and this check
114 self.state.store(EMPTY, Ordering::Release);
115 true
116 } else {
117 false
118 }
119 }
120 Err(state) => {
121 // State changed, check what it is now
122 if state == NOTIFIED {
123 // Already notified
124 self.state.store(EMPTY, Ordering::Release);
125 true
126 } else {
127 // State is WAITING (subsequent poll updating waker)
128 // Check if notified after waker update
129 if self.state.load(Ordering::Acquire) == NOTIFIED {
130 self.state.store(EMPTY, Ordering::Release);
131 true
132 } else {
133 false
134 }
135 }
136 }
137 }
138 }
139}
140
141// Drop is automatically handled by AtomicWaker's drop implementation
142// No need for explicit drop implementation
143//
144// Drop 由 AtomicWaker 的 drop 实现自动处理
145// 无需显式的 drop 实现
146
147/// Future returned by `SingleWaiterNotify::notified()`
148///
149/// `SingleWaiterNotify::notified()` 返回的 Future
150pub struct Notified<'a> {
151 notify: &'a SingleWaiterNotify,
152 registered: bool,
153}
154
155impl Future for Notified<'_> {
156 type Output = ();
157
158 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
159 // On first poll, register the waker
160 if !self.registered {
161 self.registered = true;
162 if self.notify.register_waker(cx.waker()) {
163 // Already notified (fast path)
164 return Poll::Ready(());
165 }
166 } else {
167 // On subsequent polls, check if notified
168 if self.notify.state.load(Ordering::Acquire) == NOTIFIED {
169 self.notify.state.store(EMPTY, Ordering::Release);
170 return Poll::Ready(());
171 }
172 // Update waker in case it changed
173 // IMPORTANT: Check return value - may have been notified during registration
174 if self.notify.register_waker(cx.waker()) {
175 return Poll::Ready(());
176 }
177 }
178
179 Poll::Pending
180 }
181}
182
183impl Drop for Notified<'_> {
184 fn drop(&mut self) {
185 if self.registered {
186 // If we registered but are being dropped, try to clean up
187 // PERFORMANCE: Direct compare_exchange without pre-check:
188 // - Single atomic operation instead of two (load + compare_exchange)
189 // - Relaxed ordering is sufficient - just cleaning up state
190 // - CAS will fail harmlessly if state is not WAITING
191 //
192 // 如果我们注册了但正在被 drop,尝试清理
193 // 性能优化:直接 compare_exchange 无需预检查:
194 // - 单次原子操作而不是两次(load + compare_exchange)
195 // - Relaxed ordering 就足够了 - 只是清理状态
196 // - 如果状态不是 WAITING,CAS 会无害地失败
197 let _ = self.notify.state.compare_exchange(
198 WAITING,
199 EMPTY,
200 Ordering::Relaxed,
201 Ordering::Relaxed,
202 );
203 }
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use std::sync::Arc;
211 use tokio::time::{sleep, Duration};
212
213 #[tokio::test]
214 async fn test_notify_before_wait() {
215 let notify = Arc::new(SingleWaiterNotify::new());
216
217 // Notify before waiting
218 notify.notify_one();
219
220 // Should complete immediately
221 notify.notified().await;
222 }
223
224 #[tokio::test]
225 async fn test_notify_after_wait() {
226 let notify = Arc::new(SingleWaiterNotify::new());
227 let notify_clone = notify.clone();
228
229 // Spawn a task that notifies after a delay
230 tokio::spawn(async move {
231 sleep(Duration::from_millis(10)).await;
232 notify_clone.notify_one();
233 });
234
235 // Wait for notification
236 notify.notified().await;
237 }
238
239 #[tokio::test]
240 async fn test_multiple_notify_cycles() {
241 let notify = Arc::new(SingleWaiterNotify::new());
242
243 for _ in 0..10 {
244 let notify_clone = notify.clone();
245 tokio::spawn(async move {
246 sleep(Duration::from_millis(5)).await;
247 notify_clone.notify_one();
248 });
249
250 notify.notified().await;
251 }
252 }
253
254 #[tokio::test]
255 async fn test_concurrent_notify() {
256 let notify = Arc::new(SingleWaiterNotify::new());
257 let notify_clone = notify.clone();
258
259 // Multiple notifiers (only one should wake the waiter)
260 for _ in 0..5 {
261 let n = notify_clone.clone();
262 tokio::spawn(async move {
263 sleep(Duration::from_millis(10)).await;
264 n.notify_one();
265 });
266 }
267
268 notify.notified().await;
269 }
270
271 #[tokio::test]
272 async fn test_notify_no_waiter() {
273 let notify = SingleWaiterNotify::new();
274
275 // Notify with no waiter should not panic
276 notify.notify_one();
277 notify.notify_one();
278
279 // Next wait should complete immediately
280 notify.notified().await;
281 }
282
283 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
284 async fn test_stress_test() {
285 let notify = Arc::new(SingleWaiterNotify::new());
286
287 for i in 0..100 {
288 let notify_clone = notify.clone();
289 tokio::spawn(async move {
290 sleep(Duration::from_micros(i % 10)).await;
291 notify_clone.notify_one();
292 });
293
294 notify.notified().await;
295 }
296 }
297
298 #[tokio::test]
299 async fn test_immediate_notification_race() {
300 // Test the race between notification and registration
301 for _ in 0..100 {
302 let notify = Arc::new(SingleWaiterNotify::new());
303 let notify_clone = notify.clone();
304
305 let waiter = tokio::spawn(async move {
306 notify.notified().await;
307 });
308
309 // Notify immediately (might happen before or after registration)
310 notify_clone.notify_one();
311
312 // Should complete without timeout
313 tokio::time::timeout(Duration::from_millis(100), waiter)
314 .await
315 .expect("Should not timeout")
316 .expect("Task should complete");
317 }
318 }
319}
320