Skip to main content

rbdc/io/
buf_stream.rs

1#![allow(dead_code)]
2
3use std::io;
4use std::ops::{Deref, DerefMut};
5
6use crate::Error;
7use crate::io::write_and_flush::WriteAndFlush;
8use crate::io::{decode::Decode, encode::Encode};
9use crate::rt::{AsyncRead, AsyncReadExt, AsyncWrite};
10use bytes::BytesMut;
11use std::io::Cursor;
12
13pub struct BufStream<S>
14where
15    S: AsyncRead + AsyncWrite + Unpin,
16{
17    pub stream: S,
18
19    // writes with `write` to the underlying stream are buffered
20    // this can be flushed with `flush`
21    pub wbuf: Vec<u8>,
22
23    // we read into the read buffer using 100% safe code
24    pub rbuf: BytesMut,
25}
26
27impl<S> BufStream<S>
28where
29    S: AsyncRead + AsyncWrite + Unpin,
30{
31    pub fn new(stream: S) -> Self {
32        Self {
33            stream,
34            wbuf: Vec::with_capacity(512),
35            rbuf: BytesMut::with_capacity(4096),
36        }
37    }
38
39    pub fn write<'en, T>(&mut self, value: T)
40    where
41        T: Encode<'en, ()>,
42    {
43        self.write_with(value, ())
44    }
45
46    pub fn write_with<'en, T, C>(&mut self, value: T, context: C)
47    where
48        T: Encode<'en, C>,
49    {
50        value.encode_with(&mut self.wbuf, context);
51    }
52
53    pub fn flush(&mut self) -> WriteAndFlush<'_, S> {
54        WriteAndFlush {
55            stream: &mut self.stream,
56            buf: Cursor::new(&mut self.wbuf),
57        }
58    }
59
60    pub async fn read<'de, T>(&mut self, cnt: usize) -> Result<T, Error>
61    where
62        T: Decode<'de, ()>,
63    {
64        self.read_with(cnt, ()).await
65    }
66
67    pub async fn read_with<'de, T, C>(&mut self, cnt: usize, context: C) -> Result<T, Error>
68    where
69        T: Decode<'de, C>,
70    {
71        T::decode_with(self.read_raw(cnt).await?.freeze(), context)
72    }
73
74    pub async fn read_raw(&mut self, cnt: usize) -> Result<BytesMut, Error> {
75        read_raw_into(&mut self.stream, &mut self.rbuf, cnt).await?;
76        let buf = self.rbuf.split_to(cnt);
77
78        Ok(buf)
79    }
80
81    pub async fn read_raw_into(&mut self, buf: &mut BytesMut, cnt: usize) -> Result<(), Error> {
82        read_raw_into(&mut self.stream, buf, cnt).await
83    }
84}
85
86impl<S> Deref for BufStream<S>
87where
88    S: AsyncRead + AsyncWrite + Unpin,
89{
90    type Target = S;
91
92    fn deref(&self) -> &Self::Target {
93        &self.stream
94    }
95}
96
97impl<S> DerefMut for BufStream<S>
98where
99    S: AsyncRead + AsyncWrite + Unpin,
100{
101    fn deref_mut(&mut self) -> &mut Self::Target {
102        &mut self.stream
103    }
104}
105
106// Holds a buffer which has been temporarily extended, so that
107// we can read into it. Automatically shrinks the buffer back
108// down if the read is cancelled.
109struct BufTruncator<'a> {
110    buf: &'a mut BytesMut,
111    filled_len: usize,
112}
113
114impl<'a> BufTruncator<'a> {
115    fn new(buf: &'a mut BytesMut) -> Self {
116        let filled_len = buf.len();
117        Self { buf, filled_len }
118    }
119    fn reserve(&mut self, space: usize) {
120        self.buf.resize(self.filled_len + space, 0);
121    }
122    async fn read<S: AsyncRead + Unpin>(&mut self, stream: &mut S) -> Result<usize, Error> {
123        let n = stream.read(&mut self.buf[self.filled_len..]).await?;
124        self.filled_len += n;
125        Ok(n)
126    }
127    fn is_full(&self) -> bool {
128        self.filled_len >= self.buf.len()
129    }
130}
131
132impl Drop for BufTruncator<'_> {
133    fn drop(&mut self) {
134        self.buf.truncate(self.filled_len);
135    }
136}
137
138async fn read_raw_into<S: AsyncRead + Unpin>(
139    stream: &mut S,
140    buf: &mut BytesMut,
141    cnt: usize,
142) -> Result<(), Error> {
143    let mut buf = BufTruncator::new(buf);
144    buf.reserve(cnt);
145
146    while !buf.is_full() {
147        let n = buf.read(stream).await?;
148
149        if n == 0 {
150            // a zero read when we had space in the read buffer
151            // should be treated as an EOF
152
153            // and an unexpected EOF means the server told us to go away
154
155            return Err(io::Error::from(io::ErrorKind::ConnectionAborted).into());
156        }
157    }
158
159    Ok(())
160}