kestrel_timer/utils/notify.rs
1/// Lightweight single-waiter notification primitive
2///
3/// Optimized for SPSC (Single Producer Single Consumer) pattern where
4/// only one task waits at a time. Much lighter than tokio::sync::Notify.
5///
6/// 轻量级单等待者通知原语
7///
8/// 为 SPSC(单生产者单消费者)模式优化,其中每次只有一个任务等待。
9/// 比 tokio::sync::Notify 更轻量。
10
11use std::sync::atomic::{AtomicU8, AtomicPtr, Ordering};
12use std::ptr;
13use std::future::Future;
14use std::pin::Pin;
15use std::task::{Context, Poll, Waker};
16
17// States for the notification
18const EMPTY: u8 = 0; // No waiter, no notification
19const WAITING: u8 = 1; // Waiter registered
20const NOTIFIED: u8 = 2; // Notification sent
21
22/// Lightweight single-waiter notifier optimized for SPSC pattern
23///
24/// Much lighter than tokio::sync::Notify:
25/// - No waitlist allocation (just one atomic pointer + one atomic state)
26/// - Direct waker management (no intermediate state machine)
27/// - Faster creation and notification
28/// - Handles notification-before-wait correctly
29///
30/// 为 SPSC 模式优化的轻量级单等待者通知器
31///
32/// 比 tokio::sync::Notify 更轻量:
33/// - 无需等待列表分配(仅一个原子指针 + 一个原子状态)
34/// - 直接管理 waker(无复杂状态机)
35/// - 更快的创建和通知速度
36/// - 正确处理通知先于等待的情况
37pub struct SingleWaiterNotify {
38 state: AtomicU8,
39 waker: AtomicPtr<Waker>,
40}
41
42impl SingleWaiterNotify {
43 /// Create a new single-waiter notifier
44 ///
45 /// 创建一个新的单等待者通知器
46 #[inline]
47 pub fn new() -> Self {
48 Self {
49 state: AtomicU8::new(EMPTY),
50 waker: AtomicPtr::new(ptr::null_mut()),
51 }
52 }
53
54 /// Returns a future that completes when notified
55 ///
56 /// 返回一个在收到通知时完成的 future
57 #[inline]
58 pub fn notified(&self) -> Notified<'_> {
59 Notified {
60 notify: self,
61 registered: false,
62 }
63 }
64
65 /// Wake the waiting task (if any)
66 ///
67 /// If called before wait, the next wait will complete immediately.
68 ///
69 /// 唤醒等待的任务(如果有)
70 ///
71 /// 如果在等待之前调用,下一次等待将立即完成。
72 #[inline]
73 pub fn notify_one(&self) {
74 // Mark as notified
75 let prev_state = self.state.swap(NOTIFIED, Ordering::AcqRel);
76
77 // If there was a waiter, wake it
78 if prev_state == WAITING {
79 let waker_ptr = self.waker.swap(ptr::null_mut(), Ordering::AcqRel);
80 if !waker_ptr.is_null() {
81 // SAFETY: This pointer was created by Box::into_raw in register_waker
82 unsafe {
83 let waker = Box::from_raw(waker_ptr);
84 waker.wake();
85 }
86 }
87 }
88 }
89
90 /// Register a waker to be notified
91 ///
92 /// Returns true if already notified (fast path)
93 ///
94 /// 注册一个 waker 以接收通知
95 ///
96 /// 如果已经被通知则返回 true(快速路径)
97 #[inline]
98 fn register_waker(&self, waker: &Waker) -> bool {
99 // Try to transition from EMPTY to WAITING
100 match self.state.compare_exchange(
101 EMPTY,
102 WAITING,
103 Ordering::AcqRel,
104 Ordering::Acquire,
105 ) {
106 Ok(_) => {
107 // Successfully transitioned, store the waker
108 let new_waker = Box::into_raw(Box::new(waker.clone()));
109 let old_waker = self.waker.swap(new_waker, Ordering::AcqRel);
110
111 if !old_waker.is_null() {
112 // SAFETY: This pointer was created by Box::into_raw
113 unsafe {
114 let _ = Box::from_raw(old_waker);
115 }
116 }
117 false // Not notified yet
118 }
119 Err(state) => {
120 // Already notified or waiting
121 if state == NOTIFIED {
122 // Reset to EMPTY for next wait
123 self.state.store(EMPTY, Ordering::Release);
124 true // Already notified
125 } else {
126 // State is WAITING, update the waker
127 let new_waker = Box::into_raw(Box::new(waker.clone()));
128 let old_waker = self.waker.swap(new_waker, Ordering::AcqRel);
129
130 if !old_waker.is_null() {
131 // SAFETY: This pointer was created by Box::into_raw
132 unsafe {
133 let _ = Box::from_raw(old_waker);
134 }
135 }
136
137 // Check if notified while we were updating waker
138 if self.state.load(Ordering::Acquire) == NOTIFIED {
139 self.state.store(EMPTY, Ordering::Release);
140 true
141 } else {
142 false
143 }
144 }
145 }
146 }
147 }
148}
149
150impl Drop for SingleWaiterNotify {
151 fn drop(&mut self) {
152 // Clean up any remaining waker
153 let waker_ptr = self.waker.load(Ordering::Acquire);
154 if !waker_ptr.is_null() {
155 // SAFETY: This pointer was created by Box::into_raw
156 unsafe {
157 let _ = Box::from_raw(waker_ptr);
158 }
159 }
160 }
161}
162
163/// Future returned by `SingleWaiterNotify::notified()`
164///
165/// `SingleWaiterNotify::notified()` 返回的 Future
166pub struct Notified<'a> {
167 notify: &'a SingleWaiterNotify,
168 registered: bool,
169}
170
171impl Future for Notified<'_> {
172 type Output = ();
173
174 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
175 // On first poll, register the waker
176 if !self.registered {
177 self.registered = true;
178 if self.notify.register_waker(cx.waker()) {
179 // Already notified (fast path)
180 return Poll::Ready(());
181 }
182 } else {
183 // On subsequent polls, check if notified
184 if self.notify.state.load(Ordering::Acquire) == NOTIFIED {
185 self.notify.state.store(EMPTY, Ordering::Release);
186 return Poll::Ready(());
187 }
188 // Update waker in case it changed
189 self.notify.register_waker(cx.waker());
190 }
191
192 Poll::Pending
193 }
194}
195
196impl Drop for Notified<'_> {
197 fn drop(&mut self) {
198 if self.registered {
199 // If we registered but are being dropped, try to clean up
200 if self.notify.state.load(Ordering::Acquire) == WAITING {
201 // Try to transition back to EMPTY
202 let _ = self.notify.state.compare_exchange(
203 WAITING,
204 EMPTY,
205 Ordering::AcqRel,
206 Ordering::Acquire,
207 );
208 }
209 }
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use std::sync::Arc;
217 use tokio::time::{sleep, Duration};
218
219 #[tokio::test]
220 async fn test_notify_before_wait() {
221 let notify = Arc::new(SingleWaiterNotify::new());
222
223 // Notify before waiting
224 notify.notify_one();
225
226 // Should complete immediately
227 notify.notified().await;
228 }
229
230 #[tokio::test]
231 async fn test_notify_after_wait() {
232 let notify = Arc::new(SingleWaiterNotify::new());
233 let notify_clone = notify.clone();
234
235 // Spawn a task that notifies after a delay
236 tokio::spawn(async move {
237 sleep(Duration::from_millis(10)).await;
238 notify_clone.notify_one();
239 });
240
241 // Wait for notification
242 notify.notified().await;
243 }
244
245 #[tokio::test]
246 async fn test_multiple_notify_cycles() {
247 let notify = Arc::new(SingleWaiterNotify::new());
248
249 for _ in 0..10 {
250 let notify_clone = notify.clone();
251 tokio::spawn(async move {
252 sleep(Duration::from_millis(5)).await;
253 notify_clone.notify_one();
254 });
255
256 notify.notified().await;
257 }
258 }
259
260 #[tokio::test]
261 async fn test_concurrent_notify() {
262 let notify = Arc::new(SingleWaiterNotify::new());
263 let notify_clone = notify.clone();
264
265 // Multiple notifiers (only one should wake the waiter)
266 for _ in 0..5 {
267 let n = notify_clone.clone();
268 tokio::spawn(async move {
269 sleep(Duration::from_millis(10)).await;
270 n.notify_one();
271 });
272 }
273
274 notify.notified().await;
275 }
276
277 #[tokio::test]
278 async fn test_notify_no_waiter() {
279 let notify = SingleWaiterNotify::new();
280
281 // Notify with no waiter should not panic
282 notify.notify_one();
283 notify.notify_one();
284
285 // Next wait should complete immediately
286 notify.notified().await;
287 }
288
289 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
290 async fn test_stress_test() {
291 let notify = Arc::new(SingleWaiterNotify::new());
292
293 for i in 0..100 {
294 let notify_clone = notify.clone();
295 tokio::spawn(async move {
296 sleep(Duration::from_micros(i % 10)).await;
297 notify_clone.notify_one();
298 });
299
300 notify.notified().await;
301 }
302 }
303
304 #[tokio::test]
305 async fn test_immediate_notification_race() {
306 // Test the race between notification and registration
307 for _ in 0..100 {
308 let notify = Arc::new(SingleWaiterNotify::new());
309 let notify_clone = notify.clone();
310
311 let waiter = tokio::spawn(async move {
312 notify.notified().await;
313 });
314
315 // Notify immediately (might happen before or after registration)
316 notify_clone.notify_one();
317
318 // Should complete without timeout
319 tokio::time::timeout(Duration::from_millis(100), waiter)
320 .await
321 .expect("Should not timeout")
322 .expect("Task should complete");
323 }
324 }
325}
326