heph_inbox/
oneshot.rs

1//! One-shot channel.
2//!
3//! The channel allows you to send a single value and that it. It does allow the
4//! channel's allocation to be reused via [`Receiver::try_reset`]. It is
5//! designed to be used for [Remote Procedure Calls (RPC)].
6//!
7//! [Remote Procedure Calls (RPC)]: https://en.wikipedia.org/wiki/Remote_procedure_call
8//!
9//!
10//! # Examples
11//!
12//! Simple creation of a channel and sending a message over it.
13//!
14//! ```
15//! use std::thread;
16//!
17//! use heph_inbox::oneshot::{RecvError, new_oneshot};
18//!
19//! // Create a new small channel.
20//! let (sender, mut receiver) = new_oneshot();
21//!
22//! let sender_handle = thread::spawn(move || {
23//!     if let Err(err) = sender.try_send("Hello world!".to_owned()) {
24//!         panic!("Failed to send value: {}", err);
25//!     }
26//! });
27//!
28//! let receiver_handle = thread::spawn(move || {
29//! #   #[cfg(not(miri))] // `sleep` not supported.
30//! #   thread::sleep(std::time::Duration::from_millis(1)); // Don't waste cycles.
31//!     // NOTE: this is just an example don't actually use a loop like this, it
32//!     // will waste CPU cycles when the channel is empty!
33//!     loop {
34//!         match receiver.try_recv() {
35//!             Ok(value) => println!("Got a value: {}", value),
36//!             Err(RecvError::NoValue) => continue,
37//!             Err(RecvError::Disconnected) => break,
38//!         }
39//!     }
40//! });
41//!
42//! sender_handle.join().unwrap();
43//! receiver_handle.join().unwrap();
44//! ```
45
46use std::cell::UnsafeCell;
47use std::fmt;
48use std::future::Future;
49use std::mem::MaybeUninit;
50use std::pin::Pin;
51use std::ptr::{self, NonNull};
52use std::sync::atomic::{AtomicU8, Ordering};
53use std::sync::Mutex;
54use std::task::{self, Poll};
55
56/// Create a new one-shot channel.
57pub fn new_oneshot<T>() -> (Sender<T>, Receiver<T>) {
58    let shared = NonNull::from(Box::leak(Box::new(Shared::new())));
59    (Sender { shared }, Receiver { shared })
60}
61
62/// Bits mask to mark the receiver as alive.
63const RECEIVER_ALIVE: u8 = 0b1000_0000;
64/// Bit mask to mark the sender as alive.
65const SENDER_ALIVE: u8 = 0b0100_0000;
66/// Bit mask to mark the sender still has access to the shared data.
67const SENDER_ACCESS: u8 = 0b0010_0000;
68
69/// Return `true` if the receiver is alive in `status`.
70#[inline(always)]
71const fn has_receiver(status: u8) -> bool {
72    status & RECEIVER_ALIVE != 0
73}
74
75/// Return `true` if the sender is alive in `status`.
76#[inline(always)]
77const fn has_sender(status: u8) -> bool {
78    status & SENDER_ALIVE != 0
79}
80
81/// Return `true` if the sender has access in `status`.
82#[inline(always)]
83const fn has_sender_access(status: u8) -> bool {
84    status & SENDER_ACCESS != 0
85}
86
87// Status of the message in `Shared`.
88const EMPTY: u8 = 0b0000_0000;
89const FILLED: u8 = 0b0000_0001;
90
91// Status transitions.
92const MARK_FILLED: u8 = 0b0000_0001; // ADD to go from EMPTY -> FILLED.
93const MARK_EMPTY: u8 = !MARK_FILLED; // AND to go from FILLED -> EMPTY.
94/// Initial state value, also used to reset the status.
95const INITIAL: u8 = RECEIVER_ALIVE | SENDER_ALIVE | SENDER_ACCESS | EMPTY;
96
97/// Returns `true` if `status` is empty.
98#[inline(always)]
99const fn is_empty(status: u8) -> bool {
100    status & FILLED == 0
101}
102
103/// Returns `true` if `status` is filled.
104#[inline(always)]
105const fn is_filled(status: u8) -> bool {
106    status & FILLED != 0
107}
108
109/// The sending half of the [one-shot channel].
110///
111/// This half can only be owned and used by one thread.
112///
113/// [one-shot channel]: crate::oneshot::new_oneshot
114pub struct Sender<T> {
115    // Safety: must always point to valid memory.
116    shared: NonNull<Shared<T>>,
117}
118
119impl<T> Sender<T> {
120    /// Attempts to send a `value` into the channel. If this returns an error it
121    /// means the receiver has disconnected (has been dropped).
122    pub fn try_send(self, value: T) -> Result<(), T> {
123        if !self.is_connected() {
124            return Err(value);
125        }
126
127        let shared = self.shared();
128
129        // This is safe because we're the only sender.
130        unsafe { ptr::write(shared.message.get(), MaybeUninit::new(value)) };
131
132        // Mark the item as filled.
133        // Safety: `AcqRel` is required here to ensure the write above is not
134        // moved after this status update.
135        let old_status = shared.status.fetch_add(MARK_FILLED, Ordering::AcqRel);
136        debug_assert!(is_empty(old_status));
137
138        // Note: we wake in the `Drop` impl.
139        Ok(())
140    }
141
142    /// Returns `true` if the [`Receiver`] is connected.
143    pub fn is_connected(&self) -> bool {
144        // Relaxed is fine here since there is always a bit of a race condition
145        // when using the method (and then doing something based on it).
146        let status = self.shared().status.load(Ordering::Relaxed);
147        has_receiver(status)
148    }
149
150    /// Returns `true` if this sender sends to the `receiver`.
151    pub fn sends_to(&self, receiver: &Receiver<T>) -> bool {
152        self.shared == receiver.shared
153    }
154
155    /// Reference the shared data.
156    fn shared(&self) -> &Shared<T> {
157        // Safety: see `shared` field.
158        unsafe { self.shared.as_ref() }
159    }
160}
161
162impl<T> fmt::Debug for Sender<T> {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        f.write_str("Sender")
165    }
166}
167
168// Safety: if the value can be send across thread than so can the channel.
169unsafe impl<T: Send> Send for Sender<T> {}
170
171unsafe impl<T> Sync for Sender<T> {}
172
173impl<T> Drop for Sender<T> {
174    fn drop(&mut self) {
175        // Mark ourselves as dropped, but still holding access.
176        let shared = self.shared();
177        let old_status = shared.status.fetch_and(!SENDER_ALIVE, Ordering::AcqRel);
178
179        if has_receiver(old_status) {
180            // Receiver is still alive, so we need to wake it.
181            if let Some(waker) = shared.receiver_waker.lock().unwrap().take() {
182                waker.wake();
183            }
184        }
185
186        // Now mark that we don't have access anymore.
187        let old_status = shared.status.fetch_and(!SENDER_ACCESS, Ordering::AcqRel);
188        if !has_receiver(old_status) {
189            // Receiver is already dropped so we need to drop the shared memory.
190            unsafe { drop(Box::from_raw(self.shared.as_ptr())) }
191        }
192    }
193}
194
195/// The receiving half of the [one-shot channel].
196///
197/// This half can only be owned and used by one thread.
198///
199/// [one-shot channel]: crate::oneshot::new_oneshot
200pub struct Receiver<T> {
201    // Safety: must always point to valid memory.
202    shared: NonNull<Shared<T>>,
203}
204
205/// Error returned by [`Receiver::try_recv`].
206#[derive(Debug, Eq, PartialEq)]
207pub enum RecvError {
208    /// No value is available, but the sender is still connected.
209    NoValue,
210    /// Sender is disconnected and no value is available.
211    Disconnected,
212}
213
214impl fmt::Display for RecvError {
215    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216        match self {
217            RecvError::NoValue => f.write_str("no value available"),
218            RecvError::Disconnected => f.write_str("sender disconnected"),
219        }
220    }
221}
222
223impl<T> Receiver<T> {
224    /// Attempts to receive a value and reset the channel.
225    ///
226    /// If it succeeds it returns the value and resets the channel, returning a
227    /// new [`Sender`] (which can send a value to this `Receiver`).
228    pub fn try_recv(&mut self) -> Result<T, RecvError> {
229        let shared = self.shared();
230        // Safety: `AcqRel` is required here to ensure it syncs with
231        // `Sender::try_send`'s status update after the write.
232        let status = shared.status.fetch_and(MARK_EMPTY, Ordering::AcqRel);
233
234        if is_empty(status) {
235            if has_sender(status) {
236                // The sender is still connected, thus hasn't send a value yet.
237                Err(RecvError::NoValue)
238            } else {
239                // Sender is disconnected and no value was send.
240                Err(RecvError::Disconnected)
241            }
242        } else {
243            // Safety: since we're the only thread with access this is safe.
244            let msg = unsafe { (&*shared.message.get()).assume_init_read() };
245            Ok(msg)
246        }
247    }
248
249    /// Returns a future that receives a value from the channel, waiting if the
250    /// channel is empty.
251    ///
252    /// If the returned [`Future`] returns `None` it means the [`Sender`] is
253    /// [disconnected] without sending a value. This is the same error as
254    /// [`RecvError::Disconnected`]. [`RecvError::NoValue`] will never be
255    /// returned, the `Future` will return [`Poll::Pending`] instead.
256    ///
257    /// [disconnected]: Receiver::is_connected
258    pub fn recv(&mut self) -> RecvValue<T> {
259        RecvValue { receiver: self }
260    }
261
262    /// Returns an owned version of [`Receiver::recv`] that can only be used
263    /// once.
264    ///
265    /// See [`Receiver::recv`] for more information.
266    pub fn recv_once(self) -> RecvOnce<T> {
267        RecvOnce { receiver: self }
268    }
269
270    /// Attempt to reset the channel.
271    ///
272    /// If the sender is disconnected this will return a new `Sender`. If the
273    /// sender is still connected this will return `None`.
274    ///
275    /// # Notes
276    ///
277    /// If the channel contains a value it will be dropped.
278    pub fn try_reset(&mut self) -> Option<Sender<T>> {
279        let shared = self.shared();
280        // Safety: `Acquire` is required here to ensure it syncs with
281        // `Sender::try_send`'s status update after the write.
282        let status = shared.status.load(Ordering::Acquire);
283
284        // NOTE: we need to check `SENDER_ACCESS` here as we're going to
285        // overwrite (`store`) the status below. If the `Sender` was not yet
286        // fully dropped (i.e. unset `SENDER_ACCESS`) this can lead to
287        // use-after-free and double-free.
288        if has_sender_access(status) {
289            // The sender is still connected, can't reset yet.
290            return None;
291        } else if is_filled(status) {
292            // Sender send a value we need to drop.
293            // Safety: since the sender is no longer alive (checked above) we're
294            // the only type (and thread) with access making this safe.
295            unsafe { (&mut *shared.message.get()).assume_init_drop() }
296        }
297
298        // Reset the status.
299        // Safety: since the `Sender` has been dropped we have unique access to
300        // `shared` making Relaxed ordering fine.
301        shared.status.store(INITIAL, Ordering::Release);
302
303        Some(Sender {
304            shared: self.shared,
305        })
306    }
307
308    /// Returns `true` if the `Sender` is connected.
309    pub fn is_connected(&self) -> bool {
310        // Relaxed is fine here since there is always a bit of a race condition
311        // when using the method (and then doing something based on it).
312        let status = self.shared().status.load(Ordering::Relaxed);
313        has_sender(status)
314    }
315
316    /// Set the receiver's waker to `waker`, if they are different. Returns
317    /// `true` if the waker is changed, `false` otherwise.
318    ///
319    /// This is useful if you can't call [`Receiver::recv`] but still want a
320    /// wake-up notification once messages are added to the inbox.
321    pub fn register_waker(&mut self, waker: &task::Waker) -> bool {
322        let shared = self.shared();
323        let mut receiver_waker = shared.receiver_waker.lock().unwrap();
324
325        if let Some(receiver_waker) = &*receiver_waker {
326            if receiver_waker.will_wake(waker) {
327                return false;
328            }
329        }
330
331        *receiver_waker = Some(waker.clone());
332        drop(receiver_waker);
333
334        true
335    }
336
337    /// Reference the shared data.
338    fn shared(&self) -> &Shared<T> {
339        // Safety: see `shared` field.
340        unsafe { self.shared.as_ref() }
341    }
342}
343
344impl<T> fmt::Debug for Receiver<T> {
345    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346        f.write_str("Receiver")
347    }
348}
349
350unsafe impl<T: Send> Send for Receiver<T> {}
351unsafe impl<T: Send> Sync for Receiver<T> {}
352
353impl<T> Drop for Receiver<T> {
354    fn drop(&mut self) {
355        // Mark ourselves as dropped.
356        let shared = self.shared();
357        let old_status = shared.status.fetch_and(!RECEIVER_ALIVE, Ordering::AcqRel);
358
359        if !has_sender_access(old_status) {
360            // Sender was already dropped, we need to drop the shared memory.
361            unsafe { drop(Box::from_raw(self.shared.as_ptr())) }
362        }
363    }
364}
365
366/// [`Future`] implementation behind [`Receiver::recv`].
367#[derive(Debug)]
368#[must_use = "futures do nothing unless you `.await` or poll them"]
369pub struct RecvValue<'r, T> {
370    receiver: &'r mut Receiver<T>,
371}
372
373macro_rules! recv_future_impl {
374    ($self: ident, $ctx: ident) => {
375        match $self.receiver.try_recv() {
376            Ok(ok) => Poll::Ready(Some(ok)),
377            Err(RecvError::NoValue) => {
378                // The sender hasn't send a value yet, we'll set the waker.
379                if !$self.receiver.register_waker($ctx.waker()) {
380                    // Waker already set.
381                    return Poll::Pending;
382                }
383
384                // It could be the case that the sender send a value in the time
385                // between we last checked and we actually marked ourselves as
386                // needing a wake up, so we need to check again.
387                match $self.receiver.try_recv() {
388                    Ok(ok) => Poll::Ready(Some(ok)),
389                    // The `Sender` will wake us when the message is send.
390                    Err(RecvError::NoValue) => Poll::Pending,
391                    Err(RecvError::Disconnected) => Poll::Ready(None),
392                }
393            }
394            Err(RecvError::Disconnected) => Poll::Ready(None),
395        }
396    };
397}
398
399impl<'r, T> Future for RecvValue<'r, T> {
400    type Output = Option<T>;
401
402    fn poll(mut self: Pin<&mut Self>, ctx: &mut task::Context) -> Poll<Self::Output> {
403        recv_future_impl!(self, ctx)
404    }
405}
406
407impl<'r, T> Unpin for RecvValue<'r, T> {}
408
409/// [`Future`] implementation behind [`Receiver::recv_once`].
410#[derive(Debug)]
411#[must_use = "futures do nothing unless you `.await` or poll them"]
412pub struct RecvOnce<T> {
413    receiver: Receiver<T>,
414}
415
416impl<T> Future for RecvOnce<T> {
417    type Output = Option<T>;
418
419    fn poll(mut self: Pin<&mut Self>, ctx: &mut task::Context) -> Poll<Self::Output> {
420        recv_future_impl!(self, ctx)
421    }
422}
423
424impl<T> Unpin for RecvOnce<T> {}
425
426/// Data shared between [`Sender`] and [`Receiver`].
427struct Shared<T> {
428    /// A merging of the status of `message` and the liveness of the sender and
429    /// receiver.
430    status: AtomicU8,
431    /// The message that may, or may not, be initialised depending on `status`.
432    message: UnsafeCell<MaybeUninit<T>>,
433    /// Waker used to wake the receiving end.
434    receiver_waker: Mutex<Option<task::Waker>>,
435}
436
437impl<T> Shared<T> {
438    /// Create a new `Shared` structure.
439    const fn new() -> Shared<T> {
440        Shared {
441            status: AtomicU8::new(INITIAL),
442            message: UnsafeCell::new(MaybeUninit::uninit()),
443            receiver_waker: Mutex::new(None),
444        }
445    }
446}
447
448impl<T> Drop for Shared<T> {
449    fn drop(&mut self) {
450        let status = self.status.load(Ordering::Relaxed);
451        if is_filled(status) {
452            unsafe { ptr::drop_in_place((&mut *self.message.get()).as_mut_ptr()) }
453        }
454    }
455}