1use std::{
8 io::{Read, Write},
9 marker::PhantomData,
10};
11
12use bytes::{Buf, BufMut, BytesMut};
13use prost::Message;
14use tendermint_proto::v0_38::abci::{Request, Response};
15
16use crate::error::Error;
17
18pub const MAX_VARINT_LENGTH: usize = 16;
21
22pub type ServerCodec<S> = Codec<S, Request, Response>;
24
25#[cfg(feature = "client")]
26pub type ClientCodec<S> = Codec<S, Response, Request>;
28
29pub struct Codec<S, I, O> {
32 stream: S,
33 read_buf: BytesMut,
35 read_window: Vec<u8>,
37 write_buf: BytesMut,
38 _incoming: PhantomData<I>,
39 _outgoing: PhantomData<O>,
40}
41
42impl<S, I, O> Codec<S, I, O>
43where
44 S: Read + Write,
45 I: Message + Default,
46 O: Message,
47{
48 pub fn new(stream: S, read_buf_size: usize) -> Self {
50 Self {
51 stream,
52 read_buf: BytesMut::new(),
53 read_window: vec![0_u8; read_buf_size],
54 write_buf: BytesMut::new(),
55 _incoming: Default::default(),
56 _outgoing: Default::default(),
57 }
58 }
59}
60
61impl<S, I, O> Iterator for Codec<S, I, O>
63where
64 S: Read,
65 I: Message + Default,
66{
67 type Item = Result<I, Error>;
68
69 fn next(&mut self) -> Option<Self::Item> {
70 loop {
71 match decode_length_delimited::<I>(&mut self.read_buf) {
73 Ok(Some(incoming)) => return Some(Ok(incoming)),
74 Err(e) => return Some(Err(e)),
75 _ => (), }
77
78 let bytes_read = match self.stream.read(self.read_window.as_mut()) {
81 Ok(br) => br,
82 Err(e) => return Some(Err(Error::io(e))),
83 };
84 if bytes_read == 0 {
85 return None;
87 }
88 self.read_buf
89 .extend_from_slice(&self.read_window[..bytes_read]);
90 }
91 }
92}
93
94impl<S, I, O> Codec<S, I, O>
95where
96 S: Write,
97 O: Message,
98{
99 pub fn send(&mut self, message: O) -> Result<(), Error> {
101 encode_length_delimited(message, &mut self.write_buf)?;
102 while !self.write_buf.is_empty() {
103 let bytes_written = self
104 .stream
105 .write(self.write_buf.as_ref())
106 .map_err(Error::io)?;
107
108 if bytes_written == 0 {
109 return Err(Error::io(std::io::Error::new(
110 std::io::ErrorKind::WriteZero,
111 "failed to write to underlying stream",
112 )));
113 }
114 self.write_buf.advance(bytes_written);
115 }
116
117 self.stream.flush().map_err(Error::io)?;
118
119 Ok(())
120 }
121}
122
123pub fn encode_length_delimited<M, B>(message: M, mut dst: &mut B) -> Result<(), Error>
125where
126 M: Message,
127 B: BufMut,
128{
129 let mut buf = BytesMut::new();
130 message.encode(&mut buf).map_err(Error::encode)?;
131
132 let buf = buf.freeze();
133 prost::encoding::encode_varint(buf.len() as u64, &mut dst);
134 dst.put(buf);
135 Ok(())
136}
137
138pub fn decode_length_delimited<M>(src: &mut BytesMut) -> Result<Option<M>, Error>
140where
141 M: Message + Default,
142{
143 let src_len = src.len();
144 let mut tmp = src.clone().freeze();
145 let encoded_len = match prost::encoding::decode_varint(&mut tmp) {
146 Ok(len) => len,
147 Err(_) if src_len <= MAX_VARINT_LENGTH => return Ok(None),
149 Err(e) => return Err(Error::decode(e)),
150 };
151 let remaining = tmp.remaining() as u64;
152 if remaining < encoded_len {
153 Ok(None)
155 } else {
156 let delim_len = src_len - tmp.remaining();
157 src.advance(delim_len + (encoded_len as usize));
160
161 let mut result_bytes = BytesMut::from(tmp.split_to(encoded_len as usize).as_ref());
162 let res = M::decode(&mut result_bytes).map_err(Error::decode)?;
163
164 Ok(Some(res))
165 }
166}