handoff/lib.rs
1/*!
2`handoff` is a single-producer / single-consumer, unbuffered, asynchronous
3channel. It's intended for cases where you want blocking communication between
4two async components, where all sends block until the receiver receives the
5item.
6
7A new channel is created with [`channel`], which returns a [`Sender`] and
8[`Receiver`]. Items can be sent into the channel with [`Sender::send`], and
9received with [`Receiver::recv`]. [`Receiver`] also implements
10[`futures::Stream`][Stream]. Either end of the channel can be dropped, which will
11cause the other end to unblock and report channel disconnection.
12
13While the channel operates asynchronously, it can also be used in a fully
14synchronous way by using `block_on` or similar utilities provided in most
15async runtimes.
16
17# Examples
18
19## Basic example
20
21```rust
22# futures::executor::block_on(async move {
23use handoff::channel;
24use futures::future::join;
25
26let (mut sender, mut receiver) = channel();
27
28let send_task = async move {
29 for i in 0..100 {
30 sender.send(i).await.expect("channel disconnected");
31 }
32};
33
34let recv_task = async move {
35 for i in 0..100 {
36 let value = receiver.recv().await.expect("channel disconnected");
37 assert_eq!(value, i);
38 }
39};
40
41// All sends block until the receiver accepts the value, so we need to make
42// sure the tasks happen concurrently
43join(send_task, recv_task).await;
44# });
45```
46
47## Synchronous use
48
49```
50use std::thread;
51use handoff::channel;
52use futures::executor::block_on;
53
54let (mut sender, mut receiver) = channel();
55
56let sender_thread = thread::spawn(move || {
57 for i in 0..100 {
58 block_on(sender.send(i)).expect("receiver disconnected");
59 }
60});
61
62let receiver_thread = thread::spawn(move || {
63 for i in 0..100 {
64 let value = block_on(receiver.recv()).expect("sender disconnected");
65 assert_eq!(value, i);
66 }
67});
68
69sender_thread.join().expect("sender panicked");
70receiver_thread.join().expect("receiver panicked");
71```
72
73## Disconnect
74
75```
76# futures::executor::block_on(async move {
77use handoff::channel;
78use futures::future::join;
79
80let (mut sender, mut receiver) = channel();
81
82let send_task = async move {
83 for i in 0..50 {
84 sender.send(i).await.expect("channel disconnected");
85 }
86};
87
88let recv_task = async move {
89 for i in 0..50 {
90 let value = receiver.recv().await.expect("channel disconnected");
91 assert_eq!(value, i);
92 }
93
94 assert!(receiver.recv().await.is_none());
95};
96
97// All sends block until the receiver accepts the value, so we need to make
98// sure the tasks happen concurrently
99join(send_task, recv_task).await;
100# });
101```
102*/
103
104#![deny(missing_docs)]
105#![deny(missing_debug_implementations)]
106
107use std::{
108 cell::UnsafeCell,
109 fmt::Debug,
110 future::Future,
111 hint::unreachable_unchecked,
112 ops::Not,
113 pin::Pin,
114 ptr::{self, NonNull},
115 sync::atomic::{
116 AtomicPtr,
117 Ordering::{Acquire, Relaxed, Release},
118 },
119 task::{Context, Poll},
120 thread,
121};
122
123trait UnsafeCellExt<T> {
124 #[must_use]
125 fn get_non_null(&self) -> NonNull<T>;
126}
127
128impl<T> UnsafeCellExt<T> for UnsafeCell<T> {
129 #[inline]
130 #[must_use]
131 fn get_non_null(&self) -> NonNull<T> {
132 NonNull::new(self.get()).expect("UnsafeCell shouldn't return a null pointer")
133 }
134}
135
136use futures_util::{
137 stream::{FusedStream, Stream, StreamExt},
138 task::AtomicWaker,
139};
140use pin_project::{pin_project, pinned_drop};
141use pinned_aliasable::Aliasable;
142use thiserror::Error;
143use twinsies::Joint;
144
145/// Identical to `unreachable_unchecked`, but panics in debug mode. Still
146/// requires unsafe.
147macro_rules! debug_unreachable {
148 ($($arg:tt)*) => {
149 match cfg!(debug_assertions) {
150 true => unreachable!($($arg)*),
151 false => unreachable_unchecked(),
152 }
153 }
154}
155
156/// Create an unbuffered channel for communicating between a pair of
157/// asynchronous components. All sends over this channel will block until the
158/// receiver receives the sent item. See [crate documentation][crate] for
159/// details.
160pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
161 let (send_joint, recv_joint) = Joint::new(Inner {
162 sent_item: AtomicPtr::default(),
163 sender_waker: AtomicWaker::new(),
164 receiver_waker: AtomicWaker::new(),
165 });
166
167 (Sender { inner: send_joint }, Receiver { inner: recv_joint })
168}
169
170struct Inner<T> {
171 // When this is not null, there's an object that a sender is trying to send
172 // (and is asynchronously blocked until the send completes)
173 sent_item: AtomicPtr<Option<T>>,
174
175 // The waker owned by the sender. Should be signalled when the receiver
176 // takes a value (or disconnects)
177 sender_waker: AtomicWaker,
178
179 // The waker owned by the receiver. Should be signalled when the sender has
180 // an item to send (or disconnects)
181 receiver_waker: AtomicWaker,
182}
183
184unsafe impl<T> Send for Inner<T> {}
185unsafe impl<T> Sync for Inner<T> {}
186
187impl<T> Inner<T> {
188 /// The sender uses this to take an item pointer that it placed there, to
189 /// regain exclusive access to its item.
190 #[inline]
191 fn reclaim_sent_item_pointer(&self, item_pointer: NonNull<Option<T>>) {
192 loop {
193 match self.sent_item.compare_exchange_weak(
194 item_pointer.as_ptr(),
195 ptr::null_mut(),
196 Acquire,
197 Relaxed,
198 ) {
199 Ok(_) => break,
200
201 // Spurious failure
202 Err(current) if current == item_pointer.as_ptr() => continue,
203
204 // Receiver owns the value; spin while we wait for it
205 //
206 // TODO: consider using something like the spinner from
207 // parking_lot_core. We're pretty certain that another thread is
208 // working with the pointer, though, so for now we're content to
209 // do a full yield and let it have a chance to finish its work.
210 Err(current) if current.is_null() => thread::yield_now(),
211
212 // Something very wrong happened
213 Err(current) => unsafe {
214 debug_unreachable!(
215 "A new pointer ({current:p}) appeared in inner \
216 while a sender exists ({item_pointer:p}); this \
217 should never happen"
218 )
219 },
220 }
221 }
222 }
223}
224
225/// Whenever `Inner` drops, it means a disconnect is happening. Inform the
226/// sender and receiver (though one of them, of course, is being dropped
227/// anyway). It's guaranteed that, once `Inner::drop` is called, the `Joint`
228impl<T> Drop for Inner<T> {
229 fn drop(&mut self) {
230 self.sender_waker.wake();
231 self.receiver_waker.wake();
232 }
233}
234
235/// The sending end of a handoff channel.
236///
237/// This object is created by the [`channel`] function. See [crate
238/// documentation][crate] for details.
239pub struct Sender<T> {
240 inner: Joint<Inner<T>>,
241}
242
243impl<T> Sender<T> {
244 /// Asynchronously send an item to the receiver.
245 ///
246 /// This method will asynchronously block until the receiver has received
247 /// the item. If the receiver disconnects, this will instead return a
248 /// [`SendError`] containing the item that failed to send.
249 #[inline]
250 #[must_use]
251 pub fn send(&mut self, item: T) -> SendFut<'_, T> {
252 SendFut {
253 item: Aliasable::new(UnsafeCell::new(Some(item))),
254 inner: &self.inner,
255 item_lent: false,
256 }
257 }
258
259 // TODO: `Sink` implementation. This will require wrapping the sender. Need
260 // to decide if we prefer a by-move or by-ref sink (probably the latter).
261 // Alternatively, create a crate with a general-purpose adapter between
262 // `async fn send` and `Sink`.
263}
264
265impl<T> Debug for Sender<T> {
266 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 f.debug_struct("Sender")
268 .field("inner", &self.inner)
269 .finish()
270 }
271}
272
273unsafe impl<T: Send> Send for Sender<T> {}
274
275/// Future for sending a single item through a [`Sender`], created by the
276/// [`send`][Sender::send] method. See its documentation for details.
277#[pin_project(PinnedDrop)]
278pub struct SendFut<'a, T> {
279 // Implementation note: It is critically important to remember that the
280 // contents of the cell here can be aliased even when we have a reference
281 // to it, so long as it's pinned. See the `Aliasable` docs for details.
282 #[pin]
283 item: Aliasable<UnsafeCell<Option<T>>>,
284
285 // We don't want to hold a `JointLock` on the Inner<T>; we want to
286 // check each time we're polled if there was a disconnect.
287 inner: &'a Joint<Inner<T>>,
288
289 // If item_lent is true, it means that `Inner` has ownership of `item`
290 // and we need to re-acquire the pointer before doing anything with
291 // it.
292 item_lent: bool,
293}
294
295impl<T> Debug for SendFut<'_, T> {
296 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 f.debug_struct("SendFut")
298 .field("item", &"<in transit>")
299 .field("inner", &self.inner)
300 .field("item_lent", &self.item_lent)
301 .finish()
302 }
303}
304
305impl<T> Future for SendFut<'_, T> {
306 type Output = Result<(), SendError<T>>;
307
308 #[inline]
309 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
310 let this = self.project();
311
312 let mut item_pointer = this.item.as_ref().get().get_non_null();
313
314 let Some(lock) = this.inner.lock() else {
315 return Poll::Ready(
316 // Safety: if we couldn't acquire a lock, it means that the
317 // `Inner` dropped, which means we definitely have exclusive
318 // access to the value.
319 match unsafe { item_pointer.as_mut() }
320 .take()
321 {
322 Some(item) => Err(SendError(item)),
323 None => Ok(()),
324 },
325 )
326 };
327
328 // If we've published the item pointer for the receiver to take, check
329 // to see if it successfully took the item.
330 if *this.item_lent {
331 // If we've previously polled, we're aiming to check and see if the
332 // item has been taken by the receiver yet. We need to first take
333 // the `item` pointer, to ensure we have exclusive access to the
334 // item
335 lock.reclaim_sent_item_pointer(item_pointer);
336
337 // For consistency, we always update this field after reclaiming the
338 // pointer. We specifically want it to be false so that our
339 // destructor knows it doesn't need to do any additional work.
340 *this.item_lent = false;
341
342 // We've acquired exclusive access to the item pointer; we can check
343 // to see if the item was taken yet.
344 if unsafe { item_pointer.as_ref() }.is_none() {
345 return Poll::Ready(Ok(()));
346 }
347 }
348
349 // At this point, we've either never been polled before, or we have been
350 // polled previously but we still have the item. The state is the same
351 // either way: the `Inner` contains a null pointer and we need to notify
352 // the receiver that a value is ready.
353 //
354 // Theoretically, the inner pointer could be non-null, but this only
355 // happens if we leaked a `send` future, so we can just clobber it.
356 // Similarly, we can theoretically not have the item, if we're polled
357 // again after returning Ready. Neither of these cause unsoundness.
358 debug_assert!(
359 unsafe { item_pointer.as_ref() }.is_some(),
360 "Don't poll futures after they returned success"
361 );
362
363 lock.sender_waker.register(cx.waker());
364 lock.sent_item.store(item_pointer.as_ptr(), Release);
365 lock.receiver_waker.wake();
366 *this.item_lent = true;
367
368 Poll::Pending
369 }
370}
371
372#[pinned_drop]
373impl<T> PinnedDrop for SendFut<'_, T> {
374 fn drop(self: Pin<&mut Self>) {
375 let this = self.project();
376
377 // We only need to do extra drop work if `Inner` has exclusive access to
378 // our `item`.
379 if this.item_lent.not() {
380 return;
381 };
382
383 // If we disconnected, there's nothing else we need to do. Even if
384 // `item_lent` was true, `inner` was dropped and implicitly doesn't have
385 // access to the `item` anymore.
386 let Some(lock) = this.inner.lock() else {
387 return;
388 };
389
390 // When an individual send future drops, we can immediately
391 // erase the waker. No send notification are necessary until a
392 // new send future appears.
393 drop(lock.sender_waker.take());
394
395 let item_pointer = this.item.into_ref().get().get_non_null();
396
397 // Okay, we need to acquire the pointer before we can drop. This
398 // might involve spinning if the receiver is working with it
399 // right now.
400 lock.reclaim_sent_item_pointer(item_pointer);
401 // Now that we've reclaimed the pointer, we don't need to do
402 // anything else. The drop can proceed normally.
403 }
404}
405
406unsafe impl<T: Send> Send for SendFut<'_, T> {}
407
408// TODO: verify that this is sound. I believe it is in all practical
409// cases, since there isn't actually any uncontrolled mechanism in this
410// crate by which a reference to `item` might be used while it's owned
411// by the channel
412unsafe impl<T> Sync for SendFut<'_, T> {}
413
414/// The receiving end of a handoff channel.
415///
416/// This object is created by the [`channel`] function. See [crate
417/// documentation][crate] for details.
418///
419/// [`Receiver`] only provides a simple [`recv`][Receiver::recv] method on its
420/// own, but it also implements [`futures::StreamExt`][StreamExt], which provides a number
421/// of additional helpful iterator-like methods.
422pub struct Receiver<T> {
423 inner: Joint<Inner<T>>,
424}
425
426impl<T> Receiver<T> {
427 /// Attempt to receive the next item from the sender.
428 ///
429 /// This method will asynchronously block until the sender sends an item,
430 /// then return that item. Alternatively, if the sender disconnects, this
431 /// will return `None`.
432 #[inline]
433 pub fn recv(&mut self) -> RecvFut<'_, T> {
434 RecvFut { receiver: self }
435 }
436}
437
438unsafe impl<T: Send> Send for Receiver<T> {}
439
440impl<T> Debug for Receiver<T> {
441 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
442 f.debug_struct("Receiver")
443 .field("inner", &self.inner)
444 .finish()
445 }
446}
447
448impl<T> Stream for Receiver<T> {
449 type Item = T;
450
451 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
452 let Some(lock) = self.inner.lock() else { return Poll::Ready(None) };
453
454 // All of the logic for actually attempting to take an item from
455 // `inner`. We have to call this twice, because we first try to take an
456 // item, then register a waker, then we have to try again after
457 // registering the waker. This avoids a race where we fail to retrieve
458 // and item, then the sender places an item, then the sender calls
459 // wake() before we've registered our waker.
460 //
461 // TODO: bench {recv; register(waker); recv} against {register(waker); recv}
462
463 let try_recv = || loop {
464 // Acquire the pointer. As long as we have it, we have exclusive
465 // access to the item. The sender will wait for us to return the
466 // pointer before dropping (or, if it leaks, the value is pinned, so
467 // the pointer is valid forever in that case).
468 let sent_item_ptr = lock.sent_item.swap(ptr::null_mut(), Acquire);
469
470 // If there wasn't a pointer available, we've already registered our
471 // waker, so at this point we're waiting for a signal to try another
472 // receive operation.
473 let Some(mut sent_item_ptr) = NonNull::new(sent_item_ptr) else {
474 return Poll::Pending
475 };
476
477 // Try to read the item from the pointer. It's possible that we've
478 // already taken it and this is a spurious poll.
479 //
480 // SAFETY: Because we acquired the `sent_item_ptr` (replacing it
481 // with a null ptr), we have exclusive access to it.
482 let sent_item = unsafe { sent_item_ptr.as_mut() }.take();
483
484 // We don't need to retry (non-spurious) failures, since the
485 // presence of a new non-null pointer indicates a sender leak, which
486 // means we can simply drop the `sent_item_ptr` outright.
487 match lock.sent_item.compare_exchange(
488 ptr::null_mut(),
489 sent_item_ptr.as_ptr(),
490 Release,
491 Relaxed,
492 ) {
493 // We restored the pointer, so we need to wake the sender so it
494 // can proceed with the drop
495 Ok(_) => lock.sender_waker.wake(),
496
497 // Somehow the pointer to a pinned object found its way back
498 // into the slot. This shouldn't be possible, since that memory
499 // should be usable until the sender finishes sending, and it
500 // can't drop until we restore the pointer.
501 Err(p) if p == sent_item_ptr.as_ptr() => unsafe { debug_unreachable!() },
502
503 // There was a leak and a new sent item arrived while we were
504 // working. If we didn't receive an item, we can retry receiving
505 // this *new* item.
506 Err(_) if sent_item.is_none() => continue,
507
508 // There was a leak and a new sent item arrived while we were
509 // working. We already got an item, so we have to leave the new
510 // one there until a subsequent `recv`.
511 Err(_) => {}
512 }
513
514 return match sent_item {
515 Some(item) => Poll::Ready(Some(item)),
516 None => Poll::Pending,
517 };
518 };
519
520 match try_recv() {
521 Poll::Ready(item) => Poll::Ready(item),
522 Poll::Pending => {
523 lock.receiver_waker.register(cx.waker());
524 try_recv()
525 }
526 }
527 }
528
529 #[inline]
530 #[must_use]
531 fn size_hint(&self) -> (usize, Option<usize>) {
532 (0, if self.inner.alive() { None } else { Some(0) })
533 }
534}
535
536impl<T> FusedStream for Receiver<T> {
537 fn is_terminated(&self) -> bool {
538 !self.inner.alive()
539 }
540}
541
542// TODO: this drop is only necessary if the receiver has been acting as a
543// `futures::Stream`. Consider creating a separate wrapper around it that
544// implements stream.
545impl<T> Drop for Receiver<T> {
546 fn drop(&mut self) {
547 let Some(lock) = self.inner.lock() else { return };
548 drop(lock.receiver_waker.take())
549 }
550}
551
552/// Future type for receiving a single item from a [`Receiver`]. Created by the
553/// [`recv`][Receiver::recv] method; see its documentation for details.
554pub struct RecvFut<'a, T> {
555 receiver: &'a mut Receiver<T>,
556}
557
558impl<T> Debug for RecvFut<'_, T> {
559 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
560 f.debug_struct("Recv")
561 .field("receiver", &self.receiver)
562 .finish()
563 }
564}
565
566impl<T> Future for RecvFut<'_, T> {
567 type Output = Option<T>;
568
569 #[inline]
570 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
571 self.receiver.poll_next_unpin(cx)
572 }
573}
574
575impl<T> Drop for RecvFut<'_, T> {
576 #[inline]
577 fn drop(&mut self) {
578 let Some(lock) = self.receiver.inner.lock() else { return };
579 drop(lock.receiver_waker.take())
580 }
581}
582
583/// An error from a [`send()`][Sender::send] operation.
584///
585/// This error means the send failed due to a disconnect; this is the only way
586/// sends can fail. The error contains the item that failed to send.
587#[derive(Error, Clone, Debug, Copy)]
588#[error("tried to send on a disconnected channel")]
589pub struct SendError<T>(
590 /// The item that failed to send
591 pub T,
592);
593
594#[cfg(test)]
595mod tests {
596 use std::thread;
597
598 use cool_asserts::assert_matches;
599 use futures::{executor::block_on, StreamExt};
600
601 use super::{channel, SendError};
602
603 #[tokio::test]
604 async fn basic_test() {
605 let (mut sender, receiver) = channel();
606
607 let sender_task = tokio::task::spawn(async move {
608 sender.send(1).await.unwrap();
609 sender.send(2).await.unwrap();
610 sender.send(3).await.unwrap();
611 sender.send(4).await.unwrap();
612 });
613
614 let data: Vec<i32> = receiver.collect().await;
615 sender_task.await.unwrap();
616
617 assert_eq!(data, [1, 2, 3, 4]);
618 }
619
620 #[tokio::test]
621 async fn taskless() {
622 let (mut sender, mut receiver) = channel();
623
624 let (send, recv) = futures::future::join(sender.send(1), receiver.next()).await;
625 send.unwrap();
626 assert_eq!(recv.unwrap(), 1);
627
628 let (recv, send) = futures::future::join(receiver.next(), sender.send(2)).await;
629 send.unwrap();
630 assert_eq!(recv.unwrap(), 2);
631 }
632
633 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
634 async fn multi_thread_tasks() {
635 let (mut sender, mut receiver) = channel();
636
637 let sender_task = tokio::task::spawn(async move {
638 for i in 0..1_000 {
639 sender.send(i).await.unwrap();
640 }
641 });
642
643 let receiver_task = tokio::task::spawn(async move {
644 for i in 0..1_000 {
645 assert_eq!(receiver.next().await.unwrap(), i);
646 }
647 });
648
649 sender_task.await.unwrap();
650 receiver_task.await.unwrap();
651 }
652
653 #[test]
654 fn sync_test() {
655 let (mut sender, mut receiver) = channel();
656
657 let sender_thread = thread::Builder::new()
658 .name("sender thread".to_owned())
659 .spawn(move || {
660 block_on(async move {
661 for i in 0..1_000 {
662 sender.send(i).await.unwrap();
663 }
664 })
665 })
666 .expect("failed to spawn thread");
667
668 for i in 0..1_000 {
669 assert_eq!(block_on(receiver.next()).unwrap(), i);
670 }
671
672 sender_thread.join().unwrap();
673 }
674
675 #[test]
676 fn sync_test_two_threads() {
677 let (mut sender, mut receiver) = channel();
678
679 let sender_thread = thread::spawn(move || {
680 for i in 0..100 {
681 block_on(sender.send(i)).expect("receiver disconnected");
682 }
683 });
684
685 let receiver_thread = thread::spawn(move || {
686 for i in 0..100 {
687 let value = block_on(receiver.recv()).expect("sender disconnected");
688 assert_eq!(value, i);
689 }
690 });
691
692 sender_thread.join().expect("sender panicked");
693 receiver_thread.join().expect("receiver panicked");
694 }
695
696 #[tokio::test]
697 async fn basic_sender_close() {
698 let (sender, mut receiver) = channel();
699
700 drop(sender);
701
702 let out: Option<i32> = receiver.recv().await;
703 assert_eq!(out, None);
704 }
705
706 #[tokio::test]
707 async fn basic_receiver_close() {
708 let (mut sender, receiver) = channel();
709
710 drop(receiver);
711
712 assert_matches!(sender.send(1).await, Err(SendError(1)));
713 }
714
715 #[tokio::test]
716 async fn sender_close_while_waiting() {
717 let (sender, mut receiver) = channel();
718
719 let sender_task = tokio::task::spawn(async move {
720 tokio::task::yield_now().await;
721 drop(sender);
722 });
723
724 let out: Option<i32> = receiver.recv().await;
725 assert_eq!(out, None);
726 sender_task.await.unwrap();
727 }
728
729 #[tokio::test]
730 async fn receiver_close_while_waiting() {
731 let (mut sender, receiver) = channel();
732
733 let receiver_task = tokio::task::spawn(async move {
734 tokio::task::yield_now().await;
735 drop(receiver);
736 });
737
738 assert_matches!(sender.send(1).await, Err(SendError(1)));
739 receiver_task.await.unwrap();
740 }
741
742 #[tokio::test]
743 async fn sender_cancels() {
744 let (mut sender, mut receiver) = channel();
745
746 let sender_task = tokio::task::spawn(async move {
747 sender.send(1).await.unwrap();
748 sender.send(2).await.unwrap();
749 });
750
751 assert_eq!(receiver.next().await.unwrap(), 1);
752 sender_task.abort();
753 assert_matches!(receiver.next().await, None);
754 assert_matches!(sender_task.await, Err(err) => assert!(err.is_cancelled()));
755 }
756
757 // TODO: test sender leak
758
759 // TODO: bench compare various channels
760}