1use prost::decode_length_delimiter;
2use prost::length_delimiter_len;
3use prost::Message;
4use std::io::Read;
5use std::io::Write;
6use thiserror::Error;
7#[cfg(feature = "async")]
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9
10#[derive(Error, Debug)]
11pub enum Error {
12 #[error("io error: {0}")]
13 IoError(#[from] std::io::Error),
14 #[error("prost decode error: {0}")]
15 ProstDecodeError(#[from] prost::DecodeError),
16 #[error("prost encode error: {0}")]
17 ProstEncodeError(#[from] prost::EncodeError),
18}
19
20pub type Result<T> = std::result::Result<T, Error>;
21
22pub struct Stream<T> {
23 stream: T,
24 buf: Vec<u8>,
25 send_buf: Vec<u8>,
26}
27
28impl<T: Read + Write> Stream<T> {
29 pub fn new(stream: T) -> Self {
30 Self {
31 stream,
32 buf: vec![0; 1024],
33 send_buf: Vec::with_capacity(1024),
34 }
35 }
36
37 pub fn into_inner(self) -> T {
38 self.stream
39 }
40
41 pub fn send(&mut self, msg: &impl Message) -> Result<()> {
42 let buf = &mut self.send_buf;
43 buf.clear();
44 let sz = msg.encoded_len() + 10;
45 buf.reserve(sz);
46
47 msg.encode_length_delimited(buf)?;
48 self.stream.write_all(buf)?;
49 Ok(())
50 }
51
52 pub fn recv<M: Message + Default>(&mut self) -> Result<M> {
53 let buf = &mut self.buf;
54 let stream = &mut self.stream;
55
56 stream.read_exact(&mut buf[..1])?;
59
60 match decode_length_delimiter(&buf[..1]) {
61 Ok(sz) => {
62 if sz > buf.len() {
63 buf.resize(sz, 0);
64 }
65 stream.read_exact(&mut buf[..sz])?;
66 Ok(M::decode(&buf[..sz])?)
67 }
68 Err(_) => {
69 stream.read_exact(&mut buf[1..10])?;
71 let sz = decode_length_delimiter(&buf[..10])?;
72 let delimiter_len = length_delimiter_len(sz);
73 let idx = delimiter_len;
74 let left = sz - (10 - idx);
75
76 if 10 + left > buf.len() {
77 buf.resize(10 + left, 0);
78 }
79
80 stream.read_exact(&mut buf[10..left])?;
81 Ok(M::decode(&buf[idx..idx + sz])?)
82 }
83 }
84 }
85}
86
87#[cfg(feature = "async")]
88pub struct AsyncStream<T> {
89 stream: T,
90 buf: Vec<u8>,
91 send_buf: Vec<u8>,
92}
93
94#[cfg(feature = "async")]
95impl<T: AsyncReadExt + AsyncWriteExt + Unpin> AsyncStream<T> {
96 pub fn new(stream: T) -> Self {
97 Self {
98 stream,
99 buf: vec![0u8; 1024],
100 send_buf: Vec::with_capacity(1024),
101 }
102 }
103
104 pub fn into_inner(self) -> T {
105 self.stream
106 }
107
108 pub async fn send(&mut self, msg: &impl Message) -> Result<()> {
109 let buf = &mut self.send_buf;
110 buf.clear();
111 let sz = msg.encoded_len() + 10;
112 buf.reserve(sz);
113
114 msg.encode_length_delimited(buf)?;
115
116 self.stream
117 .write_all(buf) .await
119 .map_err(Into::into)
120 }
121
122 pub async fn recv<M: Message + Default>(&mut self) -> Result<M> {
123 let buf = &mut self.buf;
124 let stream = &mut self.stream;
125
126 stream.read_exact(&mut buf[..1]).await?;
129
130 match decode_length_delimiter(&buf[..1]) {
131 Ok(sz) => {
132 if sz > buf.len() {
133 buf.resize(sz, 0);
134 }
135 stream.read_exact(&mut buf[..sz]).await?;
136 Ok(M::decode(&buf[..sz])?)
137 }
138 Err(_) => {
139 stream.read_exact(&mut buf[1..10]).await?;
141 let sz = decode_length_delimiter(&buf[..10])?;
142 let delimiter_len = length_delimiter_len(sz);
143 let idx = delimiter_len;
144 let left = sz - (10 - idx);
145
146 if 10 + left > buf.len() {
147 buf.resize(10 + left, 0);
148 }
149
150 stream.read_exact(&mut buf[10..left]).await?;
151 Ok(M::decode(&buf[idx..idx + sz])?)
152 }
153 }
154 }
155}