use prost::DecodeError;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ReadError {
#[error("Cannot read data: {0}")]
IO(#[from] std::io::Error),
#[error("Negative message size: {size}")]
NegativeMessageSize { size: i32 },
#[error("Message too large, limit is {limit} bytes but got {actual} bytes")]
MessageTooLarge { limit: usize, actual: usize },
#[error("Fail to decode error response: {0}")]
ProtoErrorResponseDecodeError(#[from] DecodeError),
}
pub trait AsyncMessageRead {
fn read_message(
&mut self,
max_message_size: usize,
) -> impl Future<Output = Result<Vec<u8>, ReadError>> + Send;
}
impl<R> AsyncMessageRead for R
where
R: AsyncRead + Send + Unpin,
{
async fn read_message(&mut self, max_message_size: usize) -> Result<Vec<u8>, ReadError> {
let mut len_buf = [0u8; 4];
self.read_exact(&mut len_buf).await?;
let len = i32::from_be_bytes(len_buf);
let len = usize::try_from(len).map_err(|_| ReadError::NegativeMessageSize { size: len })?;
if len > max_message_size {
let mut to_read = len;
let mut buf = vec![]; while to_read > 0 {
let step = max_message_size.min(to_read);
buf.resize(step, 0);
self.read_exact(&mut buf).await?;
to_read -= step;
}
return Err(ReadError::MessageTooLarge {
limit: max_message_size,
actual: len,
});
}
let mut buf = vec![0u8; len];
self.read_exact(&mut buf).await?;
Ok(buf)
}
}
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum WriteError {
#[error("Cannot write data: {0}")]
IO(#[from] std::io::Error),
#[error("Message too large: {size}")]
TooLarge { size: usize },
}
pub trait AsyncMessageWrite {
fn write_message(&mut self, msg: &[u8]) -> impl Future<Output = Result<(), WriteError>> + Send;
}
impl<W> AsyncMessageWrite for W
where
W: AsyncWrite + Send + Unpin,
{
async fn write_message(&mut self, msg: &[u8]) -> Result<(), WriteError> {
let len = i32::try_from(msg.len()).map_err(|_| WriteError::TooLarge { size: msg.len() })?;
self.write_all(len.to_be_bytes().as_ref()).await?;
if !msg.is_empty() {
self.write_all(msg).await?;
}
Ok(())
}
}