Skip to main content

moduvex_runtime/sync/
oneshot.rs

1//! One-shot channel — send exactly one value from producer to consumer.
2//!
3//! `Sender` consumes itself on `send`; `Receiver` implements `Future` and
4//! resolves to `Result<T, RecvError>`. Dropping the `Sender` before sending
5//! causes the `Receiver` to resolve with `RecvError::Closed`.
6
7use std::cell::UnsafeCell;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::atomic::{AtomicU8, Ordering};
11use std::sync::{Arc, Mutex};
12use std::task::{Context, Poll, Waker};
13
14// ── State constants ───────────────────────────────────────────────────────────
15
16/// No value has been sent; sender still alive.
17const EMPTY: u8 = 0;
18/// Value has been written into the cell; receiver may take it.
19const SENT: u8 = 1;
20/// Sender was dropped without sending (channel closed).
21const CLOSED: u8 = 2;
22
23// ── Inner shared state ────────────────────────────────────────────────────────
24
25struct Inner<T> {
26    /// Current channel state: EMPTY | SENT | CLOSED.
27    state: AtomicU8,
28    /// Storage for the transmitted value. Written exactly once (EMPTY → SENT).
29    ///
30    /// `UnsafeCell` is required because we write through a shared `Arc`.
31    /// Access is guarded by the `state` atomic: the sender writes while
32    /// `state == EMPTY` (exclusive via CAS), the receiver reads only after
33    /// observing `state == SENT`.
34    value: UnsafeCell<Option<T>>,
35    /// Waker for the blocked receiver (stored while state == EMPTY).
36    waker: Mutex<Option<Waker>>,
37}
38
39// SAFETY: `Inner<T>` is shared across threads via `Arc`. The `UnsafeCell`
40// holding the value is accessed in a sequenced, non-concurrent fashion:
41// the sender writes once (EMPTY → SENT CAS), the receiver reads once
42// (after observing SENT). The `Mutex<Option<Waker>>` guards the waker.
43unsafe impl<T: Send> Send for Inner<T> {}
44unsafe impl<T: Send> Sync for Inner<T> {}
45
46impl<T> Inner<T> {
47    fn new() -> Self {
48        Self {
49            state: AtomicU8::new(EMPTY),
50            value: UnsafeCell::new(None),
51            waker: Mutex::new(None),
52        }
53    }
54}
55
56// ── Public API ────────────────────────────────────────────────────────────────
57
58/// Error returned when a `Receiver` future resolves without a value.
59#[derive(Debug, PartialEq, Eq)]
60pub enum RecvError {
61    /// The `Sender` was dropped without calling `send`.
62    Closed,
63}
64
65impl std::fmt::Display for RecvError {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            RecvError::Closed => f.write_str("oneshot channel closed without a value"),
69        }
70    }
71}
72
73impl std::error::Error for RecvError {}
74
75/// Create a new one-shot channel, returning `(Sender, Receiver)`.
76pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
77    let inner = Arc::new(Inner::new());
78    (
79        Sender {
80            inner: inner.clone(),
81            sent: false,
82        },
83        Receiver { inner },
84    )
85}
86
87// ── Sender ────────────────────────────────────────────────────────────────────
88
89/// Sending half of a one-shot channel. Consumed on `send`.
90pub struct Sender<T> {
91    inner: Arc<Inner<T>>,
92    /// Guards against accidental double-send through raw pointer tricks.
93    sent: bool,
94}
95
96impl<T> Sender<T> {
97    /// Send `value` to the receiver. Consumes `self`.
98    ///
99    /// Returns `Err(value)` if the receiver has already been dropped.
100    pub fn send(mut self, value: T) -> Result<(), T> {
101        // Write the value before transitioning state so the receiver always
102        // sees a fully initialized `Option<T>` when it observes `SENT`.
103        //
104        // SAFETY: We hold exclusive write rights while state == EMPTY.
105        // The CAS below succeeds only once; no other thread writes here.
106        unsafe { *self.inner.value.get() = Some(value) };
107
108        match self.inner.state.compare_exchange(
109            EMPTY,
110            SENT,
111            Ordering::Release, // publish the write above
112            Ordering::Relaxed,
113        ) {
114            Ok(_) => {
115                self.sent = true;
116                // Wake the receiver if it registered a waker.
117                if let Some(w) = self.inner.waker.lock().unwrap().take() {
118                    w.wake();
119                }
120                Ok(())
121            }
122            Err(_) => {
123                // Receiver already dropped (state == CLOSED) — reclaim value.
124                // SAFETY: We just wrote it above and the CAS failed, so
125                // the receiver will never read it.
126                let val = unsafe { (*self.inner.value.get()).take() }.unwrap();
127                Err(val)
128            }
129        }
130    }
131}
132
133impl<T> Drop for Sender<T> {
134    fn drop(&mut self) {
135        if self.sent {
136            return; // value already transferred
137        }
138        // Signal the receiver that no value is coming.
139        let prev = self.inner.state.swap(CLOSED, Ordering::Release);
140        if prev == EMPTY {
141            if let Some(w) = self.inner.waker.lock().unwrap().take() {
142                w.wake();
143            }
144        }
145    }
146}
147
148// ── Receiver ──────────────────────────────────────────────────────────────────
149
150/// Receiving half of a one-shot channel. Implements `Future`.
151pub struct Receiver<T> {
152    inner: Arc<Inner<T>>,
153}
154
155impl<T> Future for Receiver<T> {
156    type Output = Result<T, RecvError>;
157
158    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159        let state = self.inner.state.load(Ordering::Acquire);
160        match state {
161            SENT => {
162                // SAFETY: state == SENT guarantees the sender wrote the value
163                // and will not write again. We are the sole reader (Receiver
164                // is not Clone), so `take` is safe.
165                let val = unsafe { (*self.inner.value.get()).take() }
166                    .expect("oneshot: SENT state but value is None (logic error)");
167                Poll::Ready(Ok(val))
168            }
169            CLOSED => Poll::Ready(Err(RecvError::Closed)),
170            _ => {
171                // EMPTY — register waker and yield.
172                *self.inner.waker.lock().unwrap() = Some(cx.waker().clone());
173                // Re-check state after registering to avoid lost wake.
174                let state2 = self.inner.state.load(Ordering::Acquire);
175                if state2 == SENT {
176                    // SAFETY: same as above — SENT, sole reader.
177                    let val = unsafe { (*self.inner.value.get()).take() }
178                        .expect("oneshot: SENT but value None after re-check");
179                    Poll::Ready(Ok(val))
180                } else if state2 == CLOSED {
181                    Poll::Ready(Err(RecvError::Closed))
182                } else {
183                    Poll::Pending
184                }
185            }
186        }
187    }
188}
189
190impl<T> Drop for Receiver<T> {
191    fn drop(&mut self) {
192        // Inform the sender (if still alive) that nobody will read the value.
193        // CAS EMPTY → CLOSED; if already SENT we just leave the value to be
194        // dropped when `inner` is freed.
195        let _ =
196            self.inner
197                .state
198                .compare_exchange(EMPTY, CLOSED, Ordering::Relaxed, Ordering::Relaxed);
199    }
200}
201
202// ── Tests ─────────────────────────────────────────────────────────────────────
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use crate::executor::{block_on, block_on_with_spawn, spawn};
208
209    #[test]
210    fn send_then_recv() {
211        let result = block_on(async {
212            let (tx, rx) = channel::<u32>();
213            tx.send(42).unwrap();
214            rx.await
215        });
216        assert_eq!(result, Ok(42));
217    }
218
219    #[test]
220    fn recv_then_send_via_spawn() {
221        let result = block_on_with_spawn(async {
222            let (tx, rx) = channel::<String>();
223            let jh = spawn(async move {
224                tx.send("hello".to_string()).unwrap();
225            });
226            let val = rx.await.unwrap();
227            jh.await.unwrap();
228            val
229        });
230        assert_eq!(result, "hello");
231    }
232
233    #[test]
234    fn sender_drop_closes_channel() {
235        let result = block_on(async {
236            let (tx, rx) = channel::<u32>();
237            drop(tx);
238            rx.await
239        });
240        assert_eq!(result, Err(RecvError::Closed));
241    }
242
243    #[test]
244    fn send_after_receiver_drop_returns_err() {
245        let (tx, rx) = channel::<u32>();
246        drop(rx);
247        assert!(tx.send(1).is_err());
248    }
249
250    #[test]
251    fn value_types_roundtrip() {
252        block_on(async {
253            let (tx, rx) = channel::<Vec<u8>>();
254            tx.send(vec![1, 2, 3]).unwrap();
255            assert_eq!(rx.await.unwrap(), vec![1, 2, 3]);
256        });
257    }
258
259    // ── Additional oneshot tests ───────────────────────────────────────────
260
261    #[test]
262    fn oneshot_send_string() {
263        let result = block_on(async {
264            let (tx, rx) = channel::<String>();
265            tx.send("world".to_string()).unwrap();
266            rx.await
267        });
268        assert_eq!(result.unwrap(), "world");
269    }
270
271    #[test]
272    fn oneshot_send_struct() {
273        #[derive(Debug, PartialEq)]
274        struct Point {
275            x: i32,
276            y: i32,
277        }
278        let result = block_on(async {
279            let (tx, rx) = channel::<Point>();
280            tx.send(Point { x: 1, y: 2 }).unwrap();
281            rx.await
282        });
283        assert_eq!(result.unwrap(), Point { x: 1, y: 2 });
284    }
285
286    #[test]
287    fn oneshot_send_vec() {
288        let result = block_on(async {
289            let (tx, rx) = channel::<Vec<u8>>();
290            tx.send(vec![1, 2, 3, 4, 5]).unwrap();
291            rx.await
292        });
293        assert_eq!(result.unwrap(), vec![1, 2, 3, 4, 5]);
294    }
295
296    #[test]
297    fn oneshot_multiple_pairs_concurrent() {
298        block_on_with_spawn(async {
299            let mut rxs = Vec::new();
300            for i in 0u32..10 {
301                let (tx, rx) = channel::<u32>();
302                spawn(async move {
303                    tx.send(i).unwrap();
304                });
305                rxs.push(rx);
306            }
307            let mut results: Vec<u32> = Vec::new();
308            for rx in rxs {
309                results.push(rx.await.unwrap());
310            }
311            results.sort();
312            assert_eq!(results, (0..10).collect::<Vec<_>>());
313        });
314    }
315
316    #[test]
317    fn oneshot_recv_error_display() {
318        let err = RecvError::Closed;
319        let s = format!("{err}");
320        assert!(s.contains("closed") || s.contains("Closed"));
321    }
322
323    #[test]
324    fn oneshot_send_returns_err_when_rx_dropped() {
325        let (tx, rx) = channel::<i32>();
326        drop(rx);
327        let result = tx.send(42);
328        assert_eq!(result, Err(42));
329    }
330
331    #[test]
332    fn oneshot_send_value_then_recv_in_separate_block_on() {
333        // Verify that a value sent synchronously (before polling) is received correctly.
334        let (tx, rx) = channel::<u64>();
335        tx.send(12345).unwrap();
336        let val = block_on(async { rx.await.unwrap() });
337        assert_eq!(val, 12345);
338    }
339
340    #[test]
341    fn oneshot_sender_drop_closes_from_spawn() {
342        let result = block_on_with_spawn(async {
343            let (tx, rx) = channel::<u32>();
344            // Sender dropped inside a spawned task without sending
345            let jh = spawn(async move {
346                drop(tx);
347            });
348            jh.await.unwrap();
349            rx.await
350        });
351        assert_eq!(result, Err(RecvError::Closed));
352    }
353
354    #[test]
355    fn oneshot_recv_error_is_error_trait() {
356        let err = RecvError::Closed;
357        // RecvError implements std::error::Error
358        let _e: &dyn std::error::Error = &err;
359    }
360
361    #[test]
362    fn oneshot_u8_roundtrip() {
363        let result = block_on(async {
364            let (tx, rx) = channel::<u8>();
365            tx.send(255).unwrap();
366            rx.await.unwrap()
367        });
368        assert_eq!(result, 255);
369    }
370
371    #[test]
372    fn oneshot_bool_roundtrip() {
373        let result = block_on(async {
374            let (tx, rx) = channel::<bool>();
375            tx.send(true).unwrap();
376            rx.await.unwrap()
377        });
378        assert!(result);
379    }
380
381    #[test]
382    fn oneshot_unit_roundtrip() {
383        let result = block_on(async {
384            let (tx, rx) = channel::<()>();
385            tx.send(()).unwrap();
386            rx.await.unwrap()
387        });
388        assert_eq!(result, ());
389    }
390
391    #[test]
392    fn oneshot_10_pairs_in_parallel() {
393        block_on_with_spawn(async {
394            let mut rxs = Vec::new();
395            for i in 0..10u32 {
396                let (tx, rx) = channel::<u32>();
397                let val = i * 3;
398                spawn(async move { tx.send(val).unwrap() });
399                rxs.push((i, rx));
400            }
401            for (i, rx) in rxs {
402                let v = rx.await.unwrap();
403                assert_eq!(v, i * 3);
404            }
405        });
406    }
407
408    #[test]
409    fn oneshot_send_before_poll_synchronous() {
410        // Sender sends synchronously before receiver is polled — should be ready immediately
411        let (tx, rx) = channel::<u32>();
412        tx.send(777).unwrap();
413        let v = block_on(async { rx.await.unwrap() });
414        assert_eq!(v, 777);
415    }
416
417    #[test]
418    fn oneshot_send_i64_max() {
419        let result = block_on(async {
420            let (tx, rx) = channel::<i64>();
421            tx.send(i64::MAX).unwrap();
422            rx.await.unwrap()
423        });
424        assert_eq!(result, i64::MAX);
425    }
426
427    #[test]
428    fn oneshot_send_i64_min() {
429        let result = block_on(async {
430            let (tx, rx) = channel::<i64>();
431            tx.send(i64::MIN).unwrap();
432            rx.await.unwrap()
433        });
434        assert_eq!(result, i64::MIN);
435    }
436
437    #[test]
438    fn oneshot_send_empty_vec() {
439        let result = block_on(async {
440            let (tx, rx) = channel::<Vec<u8>>();
441            tx.send(Vec::new()).unwrap();
442            rx.await.unwrap()
443        });
444        assert!(result.is_empty());
445    }
446
447    #[test]
448    fn oneshot_two_separate_channels_independent() {
449        block_on(async {
450            let (tx1, rx1) = channel::<u32>();
451            let (tx2, rx2) = channel::<u32>();
452            tx1.send(1).unwrap();
453            tx2.send(2).unwrap();
454            assert_eq!(rx1.await.unwrap(), 1);
455            assert_eq!(rx2.await.unwrap(), 2);
456        });
457    }
458}