Skip to main content

moduvex_runtime/sync/
oneshot.rs

1//! One-shot channel — send exactly one value from producer to consumer.
2//!
3//! `Sender` consumes itself on `send`; `Receiver` implements `Future` and
4//! resolves to `Result<T, RecvError>`. Dropping the `Sender` before sending
5//! causes the `Receiver` to resolve with `RecvError::Closed`.
6
7use std::cell::UnsafeCell;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::atomic::{AtomicU8, Ordering};
11use std::sync::{Arc, Mutex};
12use std::task::{Context, Poll, Waker};
13
14// ── State constants ───────────────────────────────────────────────────────────
15
16/// No value has been sent; sender still alive.
17const EMPTY: u8 = 0;
18/// Value has been written into the cell; receiver may take it.
19const SENT: u8 = 1;
20/// Sender was dropped without sending (channel closed).
21const CLOSED: u8 = 2;
22
23// ── Inner shared state ────────────────────────────────────────────────────────
24
25struct Inner<T> {
26    /// Current channel state: EMPTY | SENT | CLOSED.
27    state: AtomicU8,
28    /// Storage for the transmitted value. Written exactly once (EMPTY → SENT).
29    ///
30    /// `UnsafeCell` is required because we write through a shared `Arc`.
31    /// Access is guarded by the `state` atomic: the sender writes while
32    /// `state == EMPTY` (exclusive via CAS), the receiver reads only after
33    /// observing `state == SENT`.
34    value: UnsafeCell<Option<T>>,
35    /// Waker for the blocked receiver (stored while state == EMPTY).
36    waker: Mutex<Option<Waker>>,
37}
38
39// SAFETY: `Inner<T>` is shared across threads via `Arc`. The `UnsafeCell`
40// holding the value is accessed in a sequenced, non-concurrent fashion:
41// the sender writes once (EMPTY → SENT CAS), the receiver reads once
42// (after observing SENT). The `Mutex<Option<Waker>>` guards the waker.
43unsafe impl<T: Send> Send for Inner<T> {}
44unsafe impl<T: Send> Sync for Inner<T> {}
45
46impl<T> Inner<T> {
47    fn new() -> Self {
48        Self {
49            state: AtomicU8::new(EMPTY),
50            value: UnsafeCell::new(None),
51            waker: Mutex::new(None),
52        }
53    }
54}
55
56// ── Public API ────────────────────────────────────────────────────────────────
57
58/// Error returned when a `Receiver` future resolves without a value.
59#[derive(Debug, PartialEq, Eq)]
60pub enum RecvError {
61    /// The `Sender` was dropped without calling `send`.
62    Closed,
63}
64
65impl std::fmt::Display for RecvError {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            RecvError::Closed => f.write_str("oneshot channel closed without a value"),
69        }
70    }
71}
72
73impl std::error::Error for RecvError {}
74
75/// Create a new one-shot channel, returning `(Sender, Receiver)`.
76pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
77    let inner = Arc::new(Inner::new());
78    (
79        Sender {
80            inner: inner.clone(),
81            sent: false,
82        },
83        Receiver { inner },
84    )
85}
86
87// ── Sender ────────────────────────────────────────────────────────────────────
88
89/// Sending half of a one-shot channel. Consumed on `send`.
90pub struct Sender<T> {
91    inner: Arc<Inner<T>>,
92    /// Guards against accidental double-send through raw pointer tricks.
93    sent: bool,
94}
95
96impl<T> Sender<T> {
97    /// Send `value` to the receiver. Consumes `self`.
98    ///
99    /// Returns `Err(value)` if the receiver has already been dropped.
100    pub fn send(mut self, value: T) -> Result<(), T> {
101        // Write the value before transitioning state so the receiver always
102        // sees a fully initialized `Option<T>` when it observes `SENT`.
103        //
104        // SAFETY: We hold exclusive write rights while state == EMPTY.
105        // The CAS below succeeds only once; no other thread writes here.
106        unsafe { *self.inner.value.get() = Some(value) };
107
108        match self.inner.state.compare_exchange(
109            EMPTY,
110            SENT,
111            Ordering::Release, // publish the write above
112            Ordering::Relaxed,
113        ) {
114            Ok(_) => {
115                self.sent = true;
116                // Wake the receiver if it registered a waker.
117                if let Some(w) = self.inner.waker.lock().unwrap().take() {
118                    w.wake();
119                }
120                Ok(())
121            }
122            Err(_) => {
123                // Receiver already dropped (state == CLOSED) — reclaim value.
124                // SAFETY: We just wrote it above and the CAS failed, so
125                // the receiver will never read it.
126                let val = unsafe { (*self.inner.value.get()).take() }.unwrap();
127                Err(val)
128            }
129        }
130    }
131}
132
133impl<T> Drop for Sender<T> {
134    fn drop(&mut self) {
135        if self.sent {
136            return; // value already transferred
137        }
138        // Signal the receiver that no value is coming.
139        let prev = self.inner.state.swap(CLOSED, Ordering::Release);
140        if prev == EMPTY {
141            if let Some(w) = self.inner.waker.lock().unwrap().take() {
142                w.wake();
143            }
144        }
145    }
146}
147
148// ── Receiver ──────────────────────────────────────────────────────────────────
149
150/// Receiving half of a one-shot channel. Implements `Future`.
151pub struct Receiver<T> {
152    inner: Arc<Inner<T>>,
153}
154
155impl<T> Future for Receiver<T> {
156    type Output = Result<T, RecvError>;
157
158    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159        let state = self.inner.state.load(Ordering::Acquire);
160        match state {
161            SENT => {
162                // SAFETY: state == SENT guarantees the sender wrote the value
163                // and will not write again. We are the sole reader (Receiver
164                // is not Clone), so `take` is safe.
165                let val = unsafe { (*self.inner.value.get()).take() }
166                    .expect("oneshot: SENT state but value is None (logic error)");
167                Poll::Ready(Ok(val))
168            }
169            CLOSED => Poll::Ready(Err(RecvError::Closed)),
170            _ => {
171                // EMPTY — register waker and yield.
172                *self.inner.waker.lock().unwrap() = Some(cx.waker().clone());
173                // Re-check state after registering to avoid lost wake.
174                let state2 = self.inner.state.load(Ordering::Acquire);
175                if state2 == SENT {
176                    // SAFETY: same as above — SENT, sole reader.
177                    let val = unsafe { (*self.inner.value.get()).take() }
178                        .expect("oneshot: SENT but value None after re-check");
179                    Poll::Ready(Ok(val))
180                } else if state2 == CLOSED {
181                    Poll::Ready(Err(RecvError::Closed))
182                } else {
183                    Poll::Pending
184                }
185            }
186        }
187    }
188}
189
190impl<T> Drop for Receiver<T> {
191    fn drop(&mut self) {
192        // Inform the sender (if still alive) that nobody will read the value.
193        // CAS EMPTY → CLOSED; if already SENT we just leave the value to be
194        // dropped when `inner` is freed.
195        let _ =
196            self.inner
197                .state
198                .compare_exchange(EMPTY, CLOSED, Ordering::Relaxed, Ordering::Relaxed);
199    }
200}
201
202// ── Tests ─────────────────────────────────────────────────────────────────────
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use crate::executor::{block_on, block_on_with_spawn, spawn};
208
209    #[test]
210    fn send_then_recv() {
211        let result = block_on(async {
212            let (tx, rx) = channel::<u32>();
213            tx.send(42).unwrap();
214            rx.await
215        });
216        assert_eq!(result, Ok(42));
217    }
218
219    #[test]
220    fn recv_then_send_via_spawn() {
221        let result = block_on_with_spawn(async {
222            let (tx, rx) = channel::<String>();
223            let jh = spawn(async move {
224                tx.send("hello".to_string()).unwrap();
225            });
226            let val = rx.await.unwrap();
227            jh.await.unwrap();
228            val
229        });
230        assert_eq!(result, "hello");
231    }
232
233    #[test]
234    fn sender_drop_closes_channel() {
235        let result = block_on(async {
236            let (tx, rx) = channel::<u32>();
237            drop(tx);
238            rx.await
239        });
240        assert_eq!(result, Err(RecvError::Closed));
241    }
242
243    #[test]
244    fn send_after_receiver_drop_returns_err() {
245        let (tx, rx) = channel::<u32>();
246        drop(rx);
247        assert!(tx.send(1).is_err());
248    }
249
250    #[test]
251    fn value_types_roundtrip() {
252        block_on(async {
253            let (tx, rx) = channel::<Vec<u8>>();
254            tx.send(vec![1, 2, 3]).unwrap();
255            assert_eq!(rx.await.unwrap(), vec![1, 2, 3]);
256        });
257    }
258}