1use bytes::{Buf, BufMut, BytesMut};
2use prost::Message;
3#[cfg(feature = "smol-backend")]
4use smol::io::{AsyncReadExt, AsyncWriteExt};
5#[cfg(feature = "smol-backend")]
6use smol::prelude::{AsyncRead, AsyncWrite};
7use tm_protos::abci::{Request, Response};
8#[cfg(feature = "tokio-backend")]
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10
11use crate::error::Error;
12
13pub const MAX_VARINT_LENGTH: usize = 16;
16
17pub struct ICodec<R> {
18 stream: R,
19 read_buf: BytesMut,
21 read_window: Vec<u8>,
23}
24
25impl<R> ICodec<R> {
26 pub fn new(stream: R, read_buf_size: usize) -> Self {
28 Self {
29 stream,
30 read_buf: BytesMut::new(),
31 read_window: vec![0_u8; read_buf_size],
32 }
33 }
34}
35
36impl<R> ICodec<R>
38where
39 R: AsyncRead + Unpin,
40{
41 pub async fn next(&mut self) -> Option<Result<Request, Error>> {
42 loop {
43 match decode_length_delimited::<Request>(&mut self.read_buf) {
45 Ok(Some(incoming)) => return Some(Ok(incoming)),
46 Err(e) => return Some(Err(e)),
47 _ => (), }
49
50 let bytes_read = match self.stream.read(self.read_window.as_mut()).await {
53 Ok(br) => br,
54 Err(e) => return Some(Err(Error::StdIoError(e))),
55 };
56 if bytes_read == 0 {
57 return None;
59 }
60 self.read_buf
61 .extend_from_slice(&self.read_window[..bytes_read]);
62 }
63 }
64}
65
66pub struct OCodec<W> {
67 stream: W,
68 write_buf: BytesMut,
69}
70
71impl<W> OCodec<W> {
72 pub fn new(stream: W) -> Self {
74 Self {
75 stream,
76 write_buf: BytesMut::default(),
77 }
78 }
79}
80
81impl<W> OCodec<W>
82where
83 W: AsyncWrite + Unpin,
84{
85 pub async fn send(&mut self, message: Response) -> Result<(), Error> {
87 encode_length_delimited(message, &mut self.write_buf)?;
88 while !self.write_buf.is_empty() {
89 let bytes_written = self
90 .stream
91 .write(self.write_buf.as_ref())
92 .await
93 .map_err(Error::StdIoError)?;
94
95 if bytes_written == 0 {
96 return Err(Error::StdIoError(std::io::Error::new(
97 std::io::ErrorKind::WriteZero,
98 "failed to write to underlying stream",
99 )));
100 }
101 self.write_buf.advance(bytes_written);
102 }
103
104 self.stream.flush().await.map_err(Error::StdIoError)?;
105
106 Ok(())
107 }
108}
109
110pub fn encode_length_delimited<M, B>(message: M, mut dst: &mut B) -> Result<(), Error>
112where
113 M: Message,
114 B: BufMut,
115{
116 let mut buf = BytesMut::new();
117 message.encode(&mut buf).map_err(Error::ProstEncodeError)?;
118
119 let buf = buf.freeze();
120 encode_varint(buf.len() as u64, &mut dst);
121 dst.put(buf);
122 Ok(())
123}
124
125pub fn decode_length_delimited<M>(src: &mut BytesMut) -> Result<Option<M>, Error>
127where
128 M: Message + Default,
129{
130 let src_len = src.len();
131 let mut tmp = src.clone().freeze();
132 let encoded_len = match decode_varint(&mut tmp) {
133 Ok(len) => len,
134 Err(_) if src_len <= MAX_VARINT_LENGTH => return Ok(None),
136 Err(e) => return Err(e),
137 };
138 let remaining = tmp.remaining() as u64;
139 if remaining < encoded_len {
140 Ok(None)
142 } else {
143 let delim_len = src_len - tmp.remaining();
144 src.advance(delim_len + (encoded_len as usize));
147
148 let mut result_bytes = BytesMut::from(tmp.split_to(encoded_len as usize).as_ref());
149 let res = M::decode(&mut result_bytes).map_err(Error::ProstDecodeError)?;
150
151 Ok(Some(res))
152 }
153}
154
155pub fn encode_varint<B: BufMut>(val: u64, mut buf: &mut B) {
156 prost::encoding::encode_varint(val << 1, &mut buf);
157}
158
159pub fn decode_varint<B: Buf>(mut buf: &mut B) -> Result<u64, Error> {
160 let len = prost::encoding::decode_varint(&mut buf)?;
161 Ok(len >> 1)
162}