async_abci/
codec.rs

1use bytes::{Buf, BufMut, BytesMut};
2use prost::Message;
3#[cfg(feature = "smol-backend")]
4use smol::io::{AsyncReadExt, AsyncWriteExt};
5#[cfg(feature = "smol-backend")]
6use smol::prelude::{AsyncRead, AsyncWrite};
7use tm_protos::abci::{Request, Response};
8#[cfg(feature = "tokio-backend")]
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10
11use crate::error::Error;
12
13/// The maximum number of bytes we expect in a varint. We use this to check if
14/// we're encountering a decoding error for a varint.
15pub const MAX_VARINT_LENGTH: usize = 16;
16
17pub struct ICodec<R> {
18    stream: R,
19    // Long-running read buffer
20    read_buf: BytesMut,
21    // Fixed-length read window
22    read_window: Vec<u8>,
23}
24
25impl<R> ICodec<R> {
26    /// Constructor.
27    pub fn new(stream: R, read_buf_size: usize) -> Self {
28        Self {
29            stream,
30            read_buf: BytesMut::new(),
31            read_window: vec![0_u8; read_buf_size],
32        }
33    }
34}
35
36// Iterating over a codec produces instances of `Result<I>`.
37impl<R> ICodec<R>
38where
39    R: AsyncRead + Unpin,
40{
41    pub async fn next(&mut self) -> Option<Result<Request, Error>> {
42        loop {
43            // Try to decode an incoming message from our buffer first
44            match decode_length_delimited::<Request>(&mut self.read_buf) {
45                Ok(Some(incoming)) => return Some(Ok(incoming)),
46                Err(e) => return Some(Err(e)),
47                _ => (), // not enough data to decode a message, let's continue.
48            }
49
50            // If we don't have enough data to decode a message, try to read
51            // more
52            let bytes_read = match self.stream.read(self.read_window.as_mut()).await {
53                Ok(br) => br,
54                Err(e) => return Some(Err(Error::StdIoError(e))),
55            };
56            if bytes_read == 0 {
57                // The underlying stream terminated
58                return None;
59            }
60            self.read_buf
61                .extend_from_slice(&self.read_window[..bytes_read]);
62        }
63    }
64}
65
66pub struct OCodec<W> {
67    stream: W,
68    write_buf: BytesMut,
69}
70
71impl<W> OCodec<W> {
72    /// Constructor.
73    pub fn new(stream: W) -> Self {
74        Self {
75            stream,
76            write_buf: BytesMut::default(),
77        }
78    }
79}
80
81impl<W> OCodec<W>
82where
83    W: AsyncWrite + Unpin,
84{
85    /// Send a message using this codec.
86    pub async fn send(&mut self, message: Response) -> Result<(), Error> {
87        encode_length_delimited(message, &mut self.write_buf)?;
88        while !self.write_buf.is_empty() {
89            let bytes_written = self
90                .stream
91                .write(self.write_buf.as_ref())
92                .await
93                .map_err(Error::StdIoError)?;
94
95            if bytes_written == 0 {
96                return Err(Error::StdIoError(std::io::Error::new(
97                    std::io::ErrorKind::WriteZero,
98                    "failed to write to underlying stream",
99                )));
100            }
101            self.write_buf.advance(bytes_written);
102        }
103
104        self.stream.flush().await.map_err(Error::StdIoError)?;
105
106        Ok(())
107    }
108}
109
110/// Encode the given message with a length prefix.
111pub fn encode_length_delimited<M, B>(message: M, mut dst: &mut B) -> Result<(), Error>
112where
113    M: Message,
114    B: BufMut,
115{
116    let mut buf = BytesMut::new();
117    message.encode(&mut buf).map_err(Error::ProstEncodeError)?;
118
119    let buf = buf.freeze();
120    encode_varint(buf.len() as u64, &mut dst);
121    dst.put(buf);
122    Ok(())
123}
124
125/// Attempt to decode a message of type `M` from the given source buffer.
126pub fn decode_length_delimited<M>(src: &mut BytesMut) -> Result<Option<M>, Error>
127where
128    M: Message + Default,
129{
130    let src_len = src.len();
131    let mut tmp = src.clone().freeze();
132    let encoded_len = match decode_varint(&mut tmp) {
133        Ok(len) => len,
134        // We've potentially only received a partial length delimiter
135        Err(_) if src_len <= MAX_VARINT_LENGTH => return Ok(None),
136        Err(e) => return Err(e),
137    };
138    let remaining = tmp.remaining() as u64;
139    if remaining < encoded_len {
140        // We don't have enough data yet to decode the entire message
141        Ok(None)
142    } else {
143        let delim_len = src_len - tmp.remaining();
144        // We only advance the source buffer once we're sure we have enough
145        // data to try to decode the result.
146        src.advance(delim_len + (encoded_len as usize));
147
148        let mut result_bytes = BytesMut::from(tmp.split_to(encoded_len as usize).as_ref());
149        let res = M::decode(&mut result_bytes).map_err(Error::ProstDecodeError)?;
150
151        Ok(Some(res))
152    }
153}
154
155pub fn encode_varint<B: BufMut>(val: u64, mut buf: &mut B) {
156    prost::encoding::encode_varint(val << 1, &mut buf);
157}
158
159pub fn decode_varint<B: Buf>(mut buf: &mut B) -> Result<u64, Error> {
160    let len = prost::encoding::decode_varint(&mut buf)?;
161    Ok(len >> 1)
162}