lite_sync/
notify.rs

1//! Lightweight single-waiter notification primitive
2//!
3//! Optimized for SPSC (Single Producer Single Consumer) pattern where
4//! only one task waits at a time. Much lighter than tokio::sync::Notify.
5//!
6//! 轻量级单等待者通知原语
7//!
8//! 为 SPSC(单生产者单消费者)模式优化,其中每次只有一个任务等待。
9//! 比 tokio::sync::Notify 更轻量。
10use crate::shim::atomic::{AtomicU8, Ordering};
11use core::future::Future;
12use core::pin::Pin;
13use core::task::{Context, Poll};
14
15use super::atomic_waker::AtomicWaker;
16
17// States for the notification
18const EMPTY: u8 = 0; // No waiter, no notification
19const WAITING: u8 = 1; // Waiter registered
20const NOTIFIED: u8 = 2; // Notification sent
21
22/// Lightweight single-waiter notifier optimized for SPSC pattern
23///
24/// 为 SPSC 模式优化的轻量级单等待者通知器
25pub struct SingleWaiterNotify {
26    state: AtomicU8,
27    waker: AtomicWaker,
28}
29
30impl core::fmt::Debug for SingleWaiterNotify {
31    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
32        let state = self.state.load(Ordering::Acquire);
33        let state_str = match state {
34            EMPTY => "Empty",
35            WAITING => "Waiting",
36            NOTIFIED => "Notified",
37            _ => "Unknown",
38        };
39        f.debug_struct("SingleWaiterNotify")
40            .field("state", &state_str)
41            .finish()
42    }
43}
44
45impl Default for SingleWaiterNotify {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl SingleWaiterNotify {
52    /// Create a new single-waiter notifier
53    ///
54    /// 创建一个新的单等待者通知器
55    #[inline]
56    pub fn new() -> Self {
57        Self {
58            state: AtomicU8::new(EMPTY),
59            waker: AtomicWaker::new(),
60        }
61    }
62
63    /// Returns a future that completes when notified
64    ///
65    /// 返回一个在收到通知时完成的 future
66    #[inline]
67    pub fn notified(&self) -> Notified<'_> {
68        Notified {
69            notify: self,
70            registered: false,
71        }
72    }
73
74    /// Wake the waiting task (if any)
75    ///
76    /// If called before wait, the next wait will complete immediately.
77    ///
78    /// 唤醒等待的任务(如果有)
79    ///
80    /// 如果在等待之前调用,下一次等待将立即完成。
81    #[inline]
82    pub fn notify_one(&self) {
83        // Mark as notified
84        let prev_state = self.state.swap(NOTIFIED, Ordering::AcqRel);
85
86        // If there was a waiter, wake it
87        if prev_state == WAITING {
88            self.waker.wake();
89        }
90    }
91
92    /// Register a waker to be notified
93    ///
94    /// Returns true if already notified (fast path)
95    ///
96    /// 注册一个 waker 以接收通知
97    ///
98    /// 如果已经被通知则返回 true(快速路径)
99    #[inline]
100    fn register_waker(&self, waker: &core::task::Waker) -> bool {
101        // CRITICAL: Register waker FIRST, before changing state to WAITING
102        // This prevents the race where notify_one() sees WAITING but waker isn't registered yet
103        //
104        // 关键:先注册 waker,再将状态改为 WAITING
105        // 这可以防止 notify_one() 看到 WAITING 但 waker 还未注册的竞态条件
106        self.waker.register(waker);
107
108        let current_state = self.state.load(Ordering::Acquire);
109
110        // Fast path: already notified
111        if current_state == NOTIFIED {
112            // Reset to EMPTY for next wait
113            self.state.store(EMPTY, Ordering::Release);
114            return true;
115        }
116
117        // Try to transition from EMPTY to WAITING
118        match self
119            .state
120            .compare_exchange(EMPTY, WAITING, Ordering::AcqRel, Ordering::Acquire)
121        {
122            Ok(_) => {
123                // Successfully transitioned to WAITING
124                // Check if we were notified immediately after setting WAITING
125                if self.state.load(Ordering::Acquire) == NOTIFIED {
126                    // Race: notified between CAS and this check
127                    self.state.store(EMPTY, Ordering::Release);
128                    true
129                } else {
130                    false
131                }
132            }
133            Err(state) => {
134                // State changed, check what it is now
135                if state == NOTIFIED {
136                    // Already notified
137                    self.state.store(EMPTY, Ordering::Release);
138                    true
139                } else {
140                    // State is WAITING (subsequent poll updating waker)
141                    // Check if notified after waker update
142                    if self.state.load(Ordering::Acquire) == NOTIFIED {
143                        self.state.store(EMPTY, Ordering::Release);
144                        true
145                    } else {
146                        false
147                    }
148                }
149            }
150        }
151    }
152}
153
154// Drop is automatically handled by AtomicWaker's drop implementation
155// No need for explicit drop implementation
156//
157// Drop 由 AtomicWaker 的 drop 实现自动处理
158// 无需显式的 drop 实现
159
160/// Future returned by `SingleWaiterNotify::notified()`
161///
162/// `SingleWaiterNotify::notified()` 返回的 Future
163pub struct Notified<'a> {
164    notify: &'a SingleWaiterNotify,
165    registered: bool,
166}
167
168impl<'a> core::fmt::Debug for Notified<'a> {
169    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
170        let state = self.notify.state.load(Ordering::Acquire);
171        let state_str = match state {
172            EMPTY => "Empty",
173            WAITING => "Waiting",
174            NOTIFIED => "Notified",
175            _ => "Unknown",
176        };
177        f.debug_struct("Notified")
178            .field("state", &state_str)
179            .field("registered", &self.registered)
180            .finish()
181    }
182}
183
184impl Future for Notified<'_> {
185    type Output = ();
186
187    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
188        // On first poll, register the waker
189        if !self.registered {
190            self.registered = true;
191            if self.notify.register_waker(cx.waker()) {
192                // Already notified (fast path)
193                return Poll::Ready(());
194            }
195        } else {
196            // On subsequent polls, check if notified
197            if self.notify.state.load(Ordering::Acquire) == NOTIFIED {
198                self.notify.state.store(EMPTY, Ordering::Release);
199                return Poll::Ready(());
200            }
201            // Update waker in case it changed
202            // IMPORTANT: Check return value - may have been notified during registration
203            if self.notify.register_waker(cx.waker()) {
204                return Poll::Ready(());
205            }
206        }
207
208        Poll::Pending
209    }
210}
211
212impl Drop for Notified<'_> {
213    fn drop(&mut self) {
214        if self.registered {
215            // If we registered but are being dropped, try to clean up
216            // PERFORMANCE: Direct compare_exchange without pre-check:
217            // - Single atomic operation instead of two (load + compare_exchange)
218            // - Relaxed ordering is sufficient - just cleaning up state
219            // - CAS will fail harmlessly if state is not WAITING
220            //
221            // 如果我们注册了但正在被 drop,尝试清理
222            // 性能优化:直接 compare_exchange 无需预检查:
223            // - 单次原子操作而不是两次(load + compare_exchange)
224            // - Relaxed ordering 就足够了 - 只是清理状态
225            // - 如果状态不是 WAITING,CAS 会无害地失败
226            let _ = self.notify.state.compare_exchange(
227                WAITING,
228                EMPTY,
229                Ordering::Relaxed,
230                Ordering::Relaxed,
231            );
232        }
233    }
234}
235
236#[cfg(all(test, not(feature = "loom")))]
237mod tests {
238    use super::*;
239    use std::sync::Arc;
240    use tokio::time::{Duration, sleep};
241
242    #[tokio::test]
243    async fn test_notify_before_wait() {
244        let notify = Arc::new(SingleWaiterNotify::new());
245
246        // Notify before waiting
247        notify.notify_one();
248
249        // Should complete immediately
250        notify.notified().await;
251    }
252
253    #[tokio::test]
254    async fn test_notify_after_wait() {
255        let notify = Arc::new(SingleWaiterNotify::new());
256        let notify_clone = notify.clone();
257
258        // Spawn a task that notifies after a delay
259        tokio::spawn(async move {
260            sleep(Duration::from_millis(10)).await;
261            notify_clone.notify_one();
262        });
263
264        // Wait for notification
265        notify.notified().await;
266    }
267
268    #[tokio::test]
269    async fn test_multiple_notify_cycles() {
270        let notify = Arc::new(SingleWaiterNotify::new());
271
272        for _ in 0..10 {
273            let notify_clone = notify.clone();
274            tokio::spawn(async move {
275                sleep(Duration::from_millis(5)).await;
276                notify_clone.notify_one();
277            });
278
279            notify.notified().await;
280        }
281    }
282
283    #[tokio::test]
284    async fn test_concurrent_notify() {
285        let notify = Arc::new(SingleWaiterNotify::new());
286        let notify_clone = notify.clone();
287
288        // Multiple notifiers (only one should wake the waiter)
289        for _ in 0..5 {
290            let n = notify_clone.clone();
291            tokio::spawn(async move {
292                sleep(Duration::from_millis(10)).await;
293                n.notify_one();
294            });
295        }
296
297        notify.notified().await;
298    }
299
300    #[tokio::test]
301    async fn test_notify_no_waiter() {
302        let notify = SingleWaiterNotify::new();
303
304        // Notify with no waiter should not panic
305        notify.notify_one();
306        notify.notify_one();
307
308        // Next wait should complete immediately
309        notify.notified().await;
310    }
311
312    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
313    async fn test_stress_test() {
314        let notify = Arc::new(SingleWaiterNotify::new());
315
316        for i in 0..100 {
317            let notify_clone = notify.clone();
318            tokio::spawn(async move {
319                sleep(Duration::from_micros(i % 10)).await;
320                notify_clone.notify_one();
321            });
322
323            notify.notified().await;
324        }
325    }
326
327    #[tokio::test]
328    async fn test_immediate_notification_race() {
329        // Test the race between notification and registration
330        for _ in 0..100 {
331            let notify = Arc::new(SingleWaiterNotify::new());
332            let notify_clone = notify.clone();
333
334            let waiter = tokio::spawn(async move {
335                notify.notified().await;
336            });
337
338            // Notify immediately (might happen before or after registration)
339            notify_clone.notify_one();
340
341            // Should complete without timeout
342            tokio::time::timeout(Duration::from_millis(100), waiter)
343                .await
344                .expect("Should not timeout")
345                .expect("Task should complete");
346        }
347    }
348}