lite_sync/
notify.rs

1/// Lightweight single-waiter notification primitive
2/// 
3/// Optimized for SPSC (Single Producer Single Consumer) pattern where
4/// only one task waits at a time. Much lighter than tokio::sync::Notify.
5/// 
6/// 轻量级单等待者通知原语
7/// 
8/// 为 SPSC(单生产者单消费者)模式优化,其中每次只有一个任务等待。
9/// 比 tokio::sync::Notify 更轻量。
10use std::sync::atomic::{AtomicU8, Ordering};
11use std::future::Future;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14
15use super::atomic_waker::AtomicWaker;
16
17// States for the notification
18const EMPTY: u8 = 0;      // No waiter, no notification
19const WAITING: u8 = 1;    // Waiter registered
20const NOTIFIED: u8 = 2;   // Notification sent
21
22/// Lightweight single-waiter notifier optimized for SPSC pattern
23/// 
24/// 为 SPSC 模式优化的轻量级单等待者通知器
25pub struct SingleWaiterNotify {
26    state: AtomicU8,
27    waker: AtomicWaker,
28}
29
30impl std::fmt::Debug for SingleWaiterNotify {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        let state = self.state.load(Ordering::Acquire);
33        let state_str = match state {
34            EMPTY => "Empty",
35            WAITING => "Waiting",
36            NOTIFIED => "Notified",
37            _ => "Unknown",
38        };
39        f.debug_struct("SingleWaiterNotify")
40            .field("state", &state_str)
41            .finish()
42    }
43}
44
45impl Default for SingleWaiterNotify {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl SingleWaiterNotify {
52    /// Create a new single-waiter notifier
53    /// 
54    /// 创建一个新的单等待者通知器
55    #[inline]
56    pub fn new() -> Self {
57        Self {
58            state: AtomicU8::new(EMPTY),
59            waker: AtomicWaker::new(),
60        }
61    }
62    
63    /// Returns a future that completes when notified
64    /// 
65    /// 返回一个在收到通知时完成的 future
66    #[inline]
67    pub fn notified(&self) -> Notified<'_> {
68        Notified { 
69            notify: self,
70            registered: false,
71        }
72    }
73    
74    /// Wake the waiting task (if any)
75    /// 
76    /// If called before wait, the next wait will complete immediately.
77    /// 
78    /// 唤醒等待的任务(如果有)
79    /// 
80    /// 如果在等待之前调用,下一次等待将立即完成。
81    #[inline]
82    pub fn notify_one(&self) {
83        // Mark as notified
84        let prev_state = self.state.swap(NOTIFIED, Ordering::AcqRel);
85        
86        // If there was a waiter, wake it
87        if prev_state == WAITING {
88            self.waker.wake();
89        }
90    }
91    
92    /// Register a waker to be notified
93    /// 
94    /// Returns true if already notified (fast path)
95    /// 
96    /// 注册一个 waker 以接收通知
97    /// 
98    /// 如果已经被通知则返回 true(快速路径)
99    #[inline]
100    fn register_waker(&self, waker: &std::task::Waker) -> bool {
101        // CRITICAL: Register waker FIRST, before changing state to WAITING
102        // This prevents the race where notify_one() sees WAITING but waker isn't registered yet
103        //
104        // 关键:先注册 waker,再将状态改为 WAITING
105        // 这可以防止 notify_one() 看到 WAITING 但 waker 还未注册的竞态条件
106        self.waker.register(waker);
107        
108        let current_state = self.state.load(Ordering::Acquire);
109        
110        // Fast path: already notified
111        if current_state == NOTIFIED {
112            // Reset to EMPTY for next wait
113            self.state.store(EMPTY, Ordering::Release);
114            return true;
115        }
116        
117        // Try to transition from EMPTY to WAITING
118        match self.state.compare_exchange(
119            EMPTY,
120            WAITING,
121            Ordering::AcqRel,
122            Ordering::Acquire,
123        ) {
124            Ok(_) => {
125                // Successfully transitioned to WAITING
126                // Check if we were notified immediately after setting WAITING
127                if self.state.load(Ordering::Acquire) == NOTIFIED {
128                    // Race: notified between CAS and this check
129                    self.state.store(EMPTY, Ordering::Release);
130                    true
131                } else {
132                    false
133                }
134            }
135            Err(state) => {
136                // State changed, check what it is now
137                if state == NOTIFIED {
138                    // Already notified
139                    self.state.store(EMPTY, Ordering::Release);
140                    true
141                } else {
142                    // State is WAITING (subsequent poll updating waker)
143                    // Check if notified after waker update
144                    if self.state.load(Ordering::Acquire) == NOTIFIED {
145                        self.state.store(EMPTY, Ordering::Release);
146                        true
147                    } else {
148                        false
149                    }
150                }
151            }
152        }
153    }
154}
155
156// Drop is automatically handled by AtomicWaker's drop implementation
157// No need for explicit drop implementation
158//
159// Drop 由 AtomicWaker 的 drop 实现自动处理
160// 无需显式的 drop 实现
161
162/// Future returned by `SingleWaiterNotify::notified()`
163/// 
164/// `SingleWaiterNotify::notified()` 返回的 Future
165pub struct Notified<'a> {
166    notify: &'a SingleWaiterNotify,
167    registered: bool,
168}
169
170impl<'a> std::fmt::Debug for Notified<'a> {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        let state = self.notify.state.load(Ordering::Acquire);
173        let state_str = match state {
174            EMPTY => "Empty",
175            WAITING => "Waiting",
176            NOTIFIED => "Notified",
177            _ => "Unknown",
178        };
179        f.debug_struct("Notified")
180            .field("state", &state_str)
181            .field("registered", &self.registered)
182            .finish()
183    }
184}
185
186impl Future for Notified<'_> {
187    type Output = ();
188    
189    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
190        // On first poll, register the waker
191        if !self.registered {
192            self.registered = true;
193            if self.notify.register_waker(cx.waker()) {
194                // Already notified (fast path)
195                return Poll::Ready(());
196            }
197        } else {
198            // On subsequent polls, check if notified
199            if self.notify.state.load(Ordering::Acquire) == NOTIFIED {
200                self.notify.state.store(EMPTY, Ordering::Release);
201                return Poll::Ready(());
202            }
203            // Update waker in case it changed
204            // IMPORTANT: Check return value - may have been notified during registration
205            if self.notify.register_waker(cx.waker()) {
206                return Poll::Ready(());
207            }
208        }
209        
210        Poll::Pending
211    }
212}
213
214impl Drop for Notified<'_> {
215    fn drop(&mut self) {
216        if self.registered {
217            // If we registered but are being dropped, try to clean up
218            // PERFORMANCE: Direct compare_exchange without pre-check:
219            // - Single atomic operation instead of two (load + compare_exchange)
220            // - Relaxed ordering is sufficient - just cleaning up state
221            // - CAS will fail harmlessly if state is not WAITING
222            //
223            // 如果我们注册了但正在被 drop,尝试清理
224            // 性能优化:直接 compare_exchange 无需预检查:
225            // - 单次原子操作而不是两次(load + compare_exchange)
226            // - Relaxed ordering 就足够了 - 只是清理状态
227            // - 如果状态不是 WAITING,CAS 会无害地失败
228            let _ = self.notify.state.compare_exchange(
229                WAITING,
230                EMPTY,
231                Ordering::Relaxed,
232                Ordering::Relaxed,
233            );
234        }
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use std::sync::Arc;
242    use tokio::time::{sleep, Duration};
243
244    #[tokio::test]
245    async fn test_notify_before_wait() {
246        let notify = Arc::new(SingleWaiterNotify::new());
247        
248        // Notify before waiting
249        notify.notify_one();
250        
251        // Should complete immediately
252        notify.notified().await;
253    }
254
255    #[tokio::test]
256    async fn test_notify_after_wait() {
257        let notify = Arc::new(SingleWaiterNotify::new());
258        let notify_clone = notify.clone();
259        
260        // Spawn a task that notifies after a delay
261        tokio::spawn(async move {
262            sleep(Duration::from_millis(10)).await;
263            notify_clone.notify_one();
264        });
265        
266        // Wait for notification
267        notify.notified().await;
268    }
269
270    #[tokio::test]
271    async fn test_multiple_notify_cycles() {
272        let notify = Arc::new(SingleWaiterNotify::new());
273        
274        for _ in 0..10 {
275            let notify_clone = notify.clone();
276            tokio::spawn(async move {
277                sleep(Duration::from_millis(5)).await;
278                notify_clone.notify_one();
279            });
280            
281            notify.notified().await;
282        }
283    }
284
285    #[tokio::test]
286    async fn test_concurrent_notify() {
287        let notify = Arc::new(SingleWaiterNotify::new());
288        let notify_clone = notify.clone();
289        
290        // Multiple notifiers (only one should wake the waiter)
291        for _ in 0..5 {
292            let n = notify_clone.clone();
293            tokio::spawn(async move {
294                sleep(Duration::from_millis(10)).await;
295                n.notify_one();
296            });
297        }
298        
299        notify.notified().await;
300    }
301
302    #[tokio::test]
303    async fn test_notify_no_waiter() {
304        let notify = SingleWaiterNotify::new();
305        
306        // Notify with no waiter should not panic
307        notify.notify_one();
308        notify.notify_one();
309        
310        // Next wait should complete immediately
311        notify.notified().await;
312    }
313
314    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
315    async fn test_stress_test() {
316        let notify = Arc::new(SingleWaiterNotify::new());
317        
318        for i in 0..100 {
319            let notify_clone = notify.clone();
320            tokio::spawn(async move {
321                sleep(Duration::from_micros(i % 10)).await;
322                notify_clone.notify_one();
323            });
324            
325            notify.notified().await;
326        }
327    }
328
329    #[tokio::test]
330    async fn test_immediate_notification_race() {
331        // Test the race between notification and registration
332        for _ in 0..100 {
333            let notify = Arc::new(SingleWaiterNotify::new());
334            let notify_clone = notify.clone();
335            
336            let waiter = tokio::spawn(async move {
337                notify.notified().await;
338            });
339            
340            // Notify immediately (might happen before or after registration)
341            notify_clone.notify_one();
342            
343            // Should complete without timeout
344            tokio::time::timeout(Duration::from_millis(100), waiter)
345                .await
346                .expect("Should not timeout")
347                .expect("Task should complete");
348        }
349    }
350}
351