prost_stream/
stream.rs

1use prost::decode_length_delimiter;
2use prost::length_delimiter_len;
3use prost::Message;
4use std::io::Read;
5use std::io::Write;
6use thiserror::Error;
7#[cfg(feature = "async")]
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9
10#[derive(Error, Debug)]
11pub enum Error {
12    #[error("io error: {0}")]
13    IoError(#[from] std::io::Error),
14    #[error("prost decode error: {0}")]
15    ProstDecodeError(#[from] prost::DecodeError),
16    #[error("prost encode error: {0}")]
17    ProstEncodeError(#[from] prost::EncodeError),
18}
19
20pub type Result<T> = std::result::Result<T, Error>;
21
22pub struct Stream<T> {
23    stream: T,
24    buf: Vec<u8>,
25    send_buf: Vec<u8>,
26}
27
28impl<T: Read + Write> Stream<T> {
29    pub fn new(stream: T) -> Self {
30        Self {
31            stream,
32            buf: vec![0; 1024],
33            send_buf: Vec::with_capacity(1024),
34        }
35    }
36
37    pub fn into_inner(self) -> T {
38        self.stream
39    }
40
41    pub fn send(&mut self, msg: &impl Message) -> Result<()> {
42        let buf = &mut self.send_buf;
43        buf.clear();
44        let sz = msg.encoded_len() + 10;
45        buf.reserve(sz);
46
47        msg.encode_length_delimited(buf)?;
48        self.stream.write_all(buf)?;
49        Ok(())
50    }
51
52    pub fn recv<M: Message + Default>(&mut self) -> Result<M> {
53        let buf = &mut self.buf;
54        let stream = &mut self.stream;
55
56        // protobuf 消息的长度信息最少占有 1 byte, 最多占有 10 bytes
57        // 当消息本身的长度小于 128 时占用 1 byte
58        stream.read_exact(&mut buf[..1])?;
59
60        match decode_length_delimiter(&buf[..1]) {
61            Ok(sz) => {
62                if sz > buf.len() {
63                    buf.resize(sz, 0);
64                }
65                stream.read_exact(&mut buf[..sz])?;
66                Ok(M::decode(&buf[..sz])?)
67            }
68            Err(_) => {
69                // protobuf 消息的长度信息最少占有 1 byte, 最多占有 10 bytes
70                stream.read_exact(&mut buf[1..10])?;
71                let sz = decode_length_delimiter(&buf[..10])?;
72                let delimiter_len = length_delimiter_len(sz);
73                let idx = delimiter_len;
74                let left = sz - (10 - idx);
75
76                if 10 + left > buf.len() {
77                    buf.resize(10 + left, 0);
78                }
79
80                stream.read_exact(&mut buf[10..left])?;
81                Ok(M::decode(&buf[idx..idx + sz])?)
82            }
83        }
84    }
85}
86
87#[cfg(feature = "async")]
88pub struct AsyncStream<T> {
89    stream: T,
90    buf: Vec<u8>,
91    send_buf: Vec<u8>,
92}
93
94#[cfg(feature = "async")]
95impl<T: AsyncReadExt + AsyncWriteExt + Unpin> AsyncStream<T> {
96    pub fn new(stream: T) -> Self {
97        Self {
98            stream,
99            buf: vec![0u8; 1024],
100            send_buf: Vec::with_capacity(1024),
101        }
102    }
103
104    pub fn into_inner(self) -> T {
105        self.stream
106    }
107
108    pub async fn send(&mut self, msg: &impl Message) -> Result<()> {
109        let buf = &mut self.send_buf;
110        buf.clear();
111        let sz = msg.encoded_len() + 10;
112        buf.reserve(sz);
113
114        msg.encode_length_delimited(buf)?;
115
116        self.stream
117            .write_all(buf) // &msg.encode_length_delimited_to_vec()
118            .await
119            .map_err(Into::into)
120    }
121
122    pub async fn recv<M: Message + Default>(&mut self) -> Result<M> {
123        let buf = &mut self.buf;
124        let stream = &mut self.stream;
125
126        // protobuf 消息的长度信息最少占有 1 byte, 最多占有 10 bytes
127        // 当消息本身的长度小于 128 时占用 1 byte
128        stream.read_exact(&mut buf[..1]).await?;
129
130        match decode_length_delimiter(&buf[..1]) {
131            Ok(sz) => {
132                if sz > buf.len() {
133                    buf.resize(sz, 0);
134                }
135                stream.read_exact(&mut buf[..sz]).await?;
136                Ok(M::decode(&buf[..sz])?)
137            }
138            Err(_) => {
139                // protobuf 消息的长度信息最少占有 1 byte, 最多占有 10 bytes
140                stream.read_exact(&mut buf[1..10]).await?;
141                let sz = decode_length_delimiter(&buf[..10])?;
142                let delimiter_len = length_delimiter_len(sz);
143                let idx = delimiter_len;
144                let left = sz - (10 - idx);
145
146                if 10 + left > buf.len() {
147                    buf.resize(10 + left, 0);
148                }
149
150                stream.read_exact(&mut buf[10..left]).await?;
151                Ok(M::decode(&buf[idx..idx + sz])?)
152            }
153        }
154    }
155}