1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
use prost::decode_length_delimiter;
use prost::length_delimiter_len;
use prost::Message;
use std::io::Read;
use std::io::Write;
use thiserror::Error;
#[cfg(feature = "async")]
use tokio::io::{AsyncReadExt, AsyncWriteExt};

#[derive(Error, Debug)]
pub enum Error {
    #[error("io error: {0}")]
    IoError(#[from] std::io::Error),
    #[error("prost decode error: {0}")]
    ProstDecodeError(#[from] prost::DecodeError),
    #[error("prost encode error: {0}")]
    ProstEncodeError(#[from] prost::EncodeError),
}

pub type Result<T> = std::result::Result<T, Error>;

pub struct Stream<T> {
    stream: T,
    buf: Vec<u8>,
    send_buf: Vec<u8>,
}

impl<T: Read + Write> Stream<T> {
    pub fn new(stream: T) -> Self {
        Self {
            stream,
            buf: vec![0; 1024],
            send_buf: Vec::with_capacity(1024),
        }
    }

    pub fn send(&mut self, msg: &impl Message) -> Result<()> {
        let buf = &mut self.send_buf;
        let sz = msg.encoded_len() + 10;
        buf.reserve(sz);

        // we've reserved enough capacity for the message, so unwrap here is safe
        msg.encode_length_delimited(buf).unwrap();
        self.stream.write_all(buf)?;
        Ok(())
    }

    pub fn recv<M: Message + Default>(&mut self) -> Result<M> {
        let buf = &mut self.buf;
        let stream = &mut self.stream;

        // protobuf 消息的长度信息最少占有 1 byte, 最多占有 10 bytes
        // 当消息本身的长度小于 128 时占用 1 byte
        stream.read_exact(&mut buf[..1])?;

        match decode_length_delimiter(&buf[..1]) {
            Ok(sz) => {
                if sz > buf.len() {
                    buf.resize(sz, 0);
                }
                stream.read_exact(&mut buf[..sz])?;
                Ok(M::decode(&buf[..sz])?)
            }
            Err(_) => {
                // protobuf 消息的长度信息最少占有 1 byte, 最多占有 10 bytes
                stream.read_exact(&mut buf[1..10])?;
                let sz = decode_length_delimiter(&buf[..10])?;
                let delimiter_len = length_delimiter_len(sz);
                let idx = delimiter_len;
                let left = sz - (10 - idx);

                if 10 + left > buf.len() {
                    buf.resize(10 + left, 0);
                }

                stream.read_exact(&mut buf[10..left])?;
                Ok(M::decode(&buf[idx..idx + sz])?)
            }
        }
    }
}

#[cfg(feature = "async")]
pub struct AsyncStream<T> {
    stream: T,
    buf: Vec<u8>,
}

#[cfg(feature = "async")]
impl<T: AsyncReadExt + AsyncWriteExt + Unpin> AsyncStream<T> {
    pub fn new(stream: T) -> Self {
        Self {
            stream,
            buf: vec![0u8; 1024],
        }
    }

    pub async fn send(&mut self, msg: &impl Message) -> Result<()> {
        self.stream
            .write_all(&msg.encode_length_delimited_to_vec())
            .await
            .map_err(Into::into)
    }

    pub async fn recv<M: Message + Default>(&mut self) -> Result<M> {
        let buf = &mut self.buf;
        let stream = &mut self.stream;

        // protobuf 消息的长度信息最少占有 1 byte, 最多占有 10 bytes
        // 当消息本身的长度小于 128 时占用 1 byte
        stream.read_exact(&mut buf[..1]).await?;

        match decode_length_delimiter(&buf[..1]) {
            Ok(sz) => {
                if sz > buf.len() {
                    buf.resize(sz, 0);
                }
                stream.read_exact(&mut buf[..sz]).await?;
                Ok(M::decode(&buf[..sz])?)
            }
            Err(_) => {
                // protobuf 消息的长度信息最少占有 1 byte, 最多占有 10 bytes
                stream.read_exact(&mut buf[1..10]).await?;
                let sz = decode_length_delimiter(&buf[..10])?;
                let delimiter_len = length_delimiter_len(sz);
                let idx = delimiter_len;
                let left = sz - (10 - idx);

                if 10 + left > buf.len() {
                    buf.resize(10 + left, 0);
                }

                stream.read_exact(&mut buf[10..left]).await?;
                Ok(M::decode(&buf[idx..idx + sz])?)
            }
        }
    }
}