lite_sync/
atomic_waker.rs

1/// Atomic waker storage using state machine for safe concurrent access
2/// 
3/// Based on Tokio's AtomicWaker but simplified for our use cases.
4/// Uses UnsafeCell<Option<Waker>> + atomic state machine to avoid Box allocation
5/// while maintaining safe concurrent access.
6/// 
7/// 使用状态机进行安全并发访问的原子 waker 存储
8/// 
9/// 基于 Tokio 的 AtomicWaker 但为我们的用例简化。
10/// 使用 UnsafeCell<Option<Waker>> + 原子状态机避免 Box 分配
11/// 同时保持安全的并发访问。
12use std::cell::UnsafeCell;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::task::Waker;
15
16// Waker registration states
17const WAITING: usize = 0;
18const REGISTERING: usize = 0b01;
19const WAKING: usize = 0b10;
20
21/// Atomic waker storage with state machine synchronization
22/// 
23/// 带有状态机同步的原子 waker 存储
24pub(crate) struct AtomicWaker {
25    state: AtomicUsize,
26    waker: UnsafeCell<Option<Waker>>,
27}
28
29// SAFETY: AtomicWaker is Sync because access to waker is synchronized via atomic state machine
30unsafe impl Sync for AtomicWaker {}
31unsafe impl Send for AtomicWaker {}
32
33impl AtomicWaker {
34    /// Create a new atomic waker
35    /// 
36    /// 创建一个新的原子 waker
37    #[inline]
38    pub(crate) const fn new() -> Self {
39        Self {
40            state: AtomicUsize::new(WAITING),
41            waker: UnsafeCell::new(None),
42        }
43    }
44
45    /// Register a waker to be notified
46    /// 
47    /// This will store the waker and handle concurrent access safely.
48    /// If a concurrent wake happens during registration, the newly
49    /// registered waker will be woken immediately.
50    /// 
51    /// 注册一个要通知的 waker
52    /// 
53    /// 这将存储 waker 并安全地处理并发访问。
54    /// 如果在注册期间发生并发唤醒,新注册的 waker 将立即被唤醒。
55    #[inline]
56    pub(crate) fn register(&self, waker: &Waker) {
57        match self.state.compare_exchange(
58            WAITING,
59            REGISTERING,
60            Ordering::Acquire,
61            Ordering::Acquire,
62        ) {
63            Ok(_) => {
64                // SAFETY: We have exclusive access via REGISTERING lock
65                unsafe {
66                    // Replace the waker
67                    let old_waker = (*self.waker.get()).replace(waker.clone());
68                    
69                    // Try to release the lock
70                    match self.state.compare_exchange(
71                        REGISTERING,
72                        WAITING,
73                        Ordering::AcqRel,
74                        Ordering::Acquire,
75                    ) {
76                        Ok(_) => {
77                            // Successfully released, just drop old waker
78                            drop(old_waker);
79                        }
80                        Err(_) => {
81                            // Concurrent wake happened, take waker and wake it
82                            // State must be REGISTERING | WAKING
83                            let waker = (*self.waker.get()).take();
84                            self.state.store(WAITING, Ordering::Release);
85                            
86                            drop(old_waker);
87                            if let Some(waker) = waker {
88                                waker.wake();
89                            }
90                        }
91                    }
92                }
93            }
94            Err(WAKING) => {
95                // Currently waking, just wake the new waker directly
96                waker.wake_by_ref();
97            }
98            Err(_) => {
99                // Concurrent register (shouldn't happen in normal usage)
100                // Just drop this registration
101            }
102        }
103    }
104
105    /// Take the waker out for waking
106    /// 
107    /// Returns the waker if one was registered, None otherwise.
108    /// This atomically removes the waker from storage.
109    /// 
110    /// 取出 waker 用于唤醒
111    /// 
112    /// 如果注册了 waker 则返回它,否则返回 None。
113    /// 这会原子地从存储中移除 waker。
114    #[inline]
115    pub(crate) fn take(&self) -> Option<Waker> {
116        match self.state.fetch_or(WAKING, Ordering::AcqRel) {
117            WAITING => {
118                // SAFETY: We have exclusive access via WAKING lock
119                let waker = unsafe { (*self.waker.get()).take() };
120                
121                // Release the lock
122                self.state.store(WAITING, Ordering::Release);
123                
124                waker
125            }
126            _ => {
127                // Concurrent register or wake in progress
128                None
129            }
130        }
131    }
132
133    /// Wake the registered waker if any
134    /// 
135    /// 唤醒已注册的 waker(如果有)
136    #[inline]
137    pub(crate) fn wake(&self) {
138        if let Some(waker) = self.take() {
139            waker.wake();
140        }
141    }
142}
143
144impl Drop for AtomicWaker {
145    fn drop(&mut self) {
146        // SAFETY: We have exclusive access during drop
147        unsafe {
148            let _ = (*self.waker.get()).take();
149        }
150    }
151}
152
153impl std::fmt::Debug for AtomicWaker {
154    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155        let state = self.state.load(Ordering::Acquire);
156        let state_str = match state {
157            WAITING => "Waiting",
158            REGISTERING => "Registering",
159            WAKING => "Waking",
160            _ => "Unknown",
161        };
162        f.debug_struct("AtomicWaker")
163            .field("state", &state_str)
164            .finish()
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use std::sync::Arc;
172
173    #[test]
174    fn test_basic_register_and_take() {
175        let atomic_waker = AtomicWaker::new();
176        let waker = futures::task::noop_waker();
177        
178        atomic_waker.register(&waker);
179        let taken = atomic_waker.take();
180        assert!(taken.is_some());
181        
182        // Second take should return None
183        let taken2 = atomic_waker.take();
184        assert!(taken2.is_none());
185    }
186
187    #[test]
188    fn test_concurrent_access() {
189        use std::thread;
190        
191        let atomic_waker = Arc::new(AtomicWaker::new());
192        let waker = futures::task::noop_waker();
193        
194        let aw1 = atomic_waker.clone();
195        let w1 = waker.clone();
196        let h1 = thread::spawn(move || {
197            for _ in 0..100 {
198                aw1.register(&w1);
199            }
200        });
201        
202        let aw2 = atomic_waker.clone();
203        let h2 = thread::spawn(move || {
204            for _ in 0..100 {
205                aw2.take();
206            }
207        });
208        
209        h1.join().unwrap();
210        h2.join().unwrap();
211    }
212}
213