lite_sync/atomic_waker.rs
1/// Atomic waker storage using state machine for safe concurrent access
2///
3/// Based on Tokio's AtomicWaker but simplified for our use cases.
4/// Uses UnsafeCell<Option<Waker>> + atomic state machine to avoid Box allocation
5/// while maintaining safe concurrent access.
6///
7/// 使用状态机进行安全并发访问的原子 waker 存储
8///
9/// 基于 Tokio 的 AtomicWaker 但为我们的用例简化。
10/// 使用 UnsafeCell<Option<Waker>> + 原子状态机避免 Box 分配
11/// 同时保持安全的并发访问。
12use std::cell::UnsafeCell;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::task::Waker;
15
16// Waker registration states
17const WAITING: usize = 0;
18const REGISTERING: usize = 0b01;
19const WAKING: usize = 0b10;
20
21/// Atomic waker storage with state machine synchronization
22///
23/// 带有状态机同步的原子 waker 存储
24pub(crate) struct AtomicWaker {
25 state: AtomicUsize,
26 waker: UnsafeCell<Option<Waker>>,
27}
28
29// SAFETY: AtomicWaker is Sync because access to waker is synchronized via atomic state machine
30unsafe impl Sync for AtomicWaker {}
31unsafe impl Send for AtomicWaker {}
32
33impl AtomicWaker {
34 /// Create a new atomic waker
35 ///
36 /// 创建一个新的原子 waker
37 #[inline]
38 pub(crate) const fn new() -> Self {
39 Self {
40 state: AtomicUsize::new(WAITING),
41 waker: UnsafeCell::new(None),
42 }
43 }
44
45 /// Register a waker to be notified
46 ///
47 /// This will store the waker and handle concurrent access safely.
48 /// If a concurrent wake happens during registration, the newly
49 /// registered waker will be woken immediately.
50 ///
51 /// 注册一个要通知的 waker
52 ///
53 /// 这将存储 waker 并安全地处理并发访问。
54 /// 如果在注册期间发生并发唤醒,新注册的 waker 将立即被唤醒。
55 #[inline]
56 pub(crate) fn register(&self, waker: &Waker) {
57 match self.state.compare_exchange(
58 WAITING,
59 REGISTERING,
60 Ordering::Acquire,
61 Ordering::Acquire,
62 ) {
63 Ok(_) => {
64 // SAFETY: We have exclusive access via REGISTERING lock
65 unsafe {
66 // Replace the waker
67 let old_waker = (*self.waker.get()).replace(waker.clone());
68
69 // Try to release the lock
70 match self.state.compare_exchange(
71 REGISTERING,
72 WAITING,
73 Ordering::AcqRel,
74 Ordering::Acquire,
75 ) {
76 Ok(_) => {
77 // Successfully released, just drop old waker
78 drop(old_waker);
79 }
80 Err(_) => {
81 // Concurrent wake happened, take waker and wake it
82 // State must be REGISTERING | WAKING
83 let waker = (*self.waker.get()).take();
84 self.state.store(WAITING, Ordering::Release);
85
86 drop(old_waker);
87 if let Some(waker) = waker {
88 waker.wake();
89 }
90 }
91 }
92 }
93 }
94 Err(WAKING) => {
95 // Currently waking, just wake the new waker directly
96 waker.wake_by_ref();
97 }
98 Err(_) => {
99 // Concurrent register (shouldn't happen in normal usage)
100 // Just drop this registration
101 }
102 }
103 }
104
105 /// Take the waker out for waking
106 ///
107 /// Returns the waker if one was registered, None otherwise.
108 /// This atomically removes the waker from storage.
109 ///
110 /// 取出 waker 用于唤醒
111 ///
112 /// 如果注册了 waker 则返回它,否则返回 None。
113 /// 这会原子地从存储中移除 waker。
114 #[inline]
115 pub(crate) fn take(&self) -> Option<Waker> {
116 match self.state.fetch_or(WAKING, Ordering::AcqRel) {
117 WAITING => {
118 // SAFETY: We have exclusive access via WAKING lock
119 let waker = unsafe { (*self.waker.get()).take() };
120
121 // Release the lock
122 self.state.store(WAITING, Ordering::Release);
123
124 waker
125 }
126 _ => {
127 // Concurrent register or wake in progress
128 None
129 }
130 }
131 }
132
133 /// Wake the registered waker if any
134 ///
135 /// 唤醒已注册的 waker(如果有)
136 #[inline]
137 pub(crate) fn wake(&self) {
138 if let Some(waker) = self.take() {
139 waker.wake();
140 }
141 }
142}
143
144impl Drop for AtomicWaker {
145 fn drop(&mut self) {
146 // SAFETY: We have exclusive access during drop
147 unsafe {
148 let _ = (*self.waker.get()).take();
149 }
150 }
151}
152
153impl std::fmt::Debug for AtomicWaker {
154 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 let state = self.state.load(Ordering::Acquire);
156 let state_str = match state {
157 WAITING => "Waiting",
158 REGISTERING => "Registering",
159 WAKING => "Waking",
160 _ => "Unknown",
161 };
162 f.debug_struct("AtomicWaker")
163 .field("state", &state_str)
164 .finish()
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use std::sync::Arc;
172
173 #[test]
174 fn test_basic_register_and_take() {
175 let atomic_waker = AtomicWaker::new();
176 let waker = futures::task::noop_waker();
177
178 atomic_waker.register(&waker);
179 let taken = atomic_waker.take();
180 assert!(taken.is_some());
181
182 // Second take should return None
183 let taken2 = atomic_waker.take();
184 assert!(taken2.is_none());
185 }
186
187 #[test]
188 fn test_concurrent_access() {
189 use std::thread;
190
191 let atomic_waker = Arc::new(AtomicWaker::new());
192 let waker = futures::task::noop_waker();
193
194 let aw1 = atomic_waker.clone();
195 let w1 = waker.clone();
196 let h1 = thread::spawn(move || {
197 for _ in 0..100 {
198 aw1.register(&w1);
199 }
200 });
201
202 let aw2 = atomic_waker.clone();
203 let h2 = thread::spawn(move || {
204 for _ in 0..100 {
205 aw2.take();
206 }
207 });
208
209 h1.join().unwrap();
210 h2.join().unwrap();
211 }
212}
213