use crate::Result;
use bytes::{Buf, BufMut, BytesMut};
use prost::Message;
use std::io::{Read, Write};
use std::marker::PhantomData;
use tendermint_proto::abci::{Request, Response};
pub const MAX_VARINT_LENGTH: usize = 16;
pub type ServerCodec<S> = Codec<S, Request, Response>;
#[cfg(feature = "client")]
pub type ClientCodec<S> = Codec<S, Response, Request>;
pub struct Codec<S, I, O> {
stream: S,
read_buf: BytesMut,
read_window: Vec<u8>,
write_buf: BytesMut,
_incoming: PhantomData<I>,
_outgoing: PhantomData<O>,
}
impl<S, I, O> Codec<S, I, O>
where
S: Read + Write,
I: Message + Default,
O: Message,
{
pub fn new(stream: S, read_buf_size: usize) -> Self {
Self {
stream,
read_buf: BytesMut::new(),
read_window: vec![0_u8; read_buf_size],
write_buf: BytesMut::new(),
_incoming: Default::default(),
_outgoing: Default::default(),
}
}
}
impl<S, I, O> Iterator for Codec<S, I, O>
where
S: Read,
I: Message + Default,
{
type Item = Result<I>;
fn next(&mut self) -> Option<Self::Item> {
loop {
match decode_length_delimited::<I>(&mut self.read_buf) {
Ok(Some(incoming)) => return Some(Ok(incoming)),
Err(e) => return Some(Err(e)),
_ => (), }
let bytes_read = match self.stream.read(self.read_window.as_mut()) {
Ok(br) => br,
Err(e) => return Some(Err(e.into())),
};
if bytes_read == 0 {
return None;
}
self.read_buf
.extend_from_slice(&self.read_window[..bytes_read]);
}
}
}
impl<S, I, O> Codec<S, I, O>
where
S: Write,
O: Message,
{
pub fn send(&mut self, message: O) -> Result<()> {
encode_length_delimited(message, &mut self.write_buf)?;
while !self.write_buf.is_empty() {
let bytes_written = self.stream.write(self.write_buf.as_ref())?;
if bytes_written == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write to underlying stream",
)
.into());
}
self.write_buf.advance(bytes_written);
}
Ok(self.stream.flush()?)
}
}
pub fn encode_length_delimited<M, B>(message: M, mut dst: &mut B) -> Result<()>
where
M: Message,
B: BufMut,
{
let mut buf = BytesMut::new();
message.encode(&mut buf)?;
let buf = buf.freeze();
encode_varint(buf.len() as u64, &mut dst);
dst.put(buf);
Ok(())
}
pub fn decode_length_delimited<M>(src: &mut BytesMut) -> Result<Option<M>>
where
M: Message + Default,
{
let src_len = src.len();
let mut tmp = src.clone().freeze();
let encoded_len = match decode_varint(&mut tmp) {
Ok(len) => len,
Err(_) if src_len <= MAX_VARINT_LENGTH => return Ok(None),
Err(e) => return Err(e),
};
let remaining = tmp.remaining() as u64;
if remaining < encoded_len {
Ok(None)
} else {
let delim_len = src_len - tmp.remaining();
src.advance(delim_len + (encoded_len as usize));
let mut result_bytes = BytesMut::from(tmp.split_to(encoded_len as usize).as_ref());
Ok(Some(M::decode(&mut result_bytes)?))
}
}
pub fn encode_varint<B: BufMut>(val: u64, mut buf: &mut B) {
prost::encoding::encode_varint(val << 1, &mut buf);
}
pub fn decode_varint<B: Buf>(mut buf: &mut B) -> Result<u64> {
let len = prost::encoding::decode_varint(&mut buf)?;
Ok(len >> 1)
}