1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
//! An async channel with request-response semantics
//! See the [`bounded`] function documentation for more

#[deny(missing_docs)]
use async_std::channel;
pub use async_std::channel::Receiver as Responder;
use derive_more::{AsMut, AsRef, Deref, DerefMut, From};
use futures::channel::oneshot;
pub type FutureResponse<T> = oneshot::Receiver<T>;

pub trait Respond<Resp> {
    /// If the implementer owns any data, it is given back to the user on both receipt and failure
    type Owned;
    /// Fullfill our obligation to the [`Requester`] by responding to their request
    fn respond(self, response: Resp) -> Result<Self::Owned, (Self::Owned, Resp)>;
}

mod impl_respond {
    use super::{ReceivedRequest, Respond, UnRespondedRequest};

    impl<Req, Resp> Respond<Resp> for ReceivedRequest<Req, Resp> {
        type Owned = Req;
        fn respond(self, response: Resp) -> Result<Self::Owned, (Self::Owned, Resp)> {
            match self.unresponded.respond(response) {
                Ok(_) => Ok(self.request),
                Err((_, response)) => Err((self.request, response)),
            }
        }
    }

    impl<Resp> Respond<Resp> for UnRespondedRequest<Resp> {
        type Owned = ();
        fn respond(self, response: Resp) -> Result<Self::Owned, (Self::Owned, Resp)> {
            self.response_sender
                .send(response)
                .map_err(|response| ((), response))
        }
    }
}

/// Represents that the [`Requester`] associated with this communication is still waiting for a response.
/// Must be used by calling [`Respond::respond`].
#[must_use = "You must respond to the request"]
pub struct UnRespondedRequest<Resp> {
    response_sender: oneshot::Sender<Resp>,
}

/// Represents the request.
/// This implements [`AsRef`] and [`AsMut`] for the request itself for explicit use.
/// Alternatively, you may use [`Deref`] and [`DerefMut`] either explicitly, or coerced.
/// Must be used by calling [`Respond::respond`], or destructured.
#[must_use = "You must respond to the request"]
#[derive(AsRef, AsMut, From, Deref, DerefMut)]
pub struct ReceivedRequest<Req, Resp> {
    #[as_ref]
    #[as_mut]
    #[deref]
    #[deref_mut]
    pub request: Req,
    pub unresponded: UnRespondedRequest<Resp>,
}

/// Represents the initiator for the request-response exchange
pub struct Requester<Req, Resp> {
    outgoing: channel::Sender<ReceivedRequest<Req, Resp>>,
}

impl<Req, Resp> Requester<Req, Resp> {
    /// Make a request.
    /// The [`FutureResponse`] should be `await`ed to get the response from the responder
    pub async fn send(&self, request: Req) -> Result<FutureResponse<Resp>, Req> {
        // Create the return path
        let (response_sender, response_receiver) = oneshot::channel();
        self.outgoing
            .send(ReceivedRequest {
                request,
                unresponded: UnRespondedRequest { response_sender },
            })
            .await
            .map_err(|e| e.into_inner().request)?;
        Ok(response_receiver)
    }
}

/// Create a bounded [`Requester`]-[`Responder`] pair.  
/// That is, once the channel is full, future senders will yield when awaiting until there's space again
///
/// Terminology is as follows:
/// ```mermaid
/// sequenceDiagram
///     Requester ->> Responder: request
///     Responder ->> Requester: response
/// ```
///
/// When a [`Responder`] is asked to receive a request, it returns a [`ReceivedRequest`]
/// The latter should be used to communicate back to the sender
///
/// ```
/// # async_std::task::block_on( async {
/// use bidirectional_channel::Respond; // Don't forget to import this trait
/// let (requester, responder) = bidirectional_channel::bounded(10);
/// let future_response = requester.send(String::from("hello")).await.unwrap();
///
/// // This side of the channel receives Strings, and responds with their length
/// let received_request = responder.recv().await.unwrap();
/// let len = received_request.len(); // Deref coercion to the actual request
/// received_request.respond(len).unwrap();
///
/// assert!(future_response.await.unwrap() == 5);
/// # })
/// ```
pub fn bounded<Req, Resp>(
    capacity: usize,
) -> (Requester<Req, Resp>, Responder<ReceivedRequest<Req, Resp>>) {
    let (sender, receiver) = channel::bounded(capacity);
    (Requester { outgoing: sender }, receiver)
}