use std::{
io::{Read, Write},
marker::PhantomData,
};
use bytes::{Buf, BufMut, BytesMut};
use prost::Message;
use tendermint_proto::v0_38::abci::{Request, Response};
use crate::error::Error;
pub const MAX_VARINT_LENGTH: usize = 16;
pub type ServerCodec<S> = Codec<S, Request, Response>;
#[cfg(feature = "client")]
pub type ClientCodec<S> = Codec<S, Response, Request>;
pub struct Codec<S, I, O> {
stream: S,
read_buf: BytesMut,
read_window: Vec<u8>,
write_buf: BytesMut,
_incoming: PhantomData<I>,
_outgoing: PhantomData<O>,
}
impl<S, I, O> Codec<S, I, O>
where
S: Read + Write,
I: Message + Default,
O: Message,
{
pub fn new(stream: S, read_buf_size: usize) -> Self {
Self {
stream,
read_buf: BytesMut::new(),
read_window: vec![0_u8; read_buf_size],
write_buf: BytesMut::new(),
_incoming: Default::default(),
_outgoing: Default::default(),
}
}
}
impl<S, I, O> Iterator for Codec<S, I, O>
where
S: Read,
I: Message + Default,
{
type Item = Result<I, Error>;
fn next(&mut self) -> Option<Self::Item> {
loop {
match decode_length_delimited::<I>(&mut self.read_buf) {
Ok(Some(incoming)) => return Some(Ok(incoming)),
Err(e) => return Some(Err(e)),
_ => (), }
let bytes_read = match self.stream.read(self.read_window.as_mut()) {
Ok(br) => br,
Err(e) => return Some(Err(Error::io(e))),
};
if bytes_read == 0 {
return None;
}
self.read_buf
.extend_from_slice(&self.read_window[..bytes_read]);
}
}
}
impl<S, I, O> Codec<S, I, O>
where
S: Write,
O: Message,
{
pub fn send(&mut self, message: O) -> Result<(), Error> {
encode_length_delimited(message, &mut self.write_buf)?;
while !self.write_buf.is_empty() {
let bytes_written = self
.stream
.write(self.write_buf.as_ref())
.map_err(Error::io)?;
if bytes_written == 0 {
return Err(Error::io(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write to underlying stream",
)));
}
self.write_buf.advance(bytes_written);
}
self.stream.flush().map_err(Error::io)?;
Ok(())
}
}
pub fn encode_length_delimited<M, B>(message: M, mut dst: &mut B) -> Result<(), Error>
where
M: Message,
B: BufMut,
{
let mut buf = BytesMut::new();
message.encode(&mut buf).map_err(Error::encode)?;
let buf = buf.freeze();
prost::encoding::encode_varint(buf.len() as u64, &mut dst);
dst.put(buf);
Ok(())
}
pub fn decode_length_delimited<M>(src: &mut BytesMut) -> Result<Option<M>, Error>
where
M: Message + Default,
{
let src_len = src.len();
let mut tmp = src.clone().freeze();
let encoded_len = match prost::encoding::decode_varint(&mut tmp) {
Ok(len) => len,
Err(_) if src_len <= MAX_VARINT_LENGTH => return Ok(None),
Err(e) => return Err(Error::decode(e)),
};
let remaining = tmp.remaining() as u64;
if remaining < encoded_len {
Ok(None)
} else {
let delim_len = src_len - tmp.remaining();
src.advance(delim_len + (encoded_len as usize));
let mut result_bytes = BytesMut::from(tmp.split_to(encoded_len as usize).as_ref());
let res = M::decode(&mut result_bytes).map_err(Error::decode)?;
Ok(Some(res))
}
}