lite_sync/atomic_waker.rs
1/// Atomic waker storage using state machine for safe concurrent access
2///
3/// Based on Tokio's AtomicWaker but simplified for our use cases.
4/// Uses UnsafeCell<Option<Waker>> + atomic state machine to avoid Box allocation
5/// while maintaining safe concurrent access.
6///
7/// 使用状态机进行安全并发访问的原子 waker 存储
8///
9/// 基于 Tokio 的 AtomicWaker 但为我们的用例简化。
10/// 使用 UnsafeCell<Option<Waker>> + 原子状态机避免 Box 分配
11/// 同时保持安全的并发访问。
12
13use std::cell::UnsafeCell;
14use std::sync::atomic::{AtomicUsize, Ordering};
15use std::task::Waker;
16
17// Waker registration states
18const WAITING: usize = 0;
19const REGISTERING: usize = 0b01;
20const WAKING: usize = 0b10;
21
22/// Atomic waker storage with state machine synchronization
23///
24/// 带有状态机同步的原子 waker 存储
25pub(crate) struct AtomicWaker {
26 state: AtomicUsize,
27 waker: UnsafeCell<Option<Waker>>,
28}
29
30// SAFETY: AtomicWaker is Sync because access to waker is synchronized via atomic state machine
31unsafe impl Sync for AtomicWaker {}
32unsafe impl Send for AtomicWaker {}
33
34impl AtomicWaker {
35 /// Create a new atomic waker
36 ///
37 /// 创建一个新的原子 waker
38 #[inline]
39 pub(crate) const fn new() -> Self {
40 Self {
41 state: AtomicUsize::new(WAITING),
42 waker: UnsafeCell::new(None),
43 }
44 }
45
46 /// Register a waker to be notified
47 ///
48 /// This will store the waker and handle concurrent access safely.
49 /// If a concurrent wake happens during registration, the newly
50 /// registered waker will be woken immediately.
51 ///
52 /// 注册一个要通知的 waker
53 ///
54 /// 这将存储 waker 并安全地处理并发访问。
55 /// 如果在注册期间发生并发唤醒,新注册的 waker 将立即被唤醒。
56 #[inline]
57 pub(crate) fn register(&self, waker: &Waker) {
58 match self.state.compare_exchange(
59 WAITING,
60 REGISTERING,
61 Ordering::Acquire,
62 Ordering::Acquire,
63 ) {
64 Ok(_) => {
65 // SAFETY: We have exclusive access via REGISTERING lock
66 unsafe {
67 // Replace the waker
68 let old_waker = (*self.waker.get()).replace(waker.clone());
69
70 // Try to release the lock
71 match self.state.compare_exchange(
72 REGISTERING,
73 WAITING,
74 Ordering::AcqRel,
75 Ordering::Acquire,
76 ) {
77 Ok(_) => {
78 // Successfully released, just drop old waker
79 drop(old_waker);
80 }
81 Err(_) => {
82 // Concurrent wake happened, take waker and wake it
83 // State must be REGISTERING | WAKING
84 let waker = (*self.waker.get()).take();
85 self.state.store(WAITING, Ordering::Release);
86
87 drop(old_waker);
88 if let Some(waker) = waker {
89 waker.wake();
90 }
91 }
92 }
93 }
94 }
95 Err(WAKING) => {
96 // Currently waking, just wake the new waker directly
97 waker.wake_by_ref();
98 }
99 Err(_) => {
100 // Concurrent register (shouldn't happen in normal usage)
101 // Just drop this registration
102 }
103 }
104 }
105
106 /// Take the waker out for waking
107 ///
108 /// Returns the waker if one was registered, None otherwise.
109 /// This atomically removes the waker from storage.
110 ///
111 /// 取出 waker 用于唤醒
112 ///
113 /// 如果注册了 waker 则返回它,否则返回 None。
114 /// 这会原子地从存储中移除 waker。
115 #[inline]
116 pub(crate) fn take(&self) -> Option<Waker> {
117 match self.state.fetch_or(WAKING, Ordering::AcqRel) {
118 WAITING => {
119 // SAFETY: We have exclusive access via WAKING lock
120 let waker = unsafe { (*self.waker.get()).take() };
121
122 // Release the lock
123 self.state.store(WAITING, Ordering::Release);
124
125 waker
126 }
127 _ => {
128 // Concurrent register or wake in progress
129 None
130 }
131 }
132 }
133
134 /// Wake the registered waker if any
135 ///
136 /// 唤醒已注册的 waker(如果有)
137 #[inline]
138 pub(crate) fn wake(&self) {
139 if let Some(waker) = self.take() {
140 waker.wake();
141 }
142 }
143}
144
145impl Drop for AtomicWaker {
146 fn drop(&mut self) {
147 // SAFETY: We have exclusive access during drop
148 unsafe {
149 let _ = (*self.waker.get()).take();
150 }
151 }
152}
153
154impl std::fmt::Debug for AtomicWaker {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 f.debug_struct("AtomicWaker").finish()
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use std::sync::Arc;
164
165 #[test]
166 fn test_basic_register_and_take() {
167 let atomic_waker = AtomicWaker::new();
168 let waker = futures::task::noop_waker();
169
170 atomic_waker.register(&waker);
171 let taken = atomic_waker.take();
172 assert!(taken.is_some());
173
174 // Second take should return None
175 let taken2 = atomic_waker.take();
176 assert!(taken2.is_none());
177 }
178
179 #[test]
180 fn test_concurrent_access() {
181 use std::thread;
182
183 let atomic_waker = Arc::new(AtomicWaker::new());
184 let waker = futures::task::noop_waker();
185
186 let aw1 = atomic_waker.clone();
187 let w1 = waker.clone();
188 let h1 = thread::spawn(move || {
189 for _ in 0..100 {
190 aw1.register(&w1);
191 }
192 });
193
194 let aw2 = atomic_waker.clone();
195 let h2 = thread::spawn(move || {
196 for _ in 0..100 {
197 aw2.take();
198 }
199 });
200
201 h1.join().unwrap();
202 h2.join().unwrap();
203 }
204}
205