Skip to main content

sync_oneshot/
lib.rs

1//! A minimal oneshot channel for synchronous Rust.
2//!
3//! A oneshot channel is used for sending a single message between threads.
4//! The [`channel`] function is used to create a [`Sender`] and [`Receiver`]
5//! handle pair that form the channel.
6//!
7//! - The [`Sender`] handle is used by the producer to send the value.
8//! - The [`Receiver`] handle is used by the consumer to receive the value.
9//!
10//! Each handle can be used on other threads.
11//!
12//! - [`Sender::send`] will no block the calling thread.
13//! - [`Receiver::recv`] will **block** the calling thread.
14//!
15//! # Example
16//! ```rust
17//! # use std::time::Duration;
18//! let (tx, rx) = sync_oneshot::channel();
19//!
20//! std::thread::spawn(move || {
21//!     std::thread::sleep(Duration::from_millis(200));
22//!     tx.send(5).unwrap();
23//! });
24//!
25//! // blocking thread until a message available
26//! let val = rx.recv().unwrap();
27//! assert_eq!(val, 5);
28//! ```
29#[cfg(loom)]
30use loom::{
31    sync::{
32        Arc,
33        atomic::{AtomicUsize, Ordering},
34    },
35    thread,
36};
37
38use std::fmt;
39#[cfg(not(loom))]
40use std::{
41    sync::{
42        Arc,
43        atomic::{AtomicUsize, Ordering},
44    },
45    thread,
46};
47
48use crate::{notify::Notify, slot::Slot};
49
50mod error;
51mod notify;
52mod slot;
53
54pub use error::{RecvError, TryRecvError};
55
56/// Creates a new oneshot channel, returning the sender/receiver halves.
57///
58/// The [`Sender`] is used by the producer to send the value.
59/// The [`Receiver`] handle is used by the consumer to receive the value.
60///
61/// [`send`](Sender::send) will no block the calling thread. [`recv`](Receiver::recv)
62/// will **block** until a message is available.
63#[inline]
64pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
65    let inner = Arc::new(Inner {
66        state: AtomicUsize::new(0),
67        value: Slot::new(),
68        notify: Notify::new(),
69    });
70
71    (
72        Sender {
73            inner: Some(inner.clone()),
74        },
75        Receiver { inner: Some(inner) },
76    )
77}
78
79/// Sends a value to the associated [`Receiver`].
80///
81/// This is created by the [`channel`] function.
82/// Messages can be sent using [`send`](Sender::send).
83#[derive(Debug)]
84pub struct Sender<T> {
85    inner: Option<Arc<Inner<T>>>,
86}
87
88/// Receive a value from the associated [`Sender`].
89///
90/// This is created by the [`channel`] function.
91/// Messages sent to the channel can be retrieved using [`recv`](Receiver::recv).
92/// [`recv`](Receiver::recv) method blocks thread.
93#[derive(Debug)]
94pub struct Receiver<T> {
95    inner: Option<Arc<Inner<T>>>,
96}
97
98unsafe impl<T> Send for Sender<T> where T: Send {}
99unsafe impl<T> Sync for Sender<T> where T: Send {}
100
101unsafe impl<T> Send for Receiver<T> where T: Send {}
102unsafe impl<T> Sync for Receiver<T> where T: Send {}
103
104struct Inner<T> {
105    state: AtomicUsize,
106    value: Slot<T>,
107    notify: Notify,
108}
109
110/*
111 *
112 * ===== impl Sender =====
113 *
114 */
115impl<T> Sender<T> {
116    /// Attempts to send a value on this channel, returning it back if it could not be sent.
117    ///
118    /// A successful send occurs when it is determined that the other end of the
119    /// channel has not hung up already. An unsuccessful send would be one where
120    /// the corresponding receiver has already been deallocated. Note that a
121    /// return value of [`Err`] means that the data will never be received, but
122    /// a return value of [`Ok`] does *not* mean that the data will be received.
123    /// It is possible for the corresponding receiver to hang up immediately
124    /// after this function returns [`Ok`].
125    ///
126    /// This method will never block the current thread.
127    /// # Example
128    /// ```rust
129    /// let (tx, rx) = sync_oneshot::channel();
130    /// std::thread::spawn(move || {
131    ///     if let Err(e) = tx.send(5) {
132    ///         println!("the receiver dropped");
133    ///     }
134    /// });
135    ///
136    /// match rx.recv() {
137    ///     Ok(v) => println!("got = {:?}", v),
138    ///     Err(_) => println!("the sender dropped"),
139    /// }
140    /// ```
141    #[inline]
142    pub fn send(mut self, value: T) -> Result<(), T> {
143        // take inner
144        // The case inner None is unreachable
145        let inner = self.inner.take().unwrap();
146
147        // set value
148        unsafe {
149            // SAFETY:
150            // Receiver don't access inner value until set status as VALUE_SENT
151            inner.value.set(value);
152        }
153
154        // set state as VALUE_SEND and notify
155        let prev_state = inner.set_complete();
156
157        if prev_state.is_closed() {
158            // SAFETY:
159            // Receiver already has been droped. So can access inner value.
160            return Err(unsafe { inner.consume_value().unwrap() });
161        }
162
163        if prev_state.is_waiting() {
164            unsafe {
165                inner.notify();
166            }
167        }
168
169        Ok(())
170    }
171
172    /// Returns true if the associated Receiver handle has been dropped.
173    ///
174    /// A Receiver is closed by either calling close explicitly or the Receiver value is dropped.
175    /// If true is returned, a call to send will always result in an error.
176    pub fn is_closed(&self) -> bool {
177        let inner = self.inner.as_ref().unwrap();
178        State(inner.state.load(Ordering::Acquire)).is_closed()
179    }
180}
181
182impl<T> Drop for Sender<T> {
183    fn drop(&mut self) {
184        if let Some(inner) = self.inner.take() {
185            let prev_state = inner.set_complete();
186
187            if prev_state.is_waiting() {
188                unsafe {
189                    inner.notify.notify();
190                }
191            }
192        }
193    }
194}
195
196/*
197 *
198 * ===== impl Receiver =====
199 *
200 */
201impl<T> Receiver<T> {
202    /// Attempts to wait for a value on this receiver, returning an error if
203    /// the corresponding channel has hung up.
204    ///
205    /// This function will always block the current thread if there is no data
206    /// available. Once a message is sent to the corresponding [`Sender`],
207    /// this receiver will wake up and return that message.
208    ///
209    /// If the corresponding [`Sender`] has disconnected, or it disconnects while
210    /// this call is blocking, this call will wake up and return [`Err`] to
211    /// indicate that no more messages can ever be received on this channel.
212    /// # Example
213    /// ```rust
214    /// let (tx, rx) = sync_oneshot::channel();
215    ///
216    /// let th_handle = std::thread::spawn(move || {
217    ///     tx.send(5).unwrap();
218    /// });
219    ///
220    /// th_handle.join().unwrap();
221    ///
222    /// assert_eq!(5, rx.recv().unwrap());
223    /// ```
224    #[inline]
225    pub fn recv(mut self) -> Result<T, RecvError> {
226        let inner = self.inner.take().unwrap();
227
228        let mut state = inner.state.load(Ordering::Acquire);
229        loop {
230            if State(state).is_complete() {
231                let value = unsafe { inner.consume_value() };
232                return value.ok_or(RecvError);
233            } else if State(state).is_closed() {
234                return Err(RecvError);
235            }
236
237            unsafe {
238                // SAFETY:
239                // Notify::notify dose not call until state is WAITING.
240                // So we can access notify.
241
242                // Prevent double write due to spurious wake-up.
243                if !State(state).is_waiting() {
244                    inner.notify.set_current();
245                }
246            }
247
248            match inner.state.compare_exchange(
249                state,
250                state | WAITING,
251                Ordering::Release,
252                Ordering::Acquire,
253            ) {
254                Ok(_) => {
255                    thread::park();
256                    state = inner.state.load(Ordering::Acquire);
257                }
258                Err(actual) => state = actual,
259            }
260        }
261    }
262
263    /// Attempts to return a pending value on this receiver without blocking.
264    ///
265    /// This method will never block the caller in order to wait for data to
266    /// become available. Instead, this will always return immediately with a
267    /// possible option of pending data on the channel.
268    ///
269    /// This is useful for a flavor of “optimistic check” before deciding to
270    /// block on a receiver.
271    ///
272    /// Compared with recv, this function has two failure cases instead of one (one for disconnection, one for an empty buffer).
273    #[inline]
274    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
275        let result = if let Some(inner) = self.inner.as_ref() {
276            let state = State(inner.state.load(Ordering::Acquire));
277
278            if state.is_complete() {
279                unsafe {
280                    // SAFETY:
281                    // When state is complete, Sender no longer access value
282                    // Can access value safely
283                    match inner.consume_value() {
284                        Some(value) => Ok(value),
285                        None => Err(TryRecvError::Closed),
286                    }
287                }
288            } else if state.is_closed() {
289                Err(TryRecvError::Closed)
290            } else {
291                return Err(TryRecvError::Empty);
292            }
293        } else {
294            Err(TryRecvError::Closed)
295        };
296
297        self.inner = None;
298        result
299    }
300
301    /// Prevents the associated [`Sender`] handle from sending a value.
302    ///
303    /// Any `send` operation which happens after calling `close` is guaranteed
304    /// to fail. After calling `close`, [`try_recv`] should be called to
305    /// receive a value if one was sent **before** the call to `close`
306    /// completed.
307    ///
308    /// This function is useful to perform a graceful shutdown and ensure that a
309    /// value will not be sent into the channel and never received.
310    ///
311    /// `close` is no-op if a message is already received or the channel
312    /// is already closed.
313    ///
314    /// [`Sender`]: Sender
315    /// [`try_recv`]: Receiver::try_recv
316    ///
317    /// # Examples
318    ///
319    /// Prevent a value from being sent
320    ///
321    /// ```
322    /// use sync_oneshot::TryRecvError;
323    ///
324    /// # fn main() {
325    /// let (tx, mut rx) = sync_oneshot::channel();
326    ///
327    /// assert!(!tx.is_closed());
328    ///
329    /// rx.close();
330    ///
331    /// assert!(tx.is_closed());
332    /// assert!(tx.send("never received").is_err());
333    ///
334    /// match rx.try_recv() {
335    ///     Err(TryRecvError::Closed) => {}
336    ///     _ => unreachable!(),
337    /// }
338    /// # }
339    /// ```
340    ///
341    /// Receive a value sent **before** calling `close`
342    ///
343    /// ```
344    /// # fn main() {
345    /// let (tx, mut rx) = sync_oneshot::channel();
346    ///
347    /// assert!(tx.send("will receive").is_ok());
348    ///
349    /// rx.close();
350    ///
351    /// let msg = rx.try_recv().unwrap();
352    /// assert_eq!(msg, "will receive");
353    /// # }
354    /// ```
355    pub fn close(&mut self) {
356        if let Some(inner) = self.inner.as_ref() {
357            let _ = inner.set_close();
358        }
359    }
360}
361
362impl<T> Drop for Receiver<T> {
363    fn drop(&mut self) {
364        // if inner is some, Receiver::recv is not called before drop.
365        // Drop value or change state
366        if let Some(inner) = self.inner.take() {
367            let prev_state = inner.set_close();
368            if prev_state.is_complete() {
369                unsafe {
370                    inner.consume_value();
371                }
372            }
373        }
374    }
375}
376
377/*
378 *
379 * ===== impl Inner =====
380 *
381 */
382impl<T> Inner<T> {
383    #[inline]
384    fn set_complete(&self) -> State {
385        let mut state = self.state.load(Ordering::Relaxed);
386        loop {
387            if State(state).is_closed() {
388                break;
389            }
390
391            match self.state.compare_exchange_weak(
392                state,
393                state | VALUE_SENT,
394                Ordering::AcqRel,
395                Ordering::Relaxed,
396            ) {
397                Ok(_) => break,
398                Err(actual) => state = actual,
399            }
400        }
401        State(state)
402    }
403
404    #[inline]
405    fn set_close(&self) -> State {
406        State(self.state.fetch_or(CLOSED, Ordering::AcqRel))
407    }
408
409    #[inline]
410    unsafe fn notify(&self) {
411        unsafe {
412            self.notify.notify();
413        }
414    }
415
416    #[inline]
417    unsafe fn consume_value(&self) -> Option<T> {
418        unsafe { self.value.take() }
419    }
420}
421
422impl<T: fmt::Debug> fmt::Debug for Inner<T> {
423    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
424        f.debug_struct("Inner")
425            .field("state", &State(self.state.load(Ordering::Relaxed)))
426            .finish()
427    }
428}
429
430struct State(usize);
431
432const WAITING: usize = 0b0001;
433const VALUE_SENT: usize = 0b0010;
434const CLOSED: usize = 0b0100;
435
436/*
437 *
438 * ===== impl State =====
439 *
440 */
441impl State {
442    #[inline]
443    fn is_closed(&self) -> bool {
444        self.0 & CLOSED == CLOSED
445    }
446
447    #[inline]
448    fn is_waiting(&self) -> bool {
449        self.0 & WAITING == WAITING
450    }
451
452    #[inline]
453    fn is_complete(&self) -> bool {
454        self.0 & VALUE_SENT == VALUE_SENT
455    }
456}
457
458impl fmt::Debug for State {
459    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460        f.debug_struct("State")
461            .field("is_complete", &self.is_complete())
462            .field("is_closed", &self.is_closed())
463            .field("is_waiting", &self.is_waiting())
464            .finish()
465    }
466}