lite_sync/
atomic_waker.rs

1/// Atomic waker storage using state machine for safe concurrent access
2/// 
3/// Based on Tokio's AtomicWaker but simplified for our use cases.
4/// Uses UnsafeCell<Option<Waker>> + atomic state machine to avoid Box allocation
5/// while maintaining safe concurrent access.
6/// 
7/// 使用状态机进行安全并发访问的原子 waker 存储
8/// 
9/// 基于 Tokio 的 AtomicWaker 但为我们的用例简化。
10/// 使用 UnsafeCell<Option<Waker>> + 原子状态机避免 Box 分配
11/// 同时保持安全的并发访问。
12
13use std::cell::UnsafeCell;
14use std::sync::atomic::{AtomicUsize, Ordering};
15use std::task::Waker;
16
17// Waker registration states
18const WAITING: usize = 0;
19const REGISTERING: usize = 0b01;
20const WAKING: usize = 0b10;
21
22/// Atomic waker storage with state machine synchronization
23/// 
24/// 带有状态机同步的原子 waker 存储
25pub(crate) struct AtomicWaker {
26    state: AtomicUsize,
27    waker: UnsafeCell<Option<Waker>>,
28}
29
30// SAFETY: AtomicWaker is Sync because access to waker is synchronized via atomic state machine
31unsafe impl Sync for AtomicWaker {}
32unsafe impl Send for AtomicWaker {}
33
34impl AtomicWaker {
35    /// Create a new atomic waker
36    /// 
37    /// 创建一个新的原子 waker
38    #[inline]
39    pub(crate) const fn new() -> Self {
40        Self {
41            state: AtomicUsize::new(WAITING),
42            waker: UnsafeCell::new(None),
43        }
44    }
45
46    /// Register a waker to be notified
47    /// 
48    /// This will store the waker and handle concurrent access safely.
49    /// If a concurrent wake happens during registration, the newly
50    /// registered waker will be woken immediately.
51    /// 
52    /// 注册一个要通知的 waker
53    /// 
54    /// 这将存储 waker 并安全地处理并发访问。
55    /// 如果在注册期间发生并发唤醒,新注册的 waker 将立即被唤醒。
56    #[inline]
57    pub(crate) fn register(&self, waker: &Waker) {
58        match self.state.compare_exchange(
59            WAITING,
60            REGISTERING,
61            Ordering::Acquire,
62            Ordering::Acquire,
63        ) {
64            Ok(_) => {
65                // SAFETY: We have exclusive access via REGISTERING lock
66                unsafe {
67                    // Replace the waker
68                    let old_waker = (*self.waker.get()).replace(waker.clone());
69                    
70                    // Try to release the lock
71                    match self.state.compare_exchange(
72                        REGISTERING,
73                        WAITING,
74                        Ordering::AcqRel,
75                        Ordering::Acquire,
76                    ) {
77                        Ok(_) => {
78                            // Successfully released, just drop old waker
79                            drop(old_waker);
80                        }
81                        Err(_) => {
82                            // Concurrent wake happened, take waker and wake it
83                            // State must be REGISTERING | WAKING
84                            let waker = (*self.waker.get()).take();
85                            self.state.store(WAITING, Ordering::Release);
86                            
87                            drop(old_waker);
88                            if let Some(waker) = waker {
89                                waker.wake();
90                            }
91                        }
92                    }
93                }
94            }
95            Err(WAKING) => {
96                // Currently waking, just wake the new waker directly
97                waker.wake_by_ref();
98            }
99            Err(_) => {
100                // Concurrent register (shouldn't happen in normal usage)
101                // Just drop this registration
102            }
103        }
104    }
105
106    /// Take the waker out for waking
107    /// 
108    /// Returns the waker if one was registered, None otherwise.
109    /// This atomically removes the waker from storage.
110    /// 
111    /// 取出 waker 用于唤醒
112    /// 
113    /// 如果注册了 waker 则返回它,否则返回 None。
114    /// 这会原子地从存储中移除 waker。
115    #[inline]
116    pub(crate) fn take(&self) -> Option<Waker> {
117        match self.state.fetch_or(WAKING, Ordering::AcqRel) {
118            WAITING => {
119                // SAFETY: We have exclusive access via WAKING lock
120                let waker = unsafe { (*self.waker.get()).take() };
121                
122                // Release the lock
123                self.state.store(WAITING, Ordering::Release);
124                
125                waker
126            }
127            _ => {
128                // Concurrent register or wake in progress
129                None
130            }
131        }
132    }
133
134    /// Wake the registered waker if any
135    /// 
136    /// 唤醒已注册的 waker(如果有)
137    #[inline]
138    pub(crate) fn wake(&self) {
139        if let Some(waker) = self.take() {
140            waker.wake();
141        }
142    }
143}
144
145impl Drop for AtomicWaker {
146    fn drop(&mut self) {
147        // SAFETY: We have exclusive access during drop
148        unsafe {
149            let _ = (*self.waker.get()).take();
150        }
151    }
152}
153
154impl std::fmt::Debug for AtomicWaker {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        f.debug_struct("AtomicWaker").finish()
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use std::sync::Arc;
164
165    #[test]
166    fn test_basic_register_and_take() {
167        let atomic_waker = AtomicWaker::new();
168        let waker = futures::task::noop_waker();
169        
170        atomic_waker.register(&waker);
171        let taken = atomic_waker.take();
172        assert!(taken.is_some());
173        
174        // Second take should return None
175        let taken2 = atomic_waker.take();
176        assert!(taken2.is_none());
177    }
178
179    #[test]
180    fn test_concurrent_access() {
181        use std::thread;
182        
183        let atomic_waker = Arc::new(AtomicWaker::new());
184        let waker = futures::task::noop_waker();
185        
186        let aw1 = atomic_waker.clone();
187        let w1 = waker.clone();
188        let h1 = thread::spawn(move || {
189            for _ in 0..100 {
190                aw1.register(&w1);
191            }
192        });
193        
194        let aw2 = atomic_waker.clone();
195        let h2 = thread::spawn(move || {
196            for _ in 0..100 {
197                aw2.take();
198            }
199        });
200        
201        h1.join().unwrap();
202        h2.join().unwrap();
203    }
204}
205