use prost::decode_length_delimiter;
use prost::length_delimiter_len;
use prost::Message;
use std::io::Read;
use std::io::Write;
use thiserror::Error;
#[cfg(feature = "async")]
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[derive(Error, Debug)]
pub enum Error {
#[error("io error: {0}")]
IoError(#[from] std::io::Error),
#[error("prost decode error: {0}")]
ProstDecodeError(#[from] prost::DecodeError),
#[error("prost encode error: {0}")]
ProstEncodeError(#[from] prost::EncodeError),
}
pub type Result<T> = std::result::Result<T, Error>;
pub struct Stream<T> {
stream: T,
buf: Vec<u8>,
send_buf: Vec<u8>,
}
impl<T: Read + Write> Stream<T> {
pub fn new(stream: T) -> Self {
Self {
stream,
buf: vec![0; 1024],
send_buf: Vec::with_capacity(1024),
}
}
pub fn send(&mut self, msg: &impl Message) -> Result<()> {
let buf = &mut self.send_buf;
let sz = msg.encoded_len() + 10;
buf.reserve(sz);
msg.encode_length_delimited(buf).unwrap();
self.stream.write_all(buf)?;
Ok(())
}
pub fn recv<M: Message + Default>(&mut self) -> Result<M> {
let buf = &mut self.buf;
let stream = &mut self.stream;
stream.read_exact(&mut buf[..1])?;
match decode_length_delimiter(&buf[..1]) {
Ok(sz) => {
if sz > buf.len() {
buf.resize(sz, 0);
}
stream.read_exact(&mut buf[..sz])?;
Ok(M::decode(&buf[..sz])?)
}
Err(_) => {
stream.read_exact(&mut buf[1..10])?;
let sz = decode_length_delimiter(&buf[..10])?;
let delimiter_len = length_delimiter_len(sz);
let idx = delimiter_len;
let left = sz - (10 - idx);
if 10 + left > buf.len() {
buf.resize(10 + left, 0);
}
stream.read_exact(&mut buf[10..left])?;
Ok(M::decode(&buf[idx..idx + sz])?)
}
}
}
}
#[cfg(feature = "async")]
pub struct AsyncStream<T> {
stream: T,
buf: Vec<u8>,
}
#[cfg(feature = "async")]
impl<T: AsyncReadExt + AsyncWriteExt + Unpin> AsyncStream<T> {
pub fn new(stream: T) -> Self {
Self {
stream,
buf: vec![0u8; 1024],
}
}
pub async fn send(&mut self, msg: &impl Message) -> Result<()> {
self.stream
.write_all(&msg.encode_length_delimited_to_vec())
.await
.map_err(Into::into)
}
pub async fn recv<M: Message + Default>(&mut self) -> Result<M> {
let buf = &mut self.buf;
let stream = &mut self.stream;
stream.read_exact(&mut buf[..1]).await?;
match decode_length_delimiter(&buf[..1]) {
Ok(sz) => {
if sz > buf.len() {
buf.resize(sz, 0);
}
stream.read_exact(&mut buf[..sz]).await?;
Ok(M::decode(&buf[..sz])?)
}
Err(_) => {
stream.read_exact(&mut buf[1..10]).await?;
let sz = decode_length_delimiter(&buf[..10])?;
let delimiter_len = length_delimiter_len(sz);
let idx = delimiter_len;
let left = sz - (10 - idx);
if 10 + left > buf.len() {
buf.resize(10 + left, 0);
}
stream.read_exact(&mut buf[10..left]).await?;
Ok(M::decode(&buf[idx..idx + sz])?)
}
}
}
}