lite_sync/oneshot/
generic.rs

1//! Generic oneshot channel for arbitrary types.
2//!
3//! 用于任意类型的通用一次性通道。
4
5use crate::shim::atomic::{AtomicU8, Ordering};
6use crate::shim::cell::UnsafeCell;
7use std::mem::MaybeUninit;
8
9use super::common::{self, OneshotStorage, TakeResult};
10
11// Re-export common types
12pub use super::common::RecvError;
13pub use super::common::TryRecvError;
14pub use super::common::error;
15
16// States for the value cell
17const EMPTY: u8 = 0; // No value stored
18const READY: u8 = 1; // Value is ready
19const SENDER_CLOSED: u8 = 2; // Sender dropped without sending
20const RECEIVER_CLOSED: u8 = 3; // Receiver closed
21
22// ============================================================================
23// Generic Storage
24// ============================================================================
25
26/// Storage for generic types using `UnsafeCell<MaybeUninit<T>>`
27///
28/// 使用 `UnsafeCell<MaybeUninit<T>>` 存储泛型类型
29pub struct GenericStorage<T> {
30    state: AtomicU8,
31    value: UnsafeCell<MaybeUninit<T>>,
32}
33
34// SAFETY: GenericStorage<T> is Send + Sync as long as T is Send
35// - UnsafeCell<MaybeUninit<T>> is protected by atomic state transitions
36// - Only one thread can access the value at a time (enforced by state machine)
37unsafe impl<T: Send> Send for GenericStorage<T> {}
38unsafe impl<T: Send> Sync for GenericStorage<T> {}
39
40impl<T: Send> OneshotStorage for GenericStorage<T> {
41    type Value = T;
42
43    #[inline]
44    fn new() -> Self {
45        Self {
46            state: AtomicU8::new(EMPTY),
47            value: UnsafeCell::new(MaybeUninit::uninit()),
48        }
49    }
50
51    #[inline]
52    fn store(&self, value: T) {
53        // SAFETY: Only called once by sender (enforced by ownership)
54        self.value.with_mut(|v| unsafe { (*v).write(value) });
55        self.state.store(READY, Ordering::Release);
56    }
57
58    #[inline]
59    fn try_take(&self) -> TakeResult<T> {
60        let state = self.state.swap(EMPTY, Ordering::Acquire);
61        match state {
62            READY => {
63                // SAFETY: State was READY, value is initialized
64                self.value
65                    .with(|v| unsafe { TakeResult::Ready((*v).assume_init_read()) })
66            }
67            SENDER_CLOSED | RECEIVER_CLOSED => TakeResult::Closed,
68            _ => TakeResult::Pending,
69        }
70    }
71
72    #[inline]
73    fn is_sender_dropped(&self) -> bool {
74        self.state.load(Ordering::Acquire) == SENDER_CLOSED
75    }
76
77    #[inline]
78    fn mark_sender_dropped(&self) {
79        self.state.store(SENDER_CLOSED, Ordering::Release);
80    }
81
82    #[inline]
83    fn is_receiver_closed(&self) -> bool {
84        self.state.load(Ordering::Acquire) == RECEIVER_CLOSED
85    }
86
87    #[inline]
88    fn mark_receiver_closed(&self) {
89        self.state.store(RECEIVER_CLOSED, Ordering::Release);
90    }
91}
92
93impl<T> Drop for GenericStorage<T> {
94    fn drop(&mut self) {
95        // Clean up the value if it was sent but not received
96        // Note: In drop, strict ordering isn't required if we own the object, but we follow protocol
97        if self.state.load(Ordering::Acquire) == READY {
98            self.value.with_mut(|v| unsafe {
99                (*v).assume_init_drop();
100            });
101        }
102    }
103}
104
105// ============================================================================
106// Type Aliases
107// ============================================================================
108
109/// Sender for one-shot value transfer of generic types
110///
111/// 用于泛型类型一次性值传递的发送器
112pub type Sender<T> = common::Sender<GenericStorage<T>>;
113
114/// Receiver for one-shot value transfer of generic types
115///
116/// 用于泛型类型一次性值传递的接收器
117pub type Receiver<T> = common::Receiver<GenericStorage<T>>;
118
119/// Create a new oneshot channel for generic types
120///
121/// 创建一个用于泛型类型的新 oneshot 通道
122#[inline]
123pub fn channel<T: Send>() -> (Sender<T>, Receiver<T>) {
124    Sender::new()
125}
126
127// ============================================================================
128// Receiver Extension Methods
129// ============================================================================
130
131impl<T: Send> Receiver<T> {
132    /// Try to receive a value without blocking
133    ///
134    /// Returns `Ok(value)` if value is ready, `Err(TryRecvError::Empty)` if pending,
135    /// or `Err(TryRecvError::Closed)` if sender was dropped.
136    ///
137    /// 尝试接收值而不阻塞
138    ///
139    /// 如果值就绪返回 `Ok(value)`,如果待处理返回 `Err(TryRecvError::Empty)`,
140    /// 如果发送器被丢弃返回 `Err(TryRecvError::Closed)`
141    #[inline]
142    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
143        match self.inner.try_recv() {
144            super::common::TakeResult::Ready(v) => Ok(v),
145            super::common::TakeResult::Pending => Err(TryRecvError::Empty),
146            super::common::TakeResult::Closed => Err(TryRecvError::Closed),
147        }
148    }
149}
150
151#[cfg(all(test, not(feature = "loom")))]
152mod tests {
153    use super::*;
154
155    #[tokio::test]
156    async fn test_oneshot_string() {
157        let (sender, receiver) = Sender::<String>::new();
158
159        tokio::spawn(async move {
160            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
161            sender.send("Hello".to_string()).unwrap();
162        });
163
164        let result = receiver.wait().await.unwrap();
165        assert_eq!(result, "Hello");
166    }
167
168    #[tokio::test]
169    async fn test_oneshot_integer() {
170        let (sender, receiver) = Sender::<i32>::new();
171
172        tokio::spawn(async move {
173            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
174            sender.send(42).unwrap();
175        });
176
177        let result = receiver.wait().await.unwrap();
178        assert_eq!(result, 42);
179    }
180
181    #[tokio::test]
182    async fn test_oneshot_immediate() {
183        let (sender, receiver) = Sender::<String>::new();
184
185        // Send before waiting (fast path)
186        sender.send("Immediate".to_string()).unwrap();
187
188        let result = receiver.wait().await.unwrap();
189        assert_eq!(result, "Immediate");
190    }
191
192    #[tokio::test]
193    async fn test_oneshot_custom_struct() {
194        #[derive(Debug, Clone, PartialEq)]
195        struct CustomData {
196            id: u64,
197            name: String,
198        }
199
200        let (sender, receiver) = Sender::<CustomData>::new();
201
202        let data = CustomData {
203            id: 123,
204            name: "Test".to_string(),
205        };
206
207        tokio::spawn(async move {
208            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
209            sender.send(data).unwrap();
210        });
211
212        let result = receiver.wait().await.unwrap();
213        assert_eq!(result.id, 123);
214        assert_eq!(result.name, "Test");
215    }
216
217    #[tokio::test]
218    async fn test_oneshot_direct_await() {
219        let (sender, receiver) = Sender::<i32>::new();
220
221        tokio::spawn(async move {
222            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
223            sender.send(99).unwrap();
224        });
225
226        // Direct await without .wait()
227        let result = receiver.await.unwrap();
228        assert_eq!(result, 99);
229    }
230
231    #[tokio::test]
232    async fn test_oneshot_await_mut_reference() {
233        let (sender, mut receiver) = Sender::<String>::new();
234
235        tokio::spawn(async move {
236            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
237            sender.send("Mutable".to_string()).unwrap();
238        });
239
240        // Await on mutable reference
241        let result = (&mut receiver).await.unwrap();
242        assert_eq!(result, "Mutable");
243    }
244
245    #[tokio::test]
246    async fn test_oneshot_immediate_await() {
247        let (sender, receiver) = Sender::<Vec<u8>>::new();
248
249        // Immediate send (fast path)
250        sender.send(vec![1, 2, 3]).unwrap();
251
252        // Direct await
253        let result = receiver.await.unwrap();
254        assert_eq!(result, vec![1, 2, 3]);
255    }
256
257    #[tokio::test]
258    async fn test_oneshot_try_recv() {
259        let (sender, mut receiver) = Sender::<i32>::new();
260
261        // Try receive before sending
262        assert_eq!(receiver.try_recv(), Err(TryRecvError::Empty));
263
264        // Send value
265        sender.send(42).unwrap();
266
267        // Try receive after sending
268        assert_eq!(receiver.try_recv(), Ok(42));
269    }
270
271    #[tokio::test]
272    async fn test_oneshot_try_recv_closed() {
273        let (sender, mut receiver) = Sender::<i32>::new();
274
275        // Drop sender without sending
276        drop(sender);
277
278        // Try receive should return Closed error
279        assert_eq!(receiver.try_recv(), Err(TryRecvError::Closed));
280    }
281
282    #[tokio::test]
283    async fn test_oneshot_dropped() {
284        let (sender, receiver) = Sender::<i32>::new();
285        drop(sender);
286        assert_eq!(receiver.await, Err(RecvError));
287    }
288
289    #[tokio::test]
290    async fn test_oneshot_large_data() {
291        let (sender, receiver) = Sender::<Vec<u8>>::new();
292
293        let large_vec = vec![0u8; 1024 * 1024]; // 1MB
294
295        tokio::spawn(async move {
296            sender.send(large_vec).unwrap();
297        });
298
299        let result = receiver.await.unwrap();
300        assert_eq!(result.len(), 1024 * 1024);
301    }
302
303    // Tests for is_closed
304    #[test]
305    fn test_sender_is_closed_initially_false() {
306        let (sender, _receiver) = Sender::<i32>::new();
307        assert!(!sender.is_closed());
308    }
309
310    #[test]
311    fn test_sender_is_closed_after_receiver_drop() {
312        let (sender, receiver) = Sender::<i32>::new();
313        drop(receiver);
314        assert!(sender.is_closed());
315    }
316
317    #[test]
318    fn test_sender_is_closed_after_receiver_close() {
319        let (sender, mut receiver) = Sender::<i32>::new();
320        receiver.close();
321        assert!(sender.is_closed());
322    }
323
324    // Tests for close
325    #[test]
326    fn test_receiver_close_prevents_send() {
327        let (sender, mut receiver) = Sender::<i32>::new();
328        receiver.close();
329
330        // Send should fail after close
331        assert!(sender.send(42).is_err());
332    }
333
334    // Tests for blocking_recv
335    #[test]
336    fn test_blocking_recv_immediate() {
337        let (sender, receiver) = Sender::<i32>::new();
338
339        // Send before blocking_recv (fast path)
340        sender.send(42).unwrap();
341
342        let result = receiver.blocking_recv();
343        assert_eq!(result, Ok(42));
344    }
345
346    #[test]
347    fn test_blocking_recv_with_thread() {
348        let (sender, receiver) = Sender::<String>::new();
349
350        std::thread::spawn(move || {
351            std::thread::sleep(std::time::Duration::from_millis(10));
352            sender.send("hello".to_string()).unwrap();
353        });
354
355        let result = receiver.blocking_recv();
356        assert_eq!(result, Ok("hello".to_string()));
357    }
358
359    #[test]
360    fn test_blocking_recv_sender_dropped() {
361        let (sender, receiver) = Sender::<i32>::new();
362
363        std::thread::spawn(move || {
364            std::thread::sleep(std::time::Duration::from_millis(10));
365            drop(sender);
366        });
367
368        let result = receiver.blocking_recv();
369        assert_eq!(result, Err(RecvError));
370    }
371}