#![doc = include_str!("../README.md")]
use async_io_typed::{AsyncReadTyped, AsyncWriteTyped};
use futures_io::{AsyncRead, AsyncWrite};
use futures_util::{SinkExt, Stream, StreamExt};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{
future::Future,
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use tokio::sync::{mpsc, oneshot, Mutex};
#[cfg(test)]
mod tests;
#[derive(Deserialize, Serialize)]
struct InternalMessage<T> {
user_message: T,
conversation_id: u64,
is_reply: bool,
}
pub struct ReceivedMessage<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> {
message: Option<T>,
conversation_id: u64,
raw_write: Arc<Mutex<AsyncWriteTyped<W, InternalMessage<T>>>>,
}
impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> ReceivedMessage<W, T> {
pub fn message(&self) -> &T {
self.message_opt().expect("message already taken")
}
pub fn message_opt(&self) -> Option<&T> {
self.message.as_ref()
}
pub fn take_message(&mut self) -> T {
self.take_message_opt().expect("message already taken")
}
pub fn take_message_opt(&mut self) -> Option<T> {
self.message.take()
}
pub async fn reply(self, reply: T) -> Result<(), Error> {
SinkExt::send(
&mut *self.raw_write.lock().await,
InternalMessage {
user_message: reply,
is_reply: true,
conversation_id: self.conversation_id,
},
)
.await
.map_err(Into::into)
}
}
struct ReplySender<T> {
reply_sender: Option<oneshot::Sender<Result<T, Error>>>,
conversation_id: u64,
}
#[derive(Debug)]
pub enum Error {
Io(io::Error),
Bincode(bincode::Error),
ReceivedMessageTooLarge,
SentMessageTooLarge,
ChecksumMismatch {
sent_checksum: u64,
computed_checksum: u64,
},
ProtocolVersionMismatch {
our_version: u64,
their_version: u64,
},
ChecksumHandshakeFailed {
checksum_value: u8,
},
Timeout,
ReadHalfDropped,
}
pub use async_io_typed::ChecksumEnabled;
impl From<async_io_typed::Error> for Error {
fn from(e: async_io_typed::Error) -> Self {
match e {
async_io_typed::Error::Io(e) => Error::Io(e),
async_io_typed::Error::Bincode(e) => Error::Bincode(e),
async_io_typed::Error::ReceivedMessageTooLarge => Error::ReceivedMessageTooLarge,
async_io_typed::Error::SentMessageTooLarge => Error::SentMessageTooLarge,
async_io_typed::Error::ChecksumMismatch {
sent_checksum,
computed_checksum,
} => Error::ChecksumMismatch {
sent_checksum,
computed_checksum,
},
async_io_typed::Error::ProtocolVersionMismatch {
our_version,
their_version,
} => Error::ProtocolVersionMismatch {
our_version,
their_version,
},
async_io_typed::Error::ChecksumHandshakeFailed { checksum_value } => {
Error::ChecksumHandshakeFailed { checksum_value }
}
}
}
}
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
pub fn new_duplex_connection_with_limit<
T: DeserializeOwned + Serialize + Unpin,
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
>(
size_limit: u64,
checksum_enabled: ChecksumEnabled,
raw_read: R,
raw_write: W,
) -> (AsyncReadConverse<R, W, T>, AsyncWriteConverse<W, T>) {
let write = Arc::new(Mutex::new(AsyncWriteTyped::new_with_limit(
raw_write,
size_limit,
checksum_enabled,
)));
let write_clone = Arc::clone(&write);
let (reply_data_sender, reply_data_receiver) = mpsc::unbounded_channel();
let read = AsyncReadConverse {
raw: AsyncReadTyped::new_with_limit(raw_read, size_limit, checksum_enabled),
raw_write: write_clone,
reply_data_receiver,
pending_reply: Vec::new(),
};
let write = AsyncWriteConverse {
raw: write,
reply_data_sender,
next_id: 0,
};
(read, write)
}
pub fn new_duplex_connection<
T: DeserializeOwned + Serialize + Unpin,
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
>(
checksum_enabled: ChecksumEnabled,
raw_read: R,
raw_write: W,
) -> (AsyncReadConverse<R, W, T>, AsyncWriteConverse<W, T>) {
new_duplex_connection_with_limit(1024u64.pow(2), checksum_enabled, raw_read, raw_write)
}
pub struct AsyncReadConverse<
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
T: Serialize + DeserializeOwned + Unpin,
> {
raw: AsyncReadTyped<R, InternalMessage<T>>,
raw_write: Arc<Mutex<AsyncWriteTyped<W, InternalMessage<T>>>>,
reply_data_receiver: mpsc::UnboundedReceiver<ReplySender<T>>,
pending_reply: Vec<ReplySender<T>>,
}
impl<R: AsyncRead + Unpin, W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin>
AsyncReadConverse<R, W, T>
{
pub fn inner(&self) -> &R {
self.raw.inner()
}
pub fn optimize_memory_usage(&mut self) {
self.raw.optimize_memory_usage()
}
}
impl<
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
T: Serialize + DeserializeOwned + Unpin + Send + 'static,
> AsyncReadConverse<R, W, T>
{
pub async fn drive_forever(mut self) {
while StreamExt::next(&mut self).await.is_some() {}
}
}
impl<R: AsyncRead + Unpin, W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> Stream
for AsyncReadConverse<R, W, T>
{
type Item = Result<ReceivedMessage<W, T>, Error>;
fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let Self {
ref mut raw,
ref mut reply_data_receiver,
ref mut pending_reply,
ref raw_write,
} = self.get_mut();
loop {
match futures_core::ready!(Pin::new(&mut *raw).poll_next(cx)) {
Some(r) => {
let i = r?;
while let Ok(reply_data) = reply_data_receiver.try_recv() {
pending_reply.push(reply_data);
}
let mut user_message = Some(i.user_message);
pending_reply.retain_mut(|pending_reply| {
if let Some(reply_sender) = pending_reply.reply_sender.as_ref() {
if reply_sender.is_closed() {
return false;
}
}
let matches =
i.is_reply && pending_reply.conversation_id == i.conversation_id;
if matches {
let _ = pending_reply
.reply_sender
.take()
.expect("infallible")
.send(Ok(user_message.take().expect("infallible")));
}
!matches
});
if !i.is_reply {
return Poll::Ready(Some(Ok(ReceivedMessage {
message: Some(user_message.take().expect("infallible")),
conversation_id: i.conversation_id,
raw_write: Arc::clone(raw_write),
})));
} else {
continue;
}
}
None => return Poll::Ready(None),
}
}
}
}
pub struct AsyncWriteConverse<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> {
raw: Arc<Mutex<AsyncWriteTyped<W, InternalMessage<T>>>>,
reply_data_sender: mpsc::UnboundedSender<ReplySender<T>>,
next_id: u64,
}
impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncWriteConverse<W, T> {
pub async fn with_inner<F: FnOnce(&W) -> R, R>(&self, f: F) -> R {
f(self.raw.lock().await.inner())
}
pub async fn optimize_memory_usage(&mut self) {
self.raw.lock().await.optimize_memory_usage()
}
}
impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncWriteConverse<W, T> {
pub async fn ask(&mut self, message: T) -> Result<T, Error> {
self.ask_timeout(DEFAULT_TIMEOUT, message).await
}
pub async fn ask_timeout(&mut self, timeout: Duration, message: T) -> Result<T, Error> {
match self.send_timeout(timeout, message).await {
Ok(fut) => fut.await,
Err(e) => Err(e),
}
}
pub async fn send(
&mut self,
message: T,
) -> Result<impl Future<Output = Result<T, Error>>, Error> {
self.send_timeout(DEFAULT_TIMEOUT, message).await
}
pub async fn send_timeout(
&mut self,
timeout: Duration,
message: T,
) -> Result<impl Future<Output = Result<T, Error>>, Error> {
let (reply_sender, reply_receiver) = oneshot::channel();
let read_half_dropped = self
.reply_data_sender
.send(ReplySender {
reply_sender: Some(reply_sender),
conversation_id: self.next_id,
})
.is_err();
SinkExt::send(
&mut *self.raw.lock().await,
InternalMessage {
user_message: message,
conversation_id: self.next_id,
is_reply: false,
},
)
.await?;
self.next_id = self.next_id.wrapping_add(1);
Ok(async move {
if read_half_dropped {
return Err(Error::ReadHalfDropped);
}
let res = tokio::time::timeout(timeout, reply_receiver).await;
match res {
Ok(Ok(Ok(value))) => Ok(value),
Ok(Ok(Err(e))) => Err(e),
Ok(Err(_)) => Err(Error::ReadHalfDropped),
Err(_) => Err(Error::Timeout),
}
})
}
}