futures_ext/future/
conservative_receiver.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under both the MIT license found in the
5 * LICENSE-MIT file in the root directory of this source tree and the Apache
6 * License, Version 2.0 found in the LICENSE-APACHE file in the root directory
7 * of this source tree.
8 */
9
10use std::pin::Pin;
11
12use futures::channel::oneshot::Canceled;
13use futures::channel::oneshot::Receiver;
14use futures::task::Context;
15use futures::task::Poll;
16use futures::Future;
17use pin_project::pin_project;
18use thiserror::Error;
19
20/// This is a wrapper around [Receiver] that will return error when the receiver was polled
21/// and the result was not ready. This is a very strict way of preventing deadlocks in code when
22/// receiver is polled before the sender has send the result
23#[pin_project]
24pub struct ConservativeReceiver<T>(#[pin] Receiver<T>);
25
26impl<T> ConservativeReceiver<T> {
27    /// Return an instance of [ConservativeReceiver] wrapping the [Receiver]
28    pub fn new(recv: Receiver<T>) -> Self {
29        ConservativeReceiver(recv)
30    }
31}
32
33impl<T> Future for ConservativeReceiver<T> {
34    type Output = Result<T, ConservativeReceiverError>;
35
36    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
37        let mut this = self.project();
38
39        match this.0.as_mut().poll(cx) {
40            Poll::Ready(Ok(output)) => Poll::Ready(Ok(output)),
41            Poll::Ready(Err(Canceled)) => Poll::Ready(Err(ConservativeReceiverError::Canceled)),
42            Poll::Pending => Poll::Ready(Err(ConservativeReceiverError::ReceiveBeforeSend)),
43        }
44    }
45}
46
47/// Error that can be returned by [ConservativeReceiver]
48#[derive(Error, Debug)]
49pub enum ConservativeReceiverError {
50    /// The underlying [Receiver] returned [Canceled]
51    #[error("oneshot canceled")]
52    Canceled,
53    /// The underlying [Receiver] returned [Poll::Pending], which means it was polled
54    /// before the [::futures::oneshot::Sender] send some data
55    #[error("recv called on channel before send")]
56    ReceiveBeforeSend,
57}
58
59#[cfg(test)]
60mod test {
61    use assert_matches::assert_matches;
62    use futures::channel::oneshot::channel;
63
64    use super::*;
65
66    #[tokio::test]
67    async fn recv_after_send() {
68        let (send, recv) = channel();
69        let recv = ConservativeReceiver::new(recv);
70
71        send.send(42).expect("Failed to send");
72        assert_matches!(recv.await, Ok(42));
73    }
74
75    #[tokio::test]
76    async fn recv_before_send() {
77        let (send, recv) = channel();
78        let recv = ConservativeReceiver::new(recv);
79
80        assert_matches!(
81            recv.await,
82            Err(ConservativeReceiverError::ReceiveBeforeSend)
83        );
84        send.send(42).expect_err("Should fail to send");
85    }
86
87    #[tokio::test]
88    async fn recv_canceled_send() {
89        let (_, recv) = channel::<()>();
90        let recv = ConservativeReceiver::new(recv);
91
92        assert_matches!(recv.await, Err(ConservativeReceiverError::Canceled));
93    }
94}