lite_sync/
notify.rs

1/// Lightweight single-waiter notification primitive
2/// 
3/// Optimized for SPSC (Single Producer Single Consumer) pattern where
4/// only one task waits at a time. Much lighter than tokio::sync::Notify.
5/// 
6/// 轻量级单等待者通知原语
7/// 
8/// 为 SPSC(单生产者单消费者)模式优化,其中每次只有一个任务等待。
9/// 比 tokio::sync::Notify 更轻量。
10use std::sync::atomic::{AtomicU8, Ordering};
11use std::future::Future;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14
15use super::atomic_waker::AtomicWaker;
16
17// States for the notification
18const EMPTY: u8 = 0;      // No waiter, no notification
19const WAITING: u8 = 1;    // Waiter registered
20const NOTIFIED: u8 = 2;   // Notification sent
21
22/// Lightweight single-waiter notifier optimized for SPSC pattern
23/// 
24/// 为 SPSC 模式优化的轻量级单等待者通知器
25pub struct SingleWaiterNotify {
26    state: AtomicU8,
27    waker: AtomicWaker,
28}
29
30impl Default for SingleWaiterNotify {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl SingleWaiterNotify {
37    /// Create a new single-waiter notifier
38    /// 
39    /// 创建一个新的单等待者通知器
40    #[inline]
41    pub fn new() -> Self {
42        Self {
43            state: AtomicU8::new(EMPTY),
44            waker: AtomicWaker::new(),
45        }
46    }
47    
48    /// Returns a future that completes when notified
49    /// 
50    /// 返回一个在收到通知时完成的 future
51    #[inline]
52    pub fn notified(&self) -> Notified<'_> {
53        Notified { 
54            notify: self,
55            registered: false,
56        }
57    }
58    
59    /// Wake the waiting task (if any)
60    /// 
61    /// If called before wait, the next wait will complete immediately.
62    /// 
63    /// 唤醒等待的任务(如果有)
64    /// 
65    /// 如果在等待之前调用,下一次等待将立即完成。
66    #[inline]
67    pub fn notify_one(&self) {
68        // Mark as notified
69        let prev_state = self.state.swap(NOTIFIED, Ordering::AcqRel);
70        
71        // If there was a waiter, wake it
72        if prev_state == WAITING {
73            self.waker.wake();
74        }
75    }
76    
77    /// Register a waker to be notified
78    /// 
79    /// Returns true if already notified (fast path)
80    /// 
81    /// 注册一个 waker 以接收通知
82    /// 
83    /// 如果已经被通知则返回 true(快速路径)
84    #[inline]
85    fn register_waker(&self, waker: &std::task::Waker) -> bool {
86        // CRITICAL: Register waker FIRST, before changing state to WAITING
87        // This prevents the race where notify_one() sees WAITING but waker isn't registered yet
88        //
89        // 关键:先注册 waker,再将状态改为 WAITING
90        // 这可以防止 notify_one() 看到 WAITING 但 waker 还未注册的竞态条件
91        self.waker.register(waker);
92        
93        let current_state = self.state.load(Ordering::Acquire);
94        
95        // Fast path: already notified
96        if current_state == NOTIFIED {
97            // Reset to EMPTY for next wait
98            self.state.store(EMPTY, Ordering::Release);
99            return true;
100        }
101        
102        // Try to transition from EMPTY to WAITING
103        match self.state.compare_exchange(
104            EMPTY,
105            WAITING,
106            Ordering::AcqRel,
107            Ordering::Acquire,
108        ) {
109            Ok(_) => {
110                // Successfully transitioned to WAITING
111                // Check if we were notified immediately after setting WAITING
112                if self.state.load(Ordering::Acquire) == NOTIFIED {
113                    // Race: notified between CAS and this check
114                    self.state.store(EMPTY, Ordering::Release);
115                    true
116                } else {
117                    false
118                }
119            }
120            Err(state) => {
121                // State changed, check what it is now
122                if state == NOTIFIED {
123                    // Already notified
124                    self.state.store(EMPTY, Ordering::Release);
125                    true
126                } else {
127                    // State is WAITING (subsequent poll updating waker)
128                    // Check if notified after waker update
129                    if self.state.load(Ordering::Acquire) == NOTIFIED {
130                        self.state.store(EMPTY, Ordering::Release);
131                        true
132                    } else {
133                        false
134                    }
135                }
136            }
137        }
138    }
139}
140
141// Drop is automatically handled by AtomicWaker's drop implementation
142// No need for explicit drop implementation
143//
144// Drop 由 AtomicWaker 的 drop 实现自动处理
145// 无需显式的 drop 实现
146
147/// Future returned by `SingleWaiterNotify::notified()`
148/// 
149/// `SingleWaiterNotify::notified()` 返回的 Future
150pub struct Notified<'a> {
151    notify: &'a SingleWaiterNotify,
152    registered: bool,
153}
154
155impl Future for Notified<'_> {
156    type Output = ();
157    
158    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
159        // On first poll, register the waker
160        if !self.registered {
161            self.registered = true;
162            if self.notify.register_waker(cx.waker()) {
163                // Already notified (fast path)
164                return Poll::Ready(());
165            }
166        } else {
167            // On subsequent polls, check if notified
168            if self.notify.state.load(Ordering::Acquire) == NOTIFIED {
169                self.notify.state.store(EMPTY, Ordering::Release);
170                return Poll::Ready(());
171            }
172            // Update waker in case it changed
173            // IMPORTANT: Check return value - may have been notified during registration
174            if self.notify.register_waker(cx.waker()) {
175                return Poll::Ready(());
176            }
177        }
178        
179        Poll::Pending
180    }
181}
182
183impl Drop for Notified<'_> {
184    fn drop(&mut self) {
185        if self.registered {
186            // If we registered but are being dropped, try to clean up
187            // PERFORMANCE: Direct compare_exchange without pre-check:
188            // - Single atomic operation instead of two (load + compare_exchange)
189            // - Relaxed ordering is sufficient - just cleaning up state
190            // - CAS will fail harmlessly if state is not WAITING
191            //
192            // 如果我们注册了但正在被 drop,尝试清理
193            // 性能优化:直接 compare_exchange 无需预检查:
194            // - 单次原子操作而不是两次(load + compare_exchange)
195            // - Relaxed ordering 就足够了 - 只是清理状态
196            // - 如果状态不是 WAITING,CAS 会无害地失败
197            let _ = self.notify.state.compare_exchange(
198                WAITING,
199                EMPTY,
200                Ordering::Relaxed,
201                Ordering::Relaxed,
202            );
203        }
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use std::sync::Arc;
211    use tokio::time::{sleep, Duration};
212
213    #[tokio::test]
214    async fn test_notify_before_wait() {
215        let notify = Arc::new(SingleWaiterNotify::new());
216        
217        // Notify before waiting
218        notify.notify_one();
219        
220        // Should complete immediately
221        notify.notified().await;
222    }
223
224    #[tokio::test]
225    async fn test_notify_after_wait() {
226        let notify = Arc::new(SingleWaiterNotify::new());
227        let notify_clone = notify.clone();
228        
229        // Spawn a task that notifies after a delay
230        tokio::spawn(async move {
231            sleep(Duration::from_millis(10)).await;
232            notify_clone.notify_one();
233        });
234        
235        // Wait for notification
236        notify.notified().await;
237    }
238
239    #[tokio::test]
240    async fn test_multiple_notify_cycles() {
241        let notify = Arc::new(SingleWaiterNotify::new());
242        
243        for _ in 0..10 {
244            let notify_clone = notify.clone();
245            tokio::spawn(async move {
246                sleep(Duration::from_millis(5)).await;
247                notify_clone.notify_one();
248            });
249            
250            notify.notified().await;
251        }
252    }
253
254    #[tokio::test]
255    async fn test_concurrent_notify() {
256        let notify = Arc::new(SingleWaiterNotify::new());
257        let notify_clone = notify.clone();
258        
259        // Multiple notifiers (only one should wake the waiter)
260        for _ in 0..5 {
261            let n = notify_clone.clone();
262            tokio::spawn(async move {
263                sleep(Duration::from_millis(10)).await;
264                n.notify_one();
265            });
266        }
267        
268        notify.notified().await;
269    }
270
271    #[tokio::test]
272    async fn test_notify_no_waiter() {
273        let notify = SingleWaiterNotify::new();
274        
275        // Notify with no waiter should not panic
276        notify.notify_one();
277        notify.notify_one();
278        
279        // Next wait should complete immediately
280        notify.notified().await;
281    }
282
283    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
284    async fn test_stress_test() {
285        let notify = Arc::new(SingleWaiterNotify::new());
286        
287        for i in 0..100 {
288            let notify_clone = notify.clone();
289            tokio::spawn(async move {
290                sleep(Duration::from_micros(i % 10)).await;
291                notify_clone.notify_one();
292            });
293            
294            notify.notified().await;
295        }
296    }
297
298    #[tokio::test]
299    async fn test_immediate_notification_race() {
300        // Test the race between notification and registration
301        for _ in 0..100 {
302            let notify = Arc::new(SingleWaiterNotify::new());
303            let notify_clone = notify.clone();
304            
305            let waiter = tokio::spawn(async move {
306                notify.notified().await;
307            });
308            
309            // Notify immediately (might happen before or after registration)
310            notify_clone.notify_one();
311            
312            // Should complete without timeout
313            tokio::time::timeout(Duration::from_millis(100), waiter)
314                .await
315                .expect("Should not timeout")
316                .expect("Task should complete");
317        }
318    }
319}
320