lite_sync/
notify.rs

1/// Lightweight single-waiter notification primitive
2/// 
3/// Optimized for SPSC (Single Producer Single Consumer) pattern where
4/// only one task waits at a time. Much lighter than tokio::sync::Notify.
5/// 
6/// 轻量级单等待者通知原语
7/// 
8/// 为 SPSC(单生产者单消费者)模式优化,其中每次只有一个任务等待。
9/// 比 tokio::sync::Notify 更轻量。
10use std::sync::atomic::{AtomicU8, Ordering};
11use std::future::Future;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14
15use super::atomic_waker::AtomicWaker;
16
17// States for the notification
18const EMPTY: u8 = 0;      // No waiter, no notification
19const WAITING: u8 = 1;    // Waiter registered
20const NOTIFIED: u8 = 2;   // Notification sent
21
22/// Lightweight single-waiter notifier optimized for SPSC pattern
23/// 
24/// 为 SPSC 模式优化的轻量级单等待者通知器
25pub struct SingleWaiterNotify {
26    state: AtomicU8,
27    waker: AtomicWaker,
28}
29
30impl Default for SingleWaiterNotify {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl SingleWaiterNotify {
37    /// Create a new single-waiter notifier
38    /// 
39    /// 创建一个新的单等待者通知器
40    #[inline]
41    pub fn new() -> Self {
42        Self {
43            state: AtomicU8::new(EMPTY),
44            waker: AtomicWaker::new(),
45        }
46    }
47    
48    /// Returns a future that completes when notified
49    /// 
50    /// 返回一个在收到通知时完成的 future
51    #[inline]
52    pub fn notified(&self) -> Notified<'_> {
53        Notified { 
54            notify: self,
55            registered: false,
56        }
57    }
58    
59    /// Wake the waiting task (if any)
60    /// 
61    /// If called before wait, the next wait will complete immediately.
62    /// 
63    /// 唤醒等待的任务(如果有)
64    /// 
65    /// 如果在等待之前调用,下一次等待将立即完成。
66    #[inline]
67    pub fn notify_one(&self) {
68        // Mark as notified
69        let prev_state = self.state.swap(NOTIFIED, Ordering::AcqRel);
70        
71        // If there was a waiter, wake it
72        if prev_state == WAITING {
73            self.waker.wake();
74        }
75    }
76    
77    /// Register a waker to be notified
78    /// 
79    /// Returns true if already notified (fast path)
80    /// 
81    /// 注册一个 waker 以接收通知
82    /// 
83    /// 如果已经被通知则返回 true(快速路径)
84    #[inline]
85    fn register_waker(&self, waker: &std::task::Waker) -> bool {
86        // Try to transition from EMPTY to WAITING
87        match self.state.compare_exchange(
88            EMPTY,
89            WAITING,
90            Ordering::AcqRel,
91            Ordering::Acquire,
92        ) {
93            Ok(_) => {
94                // Successfully transitioned, store the waker
95                self.waker.register(waker);
96                false // Not notified yet
97            }
98            Err(state) => {
99                // Already notified or waiting
100                if state == NOTIFIED {
101                    // Reset to EMPTY for next wait
102                    self.state.store(EMPTY, Ordering::Release);
103                    true // Already notified
104                } else {
105                    // State is WAITING, update the waker
106                    self.waker.register(waker);
107                    
108                    // Check if notified while we were updating waker
109                    if self.state.load(Ordering::Acquire) == NOTIFIED {
110                        self.state.store(EMPTY, Ordering::Release);
111                        true
112                    } else {
113                        false
114                    }
115                }
116            }
117        }
118    }
119}
120
121// Drop is automatically handled by AtomicWaker's drop implementation
122// No need for explicit drop implementation
123//
124// Drop 由 AtomicWaker 的 drop 实现自动处理
125// 无需显式的 drop 实现
126
127/// Future returned by `SingleWaiterNotify::notified()`
128/// 
129/// `SingleWaiterNotify::notified()` 返回的 Future
130pub struct Notified<'a> {
131    notify: &'a SingleWaiterNotify,
132    registered: bool,
133}
134
135impl Future for Notified<'_> {
136    type Output = ();
137    
138    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
139        // On first poll, register the waker
140        if !self.registered {
141            self.registered = true;
142            if self.notify.register_waker(cx.waker()) {
143                // Already notified (fast path)
144                return Poll::Ready(());
145            }
146        } else {
147            // On subsequent polls, check if notified
148            if self.notify.state.load(Ordering::Acquire) == NOTIFIED {
149                self.notify.state.store(EMPTY, Ordering::Release);
150                return Poll::Ready(());
151            }
152            // Update waker in case it changed
153            self.notify.register_waker(cx.waker());
154        }
155        
156        Poll::Pending
157    }
158}
159
160impl Drop for Notified<'_> {
161    fn drop(&mut self) {
162        if self.registered {
163            // If we registered but are being dropped, try to clean up
164            // PERFORMANCE: Direct compare_exchange without pre-check:
165            // - Single atomic operation instead of two (load + compare_exchange)
166            // - Relaxed ordering is sufficient - just cleaning up state
167            // - CAS will fail harmlessly if state is not WAITING
168            //
169            // 如果我们注册了但正在被 drop,尝试清理
170            // 性能优化:直接 compare_exchange 无需预检查:
171            // - 单次原子操作而不是两次(load + compare_exchange)
172            // - Relaxed ordering 就足够了 - 只是清理状态
173            // - 如果状态不是 WAITING,CAS 会无害地失败
174            let _ = self.notify.state.compare_exchange(
175                WAITING,
176                EMPTY,
177                Ordering::Relaxed,
178                Ordering::Relaxed,
179            );
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use std::sync::Arc;
188    use tokio::time::{sleep, Duration};
189
190    #[tokio::test]
191    async fn test_notify_before_wait() {
192        let notify = Arc::new(SingleWaiterNotify::new());
193        
194        // Notify before waiting
195        notify.notify_one();
196        
197        // Should complete immediately
198        notify.notified().await;
199    }
200
201    #[tokio::test]
202    async fn test_notify_after_wait() {
203        let notify = Arc::new(SingleWaiterNotify::new());
204        let notify_clone = notify.clone();
205        
206        // Spawn a task that notifies after a delay
207        tokio::spawn(async move {
208            sleep(Duration::from_millis(10)).await;
209            notify_clone.notify_one();
210        });
211        
212        // Wait for notification
213        notify.notified().await;
214    }
215
216    #[tokio::test]
217    async fn test_multiple_notify_cycles() {
218        let notify = Arc::new(SingleWaiterNotify::new());
219        
220        for _ in 0..10 {
221            let notify_clone = notify.clone();
222            tokio::spawn(async move {
223                sleep(Duration::from_millis(5)).await;
224                notify_clone.notify_one();
225            });
226            
227            notify.notified().await;
228        }
229    }
230
231    #[tokio::test]
232    async fn test_concurrent_notify() {
233        let notify = Arc::new(SingleWaiterNotify::new());
234        let notify_clone = notify.clone();
235        
236        // Multiple notifiers (only one should wake the waiter)
237        for _ in 0..5 {
238            let n = notify_clone.clone();
239            tokio::spawn(async move {
240                sleep(Duration::from_millis(10)).await;
241                n.notify_one();
242            });
243        }
244        
245        notify.notified().await;
246    }
247
248    #[tokio::test]
249    async fn test_notify_no_waiter() {
250        let notify = SingleWaiterNotify::new();
251        
252        // Notify with no waiter should not panic
253        notify.notify_one();
254        notify.notify_one();
255        
256        // Next wait should complete immediately
257        notify.notified().await;
258    }
259
260    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
261    async fn test_stress_test() {
262        let notify = Arc::new(SingleWaiterNotify::new());
263        
264        for i in 0..100 {
265            let notify_clone = notify.clone();
266            tokio::spawn(async move {
267                sleep(Duration::from_micros(i % 10)).await;
268                notify_clone.notify_one();
269            });
270            
271            notify.notified().await;
272        }
273    }
274
275    #[tokio::test]
276    async fn test_immediate_notification_race() {
277        // Test the race between notification and registration
278        for _ in 0..100 {
279            let notify = Arc::new(SingleWaiterNotify::new());
280            let notify_clone = notify.clone();
281            
282            let waiter = tokio::spawn(async move {
283                notify.notified().await;
284            });
285            
286            // Notify immediately (might happen before or after registration)
287            notify_clone.notify_one();
288            
289            // Should complete without timeout
290            tokio::time::timeout(Duration::from_millis(100), waiter)
291                .await
292                .expect("Should not timeout")
293                .expect("Task should complete");
294        }
295    }
296}
297