async_io_converse/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use async_io_typed::{AsyncReadTyped, AsyncWriteTyped};
4use futures_io::{AsyncRead, AsyncWrite};
5use futures_util::{SinkExt, Stream, StreamExt};
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7use std::{
8    future::Future,
9    io,
10    pin::Pin,
11    sync::Arc,
12    task::{Context, Poll},
13    time::Duration,
14};
15use tokio::sync::{mpsc, oneshot, Mutex};
16
17#[cfg(test)]
18mod tests;
19
20#[derive(Deserialize, Serialize)]
21struct InternalMessage<T> {
22    user_message: T,
23    conversation_id: u64,
24    is_reply: bool,
25}
26
27/// A message received from the connected peer, which you may choose to reply to.
28pub struct ReceivedMessage<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> {
29    message: Option<T>,
30    conversation_id: u64,
31    raw_write: Arc<Mutex<AsyncWriteTyped<W, InternalMessage<T>>>>,
32}
33
34impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> ReceivedMessage<W, T> {
35    /// Peeks at the message, panicking if the message had already been taken prior.
36    pub fn message(&self) -> &T {
37        self.message_opt().expect("message already taken")
38    }
39
40    /// Peeks at the message, returning `None` if the message had already been taken prior.
41    pub fn message_opt(&self) -> Option<&T> {
42        self.message.as_ref()
43    }
44
45    /// Pulls the message from this, panicking if the message had already been taken prior.
46    pub fn take_message(&mut self) -> T {
47        self.take_message_opt().expect("message already taken")
48    }
49
50    /// Pulls the message from this, returning `None` if the message had already been taken prior.
51    pub fn take_message_opt(&mut self) -> Option<T> {
52        self.message.take()
53    }
54
55    /// Sends the given message as a reply to this one. There are two ways for the peer to receive this reply
56    ///
57    /// 1. `.await` both layers of [AsyncWriteConverse::send] or [AsyncWriteConverse::send_timeout]
58    /// 2. They'll receive it as the return value of [AsyncWriteConverse::ask] or [AsyncWriteConverse::ask_timeout].
59    pub async fn reply(self, reply: T) -> Result<(), Error> {
60        SinkExt::send(
61            &mut *self.raw_write.lock().await,
62            InternalMessage {
63                user_message: reply,
64                is_reply: true,
65                conversation_id: self.conversation_id,
66            },
67        )
68        .await
69        .map_err(Into::into)
70    }
71}
72
73struct ReplySender<T> {
74    reply_sender: Option<oneshot::Sender<Result<T, Error>>>,
75    conversation_id: u64,
76}
77
78/// Errors which can occur on an `async-io-converse` connection.
79#[derive(Debug)]
80pub enum Error {
81    /// Error from `std::io`
82    Io(io::Error),
83    /// Error from the `bincode` crate
84    Bincode(bincode::Error),
85    /// A message was received that exceeded the configured length limit
86    ReceivedMessageTooLarge,
87    /// A message was sent that exceeded the configured length limit
88    SentMessageTooLarge,
89    ChecksumMismatch {
90        sent_checksum: u64,
91        computed_checksum: u64,
92    },
93    ProtocolVersionMismatch {
94        our_version: u64,
95        their_version: u64,
96    },
97    ChecksumHandshakeFailed {
98        checksum_value: u8,
99    },
100    /// A reply wasn't received within the timeout specified
101    Timeout,
102    /// The read half was dropped, crippling the ability to receive replies.
103    ReadHalfDropped,
104}
105
106pub use async_io_typed::ChecksumEnabled;
107
108impl From<async_io_typed::Error> for Error {
109    fn from(e: async_io_typed::Error) -> Self {
110        match e {
111            async_io_typed::Error::Io(e) => Error::Io(e),
112            async_io_typed::Error::Bincode(e) => Error::Bincode(e),
113            async_io_typed::Error::ReceivedMessageTooLarge => Error::ReceivedMessageTooLarge,
114            async_io_typed::Error::SentMessageTooLarge => Error::SentMessageTooLarge,
115            async_io_typed::Error::ChecksumMismatch {
116                sent_checksum,
117                computed_checksum,
118            } => Error::ChecksumMismatch {
119                sent_checksum,
120                computed_checksum,
121            },
122            async_io_typed::Error::ProtocolVersionMismatch {
123                our_version,
124                their_version,
125            } => Error::ProtocolVersionMismatch {
126                our_version,
127                their_version,
128            },
129            async_io_typed::Error::ChecksumHandshakeFailed { checksum_value } => {
130                Error::ChecksumHandshakeFailed { checksum_value }
131            }
132        }
133    }
134}
135
136const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
137
138pub fn new_duplex_connection_with_limit<
139    T: DeserializeOwned + Serialize + Unpin,
140    R: AsyncRead + Unpin,
141    W: AsyncWrite + Unpin,
142>(
143    size_limit: u64,
144    checksum_enabled: ChecksumEnabled,
145    raw_read: R,
146    raw_write: W,
147) -> (AsyncReadConverse<R, W, T>, AsyncWriteConverse<W, T>) {
148    let write = Arc::new(Mutex::new(AsyncWriteTyped::new_with_limit(
149        raw_write,
150        size_limit,
151        checksum_enabled,
152    )));
153    let write_clone = Arc::clone(&write);
154    let (reply_data_sender, reply_data_receiver) = mpsc::unbounded_channel();
155    let read = AsyncReadConverse {
156        raw: AsyncReadTyped::new_with_limit(raw_read, size_limit, checksum_enabled),
157        raw_write: write_clone,
158        reply_data_receiver,
159        pending_reply: Vec::new(),
160    };
161    let write = AsyncWriteConverse {
162        raw: write,
163        reply_data_sender,
164        next_id: 0,
165    };
166    (read, write)
167}
168
169pub fn new_duplex_connection<
170    T: DeserializeOwned + Serialize + Unpin,
171    R: AsyncRead + Unpin,
172    W: AsyncWrite + Unpin,
173>(
174    checksum_enabled: ChecksumEnabled,
175    raw_read: R,
176    raw_write: W,
177) -> (AsyncReadConverse<R, W, T>, AsyncWriteConverse<W, T>) {
178    new_duplex_connection_with_limit(1024u64.pow(2), checksum_enabled, raw_read, raw_write)
179}
180
181/// Used to receive messages from the connected peer. ***You must drive this in order to receive replies on [AsyncWriteConverse]***
182pub struct AsyncReadConverse<
183    R: AsyncRead + Unpin,
184    W: AsyncWrite + Unpin,
185    T: Serialize + DeserializeOwned + Unpin,
186> {
187    raw: AsyncReadTyped<R, InternalMessage<T>>,
188    raw_write: Arc<Mutex<AsyncWriteTyped<W, InternalMessage<T>>>>,
189    reply_data_receiver: mpsc::UnboundedReceiver<ReplySender<T>>,
190    pending_reply: Vec<ReplySender<T>>,
191}
192
193impl<R: AsyncRead + Unpin, W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin>
194    AsyncReadConverse<R, W, T>
195{
196    pub fn inner(&self) -> &R {
197        self.raw.inner()
198    }
199
200    /// `AsyncReadConverse` keeps a memory buffer for receiving values which is the same size as the largest
201    /// message that's been received. If the message size varies a lot, you might find yourself wasting
202    /// memory space. This function will reduce the memory usage as much as is possible without impeding
203    /// functioning. Overuse of this function may cause excessive memory allocations when the buffer
204    /// needs to grow.
205    pub fn optimize_memory_usage(&mut self) {
206        self.raw.optimize_memory_usage()
207    }
208}
209
210impl<
211        R: AsyncRead + Unpin + Send + 'static,
212        W: AsyncWrite + Unpin + Send + 'static,
213        T: Serialize + DeserializeOwned + Unpin + Send + 'static,
214    > AsyncReadConverse<R, W, T>
215{
216    /// Returns a future that will drive the receive mechanism. It's recommended to spawn this onto an `async`
217    /// runtime, such as `tokio`. This allows you to receive replies to your messages, while completely
218    /// ignoring any non-reply messages you get.
219    ///
220    /// If instead you'd like to see the non-reply messages then you'll need to drive the `Stream` implementation
221    /// for `AsyncReadConverse`.
222    pub async fn drive_forever(mut self) {
223        while StreamExt::next(&mut self).await.is_some() {}
224    }
225}
226
227impl<R: AsyncRead + Unpin, W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> Stream
228    for AsyncReadConverse<R, W, T>
229{
230    type Item = Result<ReceivedMessage<W, T>, Error>;
231
232    fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
233        let Self {
234            ref mut raw,
235            ref mut reply_data_receiver,
236            ref mut pending_reply,
237            ref raw_write,
238        } = self.get_mut();
239        loop {
240            match futures_core::ready!(Pin::new(&mut *raw).poll_next(cx)) {
241                Some(r) => {
242                    let i = r?;
243                    while let Ok(reply_data) = reply_data_receiver.try_recv() {
244                        pending_reply.push(reply_data);
245                    }
246                    let mut user_message = Some(i.user_message);
247                    pending_reply.retain_mut(|pending_reply| {
248                        if let Some(reply_sender) = pending_reply.reply_sender.as_ref() {
249                            if reply_sender.is_closed() {
250                                return false;
251                            }
252                        }
253                        let matches =
254                            i.is_reply && pending_reply.conversation_id == i.conversation_id;
255                        if matches {
256                            let _ = pending_reply
257                                .reply_sender
258                                .take()
259                                .expect("infallible")
260                                .send(Ok(user_message.take().expect("infallible")));
261                        }
262                        !matches
263                    });
264                    if !i.is_reply {
265                        return Poll::Ready(Some(Ok(ReceivedMessage {
266                            message: Some(user_message.take().expect("infallible")),
267                            conversation_id: i.conversation_id,
268                            raw_write: Arc::clone(raw_write),
269                        })));
270                    } else {
271                        continue;
272                    }
273                }
274                None => return Poll::Ready(None),
275            }
276        }
277    }
278}
279
280/// Used to send messages to the connected peer. You may optionally receive replies to your messages as well.
281///
282/// ***You must drive the corresponding [AsyncReadConverse] in order to receive replies to your messages.***
283/// You can do this either by driving the `Stream` implementation, or calling [AsyncReadConverse::drive_forever].
284pub struct AsyncWriteConverse<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> {
285    raw: Arc<Mutex<AsyncWriteTyped<W, InternalMessage<T>>>>,
286    reply_data_sender: mpsc::UnboundedSender<ReplySender<T>>,
287    next_id: u64,
288}
289
290impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncWriteConverse<W, T> {
291    pub async fn with_inner<F: FnOnce(&W) -> R, R>(&self, f: F) -> R {
292        f(self.raw.lock().await.inner())
293    }
294
295    /// `AsyncWriteConverse` keeps a memory buffer for sending values which is the same size as the largest
296    /// message that's been sent. If the message size varies a lot, you might find yourself wasting
297    /// memory space. This function will reduce the memory usage as much as is possible without impeding
298    /// functioning. Overuse of this function may cause excessive memory allocations when the buffer
299    /// needs to grow.
300    pub async fn optimize_memory_usage(&mut self) {
301        self.raw.lock().await.optimize_memory_usage()
302    }
303}
304
305impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncWriteConverse<W, T> {
306    /// Send a message, and wait for a reply, with the default timeout. Shorthand for `.await`ing both layers of `.send(message)`.
307    pub async fn ask(&mut self, message: T) -> Result<T, Error> {
308        self.ask_timeout(DEFAULT_TIMEOUT, message).await
309    }
310
311    /// Send a message, and wait for a reply, up to timeout. Shorthand for `.await`ing both layers of `.send_timeout(message)`.
312    pub async fn ask_timeout(&mut self, timeout: Duration, message: T) -> Result<T, Error> {
313        match self.send_timeout(timeout, message).await {
314            Ok(fut) => fut.await,
315            Err(e) => Err(e),
316        }
317    }
318
319    /// Sends a message to the peer on the other side of the connection. This returns a future wrapped in a future. You must
320    /// `.await` the first layer to send the message, however `.await`ing the second layer is optional. You only need to
321    /// `.await` the second layer if you care about the reply to your message. Waits up to the default timeout for a reply.
322    pub async fn send(
323        &mut self,
324        message: T,
325    ) -> Result<impl Future<Output = Result<T, Error>>, Error> {
326        self.send_timeout(DEFAULT_TIMEOUT, message).await
327    }
328
329    /// Sends a message to the peer on the other side of the connection, waiting up to the specified timeout for a reply.
330    /// This returns a future wrapped in a future. You must  `.await` the first layer to send the message, however
331    /// `.await`ing the second layer is optional. You only need to  `.await` the second layer if you care about the
332    /// reply to your message.
333    pub async fn send_timeout(
334        &mut self,
335        timeout: Duration,
336        message: T,
337    ) -> Result<impl Future<Output = Result<T, Error>>, Error> {
338        let (reply_sender, reply_receiver) = oneshot::channel();
339        let read_half_dropped = self
340            .reply_data_sender
341            .send(ReplySender {
342                reply_sender: Some(reply_sender),
343                conversation_id: self.next_id,
344            })
345            .is_err();
346        SinkExt::send(
347            &mut *self.raw.lock().await,
348            InternalMessage {
349                user_message: message,
350                conversation_id: self.next_id,
351                is_reply: false,
352            },
353        )
354        .await?;
355        self.next_id = self.next_id.wrapping_add(1);
356        Ok(async move {
357            if read_half_dropped {
358                return Err(Error::ReadHalfDropped);
359            }
360            let res = tokio::time::timeout(timeout, reply_receiver).await;
361            match res {
362                Ok(Ok(Ok(value))) => Ok(value),
363                Ok(Ok(Err(e))) => Err(e),
364                Ok(Err(_)) => Err(Error::ReadHalfDropped),
365                Err(_) => Err(Error::Timeout),
366            }
367        })
368    }
369}