handoff/
lib.rs

1/*!
2`handoff` is a single-producer / single-consumer, unbuffered, asynchronous
3channel. It's intended for cases where you want blocking communication between
4two async components, where all sends block until the receiver receives the
5item.
6
7A new channel is created with [`channel`], which returns a [`Sender`] and
8[`Receiver`]. Items can be sent into the channel with [`Sender::send`], and
9received with [`Receiver::recv`]. [`Receiver`] also implements
10[`futures::Stream`][Stream]. Either end of the channel can be dropped, which will
11cause the other end to unblock and report channel disconnection.
12
13While the channel operates asynchronously, it can also be used in a fully
14synchronous way by using `block_on` or similar utilities provided in most
15async runtimes.
16
17# Examples
18
19## Basic example
20
21```rust
22# futures::executor::block_on(async move {
23use handoff::channel;
24use futures::future::join;
25
26let (mut sender, mut receiver) = channel();
27
28let send_task = async move {
29    for i in 0..100 {
30        sender.send(i).await.expect("channel disconnected");
31    }
32};
33
34let recv_task = async move {
35    for i in 0..100 {
36        let value = receiver.recv().await.expect("channel disconnected");
37        assert_eq!(value, i);
38    }
39};
40
41// All sends block until the receiver accepts the value, so we need to make
42// sure the tasks happen concurrently
43join(send_task, recv_task).await;
44# });
45```
46
47## Synchronous use
48
49```
50use std::thread;
51use handoff::channel;
52use futures::executor::block_on;
53
54let (mut sender, mut receiver) = channel();
55
56let sender_thread = thread::spawn(move || {
57    for i in 0..100 {
58        block_on(sender.send(i)).expect("receiver disconnected");
59    }
60});
61
62let receiver_thread = thread::spawn(move || {
63    for i in 0..100 {
64        let value = block_on(receiver.recv()).expect("sender disconnected");
65        assert_eq!(value, i);
66    }
67});
68
69sender_thread.join().expect("sender panicked");
70receiver_thread.join().expect("receiver panicked");
71```
72
73## Disconnect
74
75```
76# futures::executor::block_on(async move {
77use handoff::channel;
78use futures::future::join;
79
80let (mut sender, mut receiver) = channel();
81
82let send_task = async move {
83    for i in 0..50 {
84        sender.send(i).await.expect("channel disconnected");
85    }
86};
87
88let recv_task = async move {
89    for i in 0..50 {
90        let value = receiver.recv().await.expect("channel disconnected");
91        assert_eq!(value, i);
92    }
93
94    assert!(receiver.recv().await.is_none());
95};
96
97// All sends block until the receiver accepts the value, so we need to make
98// sure the tasks happen concurrently
99join(send_task, recv_task).await;
100# });
101```
102*/
103
104#![deny(missing_docs)]
105#![deny(missing_debug_implementations)]
106
107use std::{
108    cell::UnsafeCell,
109    fmt::Debug,
110    future::Future,
111    hint::unreachable_unchecked,
112    ops::Not,
113    pin::Pin,
114    ptr::{self, NonNull},
115    sync::atomic::{
116        AtomicPtr,
117        Ordering::{Acquire, Relaxed, Release},
118    },
119    task::{Context, Poll},
120    thread,
121};
122
123trait UnsafeCellExt<T> {
124    #[must_use]
125    fn get_non_null(&self) -> NonNull<T>;
126}
127
128impl<T> UnsafeCellExt<T> for UnsafeCell<T> {
129    #[inline]
130    #[must_use]
131    fn get_non_null(&self) -> NonNull<T> {
132        NonNull::new(self.get()).expect("UnsafeCell shouldn't return a null pointer")
133    }
134}
135
136use futures_util::{
137    stream::{FusedStream, Stream, StreamExt},
138    task::AtomicWaker,
139};
140use pin_project::{pin_project, pinned_drop};
141use pinned_aliasable::Aliasable;
142use thiserror::Error;
143use twinsies::Joint;
144
145/// Identical to `unreachable_unchecked`, but panics in debug mode. Still
146/// requires unsafe.
147macro_rules! debug_unreachable {
148    ($($arg:tt)*) => {
149        match cfg!(debug_assertions) {
150            true => unreachable!($($arg)*),
151            false => unreachable_unchecked(),
152        }
153    }
154}
155
156/// Create an unbuffered channel for communicating between a pair of
157/// asynchronous components. All sends over this channel will block until the
158/// receiver receives the sent item. See [crate documentation][crate] for
159/// details.
160pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
161    let (send_joint, recv_joint) = Joint::new(Inner {
162        sent_item: AtomicPtr::default(),
163        sender_waker: AtomicWaker::new(),
164        receiver_waker: AtomicWaker::new(),
165    });
166
167    (Sender { inner: send_joint }, Receiver { inner: recv_joint })
168}
169
170struct Inner<T> {
171    // When this is not null, there's an object that a sender is trying to send
172    // (and is asynchronously blocked until the send completes)
173    sent_item: AtomicPtr<Option<T>>,
174
175    // The waker owned by the sender. Should be signalled when the receiver
176    // takes a value (or disconnects)
177    sender_waker: AtomicWaker,
178
179    // The waker owned by the receiver. Should be signalled when the sender has
180    // an item to send (or disconnects)
181    receiver_waker: AtomicWaker,
182}
183
184unsafe impl<T> Send for Inner<T> {}
185unsafe impl<T> Sync for Inner<T> {}
186
187impl<T> Inner<T> {
188    /// The sender uses this to take an item pointer that it placed there, to
189    /// regain exclusive access to its item.
190    #[inline]
191    fn reclaim_sent_item_pointer(&self, item_pointer: NonNull<Option<T>>) {
192        loop {
193            match self.sent_item.compare_exchange_weak(
194                item_pointer.as_ptr(),
195                ptr::null_mut(),
196                Acquire,
197                Relaxed,
198            ) {
199                Ok(_) => break,
200
201                // Spurious failure
202                Err(current) if current == item_pointer.as_ptr() => continue,
203
204                // Receiver owns the value; spin while we wait for it
205                //
206                // TODO: consider using something like the spinner from
207                // parking_lot_core. We're pretty certain that another thread is
208                // working with the pointer, though, so for now we're content to
209                // do a full yield and let it have a chance to finish its work.
210                Err(current) if current.is_null() => thread::yield_now(),
211
212                // Something very wrong happened
213                Err(current) => unsafe {
214                    debug_unreachable!(
215                        "A new pointer ({current:p}) appeared in inner \
216                        while a sender exists ({item_pointer:p}); this \
217                        should never happen"
218                    )
219                },
220            }
221        }
222    }
223}
224
225/// Whenever `Inner` drops, it means a disconnect is happening. Inform the
226/// sender and receiver (though one of them, of course, is being dropped
227/// anyway). It's guaranteed that, once `Inner::drop` is called, the `Joint`
228impl<T> Drop for Inner<T> {
229    fn drop(&mut self) {
230        self.sender_waker.wake();
231        self.receiver_waker.wake();
232    }
233}
234
235/// The sending end of a handoff channel.
236///
237/// This object is created by the [`channel`] function. See [crate
238/// documentation][crate] for details.
239pub struct Sender<T> {
240    inner: Joint<Inner<T>>,
241}
242
243impl<T> Sender<T> {
244    /// Asynchronously send an item to the receiver.
245    ///
246    /// This method will asynchronously block until the receiver has received
247    /// the item. If the receiver disconnects, this will instead return a
248    /// [`SendError`] containing the item that failed to send.
249    #[inline]
250    #[must_use]
251    pub fn send(&mut self, item: T) -> SendFut<'_, T> {
252        SendFut {
253            item: Aliasable::new(UnsafeCell::new(Some(item))),
254            inner: &self.inner,
255            item_lent: false,
256        }
257    }
258
259    // TODO: `Sink` implementation. This will require wrapping the sender. Need
260    // to decide if we prefer a by-move or by-ref sink (probably the latter).
261    // Alternatively, create a crate with a general-purpose adapter between
262    // `async fn send` and `Sink`.
263}
264
265impl<T> Debug for Sender<T> {
266    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267        f.debug_struct("Sender")
268            .field("inner", &self.inner)
269            .finish()
270    }
271}
272
273unsafe impl<T: Send> Send for Sender<T> {}
274
275/// Future for sending a single item through a [`Sender`], created by the
276/// [`send`][Sender::send] method. See its documentation for details.
277#[pin_project(PinnedDrop)]
278pub struct SendFut<'a, T> {
279    // Implementation note: It is critically important to remember that the
280    // contents of the cell here can be aliased even when we have a reference
281    // to it, so long as it's pinned. See the `Aliasable` docs for details.
282    #[pin]
283    item: Aliasable<UnsafeCell<Option<T>>>,
284
285    // We don't want to hold a `JointLock` on the Inner<T>; we want to
286    // check each time we're polled if there was a disconnect.
287    inner: &'a Joint<Inner<T>>,
288
289    // If item_lent is true, it means that `Inner` has ownership of `item`
290    // and we need to re-acquire the pointer before doing anything with
291    // it.
292    item_lent: bool,
293}
294
295impl<T> Debug for SendFut<'_, T> {
296    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        f.debug_struct("SendFut")
298            .field("item", &"<in transit>")
299            .field("inner", &self.inner)
300            .field("item_lent", &self.item_lent)
301            .finish()
302    }
303}
304
305impl<T> Future for SendFut<'_, T> {
306    type Output = Result<(), SendError<T>>;
307
308    #[inline]
309    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
310        let this = self.project();
311
312        let mut item_pointer = this.item.as_ref().get().get_non_null();
313
314        let Some(lock) = this.inner.lock() else {
315            return Poll::Ready(
316                // Safety: if we couldn't acquire a lock, it means that the
317                // `Inner` dropped, which means we definitely have exclusive
318                // access to the value.
319                match unsafe { item_pointer.as_mut() }
320                    .take()
321                {
322                    Some(item) => Err(SendError(item)),
323                    None => Ok(()),
324                },
325            )
326        };
327
328        // If we've published the item pointer for the receiver to take, check
329        // to see if it successfully took the item.
330        if *this.item_lent {
331            // If we've previously polled, we're aiming to check and see if the
332            // item has been taken by the receiver yet. We need to first take
333            // the `item` pointer, to ensure we have exclusive access to the
334            // item
335            lock.reclaim_sent_item_pointer(item_pointer);
336
337            // For consistency, we always update this field after reclaiming the
338            // pointer. We specifically want it to be false so that our
339            // destructor knows it doesn't need to do any additional work.
340            *this.item_lent = false;
341
342            // We've acquired exclusive access to the item pointer; we can check
343            // to see if the item was taken yet.
344            if unsafe { item_pointer.as_ref() }.is_none() {
345                return Poll::Ready(Ok(()));
346            }
347        }
348
349        // At this point, we've either never been polled before, or we have been
350        // polled previously but we still have the item. The state is the same
351        // either way: the `Inner` contains a null pointer and we need to notify
352        // the receiver that a value is ready.
353        //
354        // Theoretically, the inner pointer could be non-null, but this only
355        // happens if we leaked a `send` future, so we can just clobber it.
356        // Similarly, we can theoretically not have the item, if we're polled
357        // again after returning Ready. Neither of these cause unsoundness.
358        debug_assert!(
359            unsafe { item_pointer.as_ref() }.is_some(),
360            "Don't poll futures after they returned success"
361        );
362
363        lock.sender_waker.register(cx.waker());
364        lock.sent_item.store(item_pointer.as_ptr(), Release);
365        lock.receiver_waker.wake();
366        *this.item_lent = true;
367
368        Poll::Pending
369    }
370}
371
372#[pinned_drop]
373impl<T> PinnedDrop for SendFut<'_, T> {
374    fn drop(self: Pin<&mut Self>) {
375        let this = self.project();
376
377        // We only need to do extra drop work if `Inner` has exclusive access to
378        // our `item`.
379        if this.item_lent.not() {
380            return;
381        };
382
383        // If we disconnected, there's nothing else we need to do. Even if
384        // `item_lent` was true, `inner` was dropped and implicitly doesn't have
385        // access to the `item` anymore.
386        let Some(lock) = this.inner.lock() else {
387            return;
388        };
389
390        // When an individual send future drops, we can immediately
391        // erase the waker. No send notification are necessary until a
392        // new send future appears.
393        drop(lock.sender_waker.take());
394
395        let item_pointer = this.item.into_ref().get().get_non_null();
396
397        // Okay, we need to acquire the pointer before we can drop. This
398        // might involve spinning if the receiver is working with it
399        // right now.
400        lock.reclaim_sent_item_pointer(item_pointer);
401        // Now that we've reclaimed the pointer, we don't need to do
402        // anything else. The drop can proceed normally.
403    }
404}
405
406unsafe impl<T: Send> Send for SendFut<'_, T> {}
407
408// TODO: verify that this is sound. I believe it is in all practical
409// cases, since there isn't actually any uncontrolled mechanism in this
410// crate by which a reference to `item` might be used while it's owned
411// by the channel
412unsafe impl<T> Sync for SendFut<'_, T> {}
413
414/// The receiving end of a handoff channel.
415///
416/// This object is created by the [`channel`] function. See [crate
417/// documentation][crate] for details.
418///
419/// [`Receiver`] only provides a simple [`recv`][Receiver::recv] method on its
420/// own, but it also implements [`futures::StreamExt`][StreamExt], which provides a number
421/// of additional helpful iterator-like methods.
422pub struct Receiver<T> {
423    inner: Joint<Inner<T>>,
424}
425
426impl<T> Receiver<T> {
427    /// Attempt to receive the next item from the sender.
428    ///
429    /// This method will asynchronously block until the sender sends an item,
430    /// then return that item. Alternatively, if the sender disconnects, this
431    /// will return `None`.
432    #[inline]
433    pub fn recv(&mut self) -> RecvFut<'_, T> {
434        RecvFut { receiver: self }
435    }
436}
437
438unsafe impl<T: Send> Send for Receiver<T> {}
439
440impl<T> Debug for Receiver<T> {
441    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
442        f.debug_struct("Receiver")
443            .field("inner", &self.inner)
444            .finish()
445    }
446}
447
448impl<T> Stream for Receiver<T> {
449    type Item = T;
450
451    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
452        let Some(lock) = self.inner.lock() else { return Poll::Ready(None) };
453
454        // All of the logic for actually attempting to take an item from
455        // `inner`. We have to call this twice, because we first try to take an
456        // item, then register a waker, then we have to try again after
457        // registering the waker. This avoids a race where we fail to retrieve
458        // and item, then the sender places an item, then the sender calls
459        // wake() before we've registered our waker.
460        //
461        // TODO: bench {recv; register(waker); recv} against {register(waker); recv}
462
463        let try_recv = || loop {
464            // Acquire the pointer. As long as we have it, we have exclusive
465            // access to the item. The sender will wait for us to return the
466            // pointer before dropping (or, if it leaks, the value is pinned, so
467            // the pointer is valid forever in that case).
468            let sent_item_ptr = lock.sent_item.swap(ptr::null_mut(), Acquire);
469
470            // If there wasn't a pointer available, we've already registered our
471            // waker, so at this point we're waiting for a signal to try another
472            // receive operation.
473            let Some(mut sent_item_ptr) = NonNull::new(sent_item_ptr) else {
474                return Poll::Pending
475            };
476
477            // Try to read the item from the pointer. It's possible that we've
478            // already taken it and this is a spurious poll.
479            //
480            // SAFETY: Because we acquired the `sent_item_ptr` (replacing it
481            // with a null ptr), we have exclusive access to it.
482            let sent_item = unsafe { sent_item_ptr.as_mut() }.take();
483
484            // We don't need to retry (non-spurious) failures, since the
485            // presence of a new non-null pointer indicates a sender leak, which
486            // means we can simply drop the `sent_item_ptr` outright.
487            match lock.sent_item.compare_exchange(
488                ptr::null_mut(),
489                sent_item_ptr.as_ptr(),
490                Release,
491                Relaxed,
492            ) {
493                // We restored the pointer, so we need to wake the sender so it
494                // can proceed with the drop
495                Ok(_) => lock.sender_waker.wake(),
496
497                // Somehow the pointer to a pinned object found its way back
498                // into the slot. This shouldn't be possible, since that memory
499                // should be usable until the sender finishes sending, and it
500                // can't drop until we restore the pointer.
501                Err(p) if p == sent_item_ptr.as_ptr() => unsafe { debug_unreachable!() },
502
503                // There was a leak and a new sent item arrived while we were
504                // working. If we didn't receive an item, we can retry receiving
505                // this *new* item.
506                Err(_) if sent_item.is_none() => continue,
507
508                // There was a leak and a new sent item arrived while we were
509                // working. We already got an item, so we have to leave the new
510                // one there until a subsequent `recv`.
511                Err(_) => {}
512            }
513
514            return match sent_item {
515                Some(item) => Poll::Ready(Some(item)),
516                None => Poll::Pending,
517            };
518        };
519
520        match try_recv() {
521            Poll::Ready(item) => Poll::Ready(item),
522            Poll::Pending => {
523                lock.receiver_waker.register(cx.waker());
524                try_recv()
525            }
526        }
527    }
528
529    #[inline]
530    #[must_use]
531    fn size_hint(&self) -> (usize, Option<usize>) {
532        (0, if self.inner.alive() { None } else { Some(0) })
533    }
534}
535
536impl<T> FusedStream for Receiver<T> {
537    fn is_terminated(&self) -> bool {
538        !self.inner.alive()
539    }
540}
541
542// TODO: this drop is only necessary if the receiver has been acting as a
543// `futures::Stream`. Consider creating a separate wrapper around it that
544// implements stream.
545impl<T> Drop for Receiver<T> {
546    fn drop(&mut self) {
547        let Some(lock) = self.inner.lock() else { return };
548        drop(lock.receiver_waker.take())
549    }
550}
551
552/// Future type for receiving a single item from a [`Receiver`]. Created by the
553/// [`recv`][Receiver::recv] method; see its documentation for details.
554pub struct RecvFut<'a, T> {
555    receiver: &'a mut Receiver<T>,
556}
557
558impl<T> Debug for RecvFut<'_, T> {
559    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
560        f.debug_struct("Recv")
561            .field("receiver", &self.receiver)
562            .finish()
563    }
564}
565
566impl<T> Future for RecvFut<'_, T> {
567    type Output = Option<T>;
568
569    #[inline]
570    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
571        self.receiver.poll_next_unpin(cx)
572    }
573}
574
575impl<T> Drop for RecvFut<'_, T> {
576    #[inline]
577    fn drop(&mut self) {
578        let Some(lock) = self.receiver.inner.lock() else { return };
579        drop(lock.receiver_waker.take())
580    }
581}
582
583/// An error from a [`send()`][Sender::send] operation.
584///
585/// This error means the send failed due to a disconnect; this is the only way
586/// sends can fail. The error contains the item that failed to send.
587#[derive(Error, Clone, Debug, Copy)]
588#[error("tried to send on a disconnected channel")]
589pub struct SendError<T>(
590    /// The item that failed to send
591    pub T,
592);
593
594#[cfg(test)]
595mod tests {
596    use std::thread;
597
598    use cool_asserts::assert_matches;
599    use futures::{executor::block_on, StreamExt};
600
601    use super::{channel, SendError};
602
603    #[tokio::test]
604    async fn basic_test() {
605        let (mut sender, receiver) = channel();
606
607        let sender_task = tokio::task::spawn(async move {
608            sender.send(1).await.unwrap();
609            sender.send(2).await.unwrap();
610            sender.send(3).await.unwrap();
611            sender.send(4).await.unwrap();
612        });
613
614        let data: Vec<i32> = receiver.collect().await;
615        sender_task.await.unwrap();
616
617        assert_eq!(data, [1, 2, 3, 4]);
618    }
619
620    #[tokio::test]
621    async fn taskless() {
622        let (mut sender, mut receiver) = channel();
623
624        let (send, recv) = futures::future::join(sender.send(1), receiver.next()).await;
625        send.unwrap();
626        assert_eq!(recv.unwrap(), 1);
627
628        let (recv, send) = futures::future::join(receiver.next(), sender.send(2)).await;
629        send.unwrap();
630        assert_eq!(recv.unwrap(), 2);
631    }
632
633    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
634    async fn multi_thread_tasks() {
635        let (mut sender, mut receiver) = channel();
636
637        let sender_task = tokio::task::spawn(async move {
638            for i in 0..1_000 {
639                sender.send(i).await.unwrap();
640            }
641        });
642
643        let receiver_task = tokio::task::spawn(async move {
644            for i in 0..1_000 {
645                assert_eq!(receiver.next().await.unwrap(), i);
646            }
647        });
648
649        sender_task.await.unwrap();
650        receiver_task.await.unwrap();
651    }
652
653    #[test]
654    fn sync_test() {
655        let (mut sender, mut receiver) = channel();
656
657        let sender_thread = thread::Builder::new()
658            .name("sender thread".to_owned())
659            .spawn(move || {
660                block_on(async move {
661                    for i in 0..1_000 {
662                        sender.send(i).await.unwrap();
663                    }
664                })
665            })
666            .expect("failed to spawn thread");
667
668        for i in 0..1_000 {
669            assert_eq!(block_on(receiver.next()).unwrap(), i);
670        }
671
672        sender_thread.join().unwrap();
673    }
674
675    #[test]
676    fn sync_test_two_threads() {
677        let (mut sender, mut receiver) = channel();
678
679        let sender_thread = thread::spawn(move || {
680            for i in 0..100 {
681                block_on(sender.send(i)).expect("receiver disconnected");
682            }
683        });
684
685        let receiver_thread = thread::spawn(move || {
686            for i in 0..100 {
687                let value = block_on(receiver.recv()).expect("sender disconnected");
688                assert_eq!(value, i);
689            }
690        });
691
692        sender_thread.join().expect("sender panicked");
693        receiver_thread.join().expect("receiver panicked");
694    }
695
696    #[tokio::test]
697    async fn basic_sender_close() {
698        let (sender, mut receiver) = channel();
699
700        drop(sender);
701
702        let out: Option<i32> = receiver.recv().await;
703        assert_eq!(out, None);
704    }
705
706    #[tokio::test]
707    async fn basic_receiver_close() {
708        let (mut sender, receiver) = channel();
709
710        drop(receiver);
711
712        assert_matches!(sender.send(1).await, Err(SendError(1)));
713    }
714
715    #[tokio::test]
716    async fn sender_close_while_waiting() {
717        let (sender, mut receiver) = channel();
718
719        let sender_task = tokio::task::spawn(async move {
720            tokio::task::yield_now().await;
721            drop(sender);
722        });
723
724        let out: Option<i32> = receiver.recv().await;
725        assert_eq!(out, None);
726        sender_task.await.unwrap();
727    }
728
729    #[tokio::test]
730    async fn receiver_close_while_waiting() {
731        let (mut sender, receiver) = channel();
732
733        let receiver_task = tokio::task::spawn(async move {
734            tokio::task::yield_now().await;
735            drop(receiver);
736        });
737
738        assert_matches!(sender.send(1).await, Err(SendError(1)));
739        receiver_task.await.unwrap();
740    }
741
742    #[tokio::test]
743    async fn sender_cancels() {
744        let (mut sender, mut receiver) = channel();
745
746        let sender_task = tokio::task::spawn(async move {
747            sender.send(1).await.unwrap();
748            sender.send(2).await.unwrap();
749        });
750
751        assert_eq!(receiver.next().await.unwrap(), 1);
752        sender_task.abort();
753        assert_matches!(receiver.next().await, None);
754        assert_matches!(sender_task.await, Err(err) => assert!(err.is_cancelled()));
755    }
756
757    // TODO: test sender leak
758
759    // TODO: bench compare various channels
760}