1#![doc = include_str!("../README.md")]
2
3use async_io_typed::{AsyncReadTyped, AsyncWriteTyped};
4use futures_io::{AsyncRead, AsyncWrite};
5use futures_util::{SinkExt, Stream, StreamExt};
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7use std::{
8 future::Future,
9 io,
10 pin::Pin,
11 sync::Arc,
12 task::{Context, Poll},
13 time::Duration,
14};
15use tokio::sync::{mpsc, oneshot, Mutex};
16
17#[cfg(test)]
18mod tests;
19
20#[derive(Deserialize, Serialize)]
21struct InternalMessage<T> {
22 user_message: T,
23 conversation_id: u64,
24 is_reply: bool,
25}
26
27pub struct ReceivedMessage<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> {
29 message: Option<T>,
30 conversation_id: u64,
31 raw_write: Arc<Mutex<AsyncWriteTyped<W, InternalMessage<T>>>>,
32}
33
34impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> ReceivedMessage<W, T> {
35 pub fn message(&self) -> &T {
37 self.message_opt().expect("message already taken")
38 }
39
40 pub fn message_opt(&self) -> Option<&T> {
42 self.message.as_ref()
43 }
44
45 pub fn take_message(&mut self) -> T {
47 self.take_message_opt().expect("message already taken")
48 }
49
50 pub fn take_message_opt(&mut self) -> Option<T> {
52 self.message.take()
53 }
54
55 pub async fn reply(self, reply: T) -> Result<(), Error> {
60 SinkExt::send(
61 &mut *self.raw_write.lock().await,
62 InternalMessage {
63 user_message: reply,
64 is_reply: true,
65 conversation_id: self.conversation_id,
66 },
67 )
68 .await
69 .map_err(Into::into)
70 }
71}
72
73struct ReplySender<T> {
74 reply_sender: Option<oneshot::Sender<Result<T, Error>>>,
75 conversation_id: u64,
76}
77
78#[derive(Debug)]
80pub enum Error {
81 Io(io::Error),
83 Bincode(bincode::Error),
85 ReceivedMessageTooLarge,
87 SentMessageTooLarge,
89 ChecksumMismatch {
90 sent_checksum: u64,
91 computed_checksum: u64,
92 },
93 ProtocolVersionMismatch {
94 our_version: u64,
95 their_version: u64,
96 },
97 ChecksumHandshakeFailed {
98 checksum_value: u8,
99 },
100 Timeout,
102 ReadHalfDropped,
104}
105
106pub use async_io_typed::ChecksumEnabled;
107
108impl From<async_io_typed::Error> for Error {
109 fn from(e: async_io_typed::Error) -> Self {
110 match e {
111 async_io_typed::Error::Io(e) => Error::Io(e),
112 async_io_typed::Error::Bincode(e) => Error::Bincode(e),
113 async_io_typed::Error::ReceivedMessageTooLarge => Error::ReceivedMessageTooLarge,
114 async_io_typed::Error::SentMessageTooLarge => Error::SentMessageTooLarge,
115 async_io_typed::Error::ChecksumMismatch {
116 sent_checksum,
117 computed_checksum,
118 } => Error::ChecksumMismatch {
119 sent_checksum,
120 computed_checksum,
121 },
122 async_io_typed::Error::ProtocolVersionMismatch {
123 our_version,
124 their_version,
125 } => Error::ProtocolVersionMismatch {
126 our_version,
127 their_version,
128 },
129 async_io_typed::Error::ChecksumHandshakeFailed { checksum_value } => {
130 Error::ChecksumHandshakeFailed { checksum_value }
131 }
132 }
133 }
134}
135
136const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
137
138pub fn new_duplex_connection_with_limit<
139 T: DeserializeOwned + Serialize + Unpin,
140 R: AsyncRead + Unpin,
141 W: AsyncWrite + Unpin,
142>(
143 size_limit: u64,
144 checksum_enabled: ChecksumEnabled,
145 raw_read: R,
146 raw_write: W,
147) -> (AsyncReadConverse<R, W, T>, AsyncWriteConverse<W, T>) {
148 let write = Arc::new(Mutex::new(AsyncWriteTyped::new_with_limit(
149 raw_write,
150 size_limit,
151 checksum_enabled,
152 )));
153 let write_clone = Arc::clone(&write);
154 let (reply_data_sender, reply_data_receiver) = mpsc::unbounded_channel();
155 let read = AsyncReadConverse {
156 raw: AsyncReadTyped::new_with_limit(raw_read, size_limit, checksum_enabled),
157 raw_write: write_clone,
158 reply_data_receiver,
159 pending_reply: Vec::new(),
160 };
161 let write = AsyncWriteConverse {
162 raw: write,
163 reply_data_sender,
164 next_id: 0,
165 };
166 (read, write)
167}
168
169pub fn new_duplex_connection<
170 T: DeserializeOwned + Serialize + Unpin,
171 R: AsyncRead + Unpin,
172 W: AsyncWrite + Unpin,
173>(
174 checksum_enabled: ChecksumEnabled,
175 raw_read: R,
176 raw_write: W,
177) -> (AsyncReadConverse<R, W, T>, AsyncWriteConverse<W, T>) {
178 new_duplex_connection_with_limit(1024u64.pow(2), checksum_enabled, raw_read, raw_write)
179}
180
181pub struct AsyncReadConverse<
183 R: AsyncRead + Unpin,
184 W: AsyncWrite + Unpin,
185 T: Serialize + DeserializeOwned + Unpin,
186> {
187 raw: AsyncReadTyped<R, InternalMessage<T>>,
188 raw_write: Arc<Mutex<AsyncWriteTyped<W, InternalMessage<T>>>>,
189 reply_data_receiver: mpsc::UnboundedReceiver<ReplySender<T>>,
190 pending_reply: Vec<ReplySender<T>>,
191}
192
193impl<R: AsyncRead + Unpin, W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin>
194 AsyncReadConverse<R, W, T>
195{
196 pub fn inner(&self) -> &R {
197 self.raw.inner()
198 }
199
200 pub fn optimize_memory_usage(&mut self) {
206 self.raw.optimize_memory_usage()
207 }
208}
209
210impl<
211 R: AsyncRead + Unpin + Send + 'static,
212 W: AsyncWrite + Unpin + Send + 'static,
213 T: Serialize + DeserializeOwned + Unpin + Send + 'static,
214 > AsyncReadConverse<R, W, T>
215{
216 pub async fn drive_forever(mut self) {
223 while StreamExt::next(&mut self).await.is_some() {}
224 }
225}
226
227impl<R: AsyncRead + Unpin, W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> Stream
228 for AsyncReadConverse<R, W, T>
229{
230 type Item = Result<ReceivedMessage<W, T>, Error>;
231
232 fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
233 let Self {
234 ref mut raw,
235 ref mut reply_data_receiver,
236 ref mut pending_reply,
237 ref raw_write,
238 } = self.get_mut();
239 loop {
240 match futures_core::ready!(Pin::new(&mut *raw).poll_next(cx)) {
241 Some(r) => {
242 let i = r?;
243 while let Ok(reply_data) = reply_data_receiver.try_recv() {
244 pending_reply.push(reply_data);
245 }
246 let mut user_message = Some(i.user_message);
247 pending_reply.retain_mut(|pending_reply| {
248 if let Some(reply_sender) = pending_reply.reply_sender.as_ref() {
249 if reply_sender.is_closed() {
250 return false;
251 }
252 }
253 let matches =
254 i.is_reply && pending_reply.conversation_id == i.conversation_id;
255 if matches {
256 let _ = pending_reply
257 .reply_sender
258 .take()
259 .expect("infallible")
260 .send(Ok(user_message.take().expect("infallible")));
261 }
262 !matches
263 });
264 if !i.is_reply {
265 return Poll::Ready(Some(Ok(ReceivedMessage {
266 message: Some(user_message.take().expect("infallible")),
267 conversation_id: i.conversation_id,
268 raw_write: Arc::clone(raw_write),
269 })));
270 } else {
271 continue;
272 }
273 }
274 None => return Poll::Ready(None),
275 }
276 }
277 }
278}
279
280pub struct AsyncWriteConverse<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> {
285 raw: Arc<Mutex<AsyncWriteTyped<W, InternalMessage<T>>>>,
286 reply_data_sender: mpsc::UnboundedSender<ReplySender<T>>,
287 next_id: u64,
288}
289
290impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncWriteConverse<W, T> {
291 pub async fn with_inner<F: FnOnce(&W) -> R, R>(&self, f: F) -> R {
292 f(self.raw.lock().await.inner())
293 }
294
295 pub async fn optimize_memory_usage(&mut self) {
301 self.raw.lock().await.optimize_memory_usage()
302 }
303}
304
305impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncWriteConverse<W, T> {
306 pub async fn ask(&mut self, message: T) -> Result<T, Error> {
308 self.ask_timeout(DEFAULT_TIMEOUT, message).await
309 }
310
311 pub async fn ask_timeout(&mut self, timeout: Duration, message: T) -> Result<T, Error> {
313 match self.send_timeout(timeout, message).await {
314 Ok(fut) => fut.await,
315 Err(e) => Err(e),
316 }
317 }
318
319 pub async fn send(
323 &mut self,
324 message: T,
325 ) -> Result<impl Future<Output = Result<T, Error>>, Error> {
326 self.send_timeout(DEFAULT_TIMEOUT, message).await
327 }
328
329 pub async fn send_timeout(
334 &mut self,
335 timeout: Duration,
336 message: T,
337 ) -> Result<impl Future<Output = Result<T, Error>>, Error> {
338 let (reply_sender, reply_receiver) = oneshot::channel();
339 let read_half_dropped = self
340 .reply_data_sender
341 .send(ReplySender {
342 reply_sender: Some(reply_sender),
343 conversation_id: self.next_id,
344 })
345 .is_err();
346 SinkExt::send(
347 &mut *self.raw.lock().await,
348 InternalMessage {
349 user_message: message,
350 conversation_id: self.next_id,
351 is_reply: false,
352 },
353 )
354 .await?;
355 self.next_id = self.next_id.wrapping_add(1);
356 Ok(async move {
357 if read_half_dropped {
358 return Err(Error::ReadHalfDropped);
359 }
360 let res = tokio::time::timeout(timeout, reply_receiver).await;
361 match res {
362 Ok(Ok(Ok(value))) => Ok(value),
363 Ok(Ok(Err(e))) => Err(e),
364 Ok(Err(_)) => Err(Error::ReadHalfDropped),
365 Err(_) => Err(Error::Timeout),
366 }
367 })
368 }
369}