1#![allow(dead_code)]
2
3use std::io;
4use std::ops::{Deref, DerefMut};
5
6use crate::Error;
7use crate::io::write_and_flush::WriteAndFlush;
8use crate::io::{decode::Decode, encode::Encode};
9use crate::rt::{AsyncRead, AsyncReadExt, AsyncWrite};
10use bytes::BytesMut;
11use std::io::Cursor;
12
13pub struct BufStream<S>
14where
15 S: AsyncRead + AsyncWrite + Unpin,
16{
17 pub stream: S,
18
19 pub wbuf: Vec<u8>,
22
23 pub rbuf: BytesMut,
25}
26
27impl<S> BufStream<S>
28where
29 S: AsyncRead + AsyncWrite + Unpin,
30{
31 pub fn new(stream: S) -> Self {
32 Self {
33 stream,
34 wbuf: Vec::with_capacity(512),
35 rbuf: BytesMut::with_capacity(4096),
36 }
37 }
38
39 pub fn write<'en, T>(&mut self, value: T)
40 where
41 T: Encode<'en, ()>,
42 {
43 self.write_with(value, ())
44 }
45
46 pub fn write_with<'en, T, C>(&mut self, value: T, context: C)
47 where
48 T: Encode<'en, C>,
49 {
50 value.encode_with(&mut self.wbuf, context);
51 }
52
53 pub fn flush(&mut self) -> WriteAndFlush<'_, S> {
54 WriteAndFlush {
55 stream: &mut self.stream,
56 buf: Cursor::new(&mut self.wbuf),
57 }
58 }
59
60 pub async fn read<'de, T>(&mut self, cnt: usize) -> Result<T, Error>
61 where
62 T: Decode<'de, ()>,
63 {
64 self.read_with(cnt, ()).await
65 }
66
67 pub async fn read_with<'de, T, C>(&mut self, cnt: usize, context: C) -> Result<T, Error>
68 where
69 T: Decode<'de, C>,
70 {
71 T::decode_with(self.read_raw(cnt).await?.freeze(), context)
72 }
73
74 pub async fn read_raw(&mut self, cnt: usize) -> Result<BytesMut, Error> {
75 read_raw_into(&mut self.stream, &mut self.rbuf, cnt).await?;
76 let buf = self.rbuf.split_to(cnt);
77
78 Ok(buf)
79 }
80
81 pub async fn read_raw_into(&mut self, buf: &mut BytesMut, cnt: usize) -> Result<(), Error> {
82 read_raw_into(&mut self.stream, buf, cnt).await
83 }
84}
85
86impl<S> Deref for BufStream<S>
87where
88 S: AsyncRead + AsyncWrite + Unpin,
89{
90 type Target = S;
91
92 fn deref(&self) -> &Self::Target {
93 &self.stream
94 }
95}
96
97impl<S> DerefMut for BufStream<S>
98where
99 S: AsyncRead + AsyncWrite + Unpin,
100{
101 fn deref_mut(&mut self) -> &mut Self::Target {
102 &mut self.stream
103 }
104}
105
106struct BufTruncator<'a> {
110 buf: &'a mut BytesMut,
111 filled_len: usize,
112}
113
114impl<'a> BufTruncator<'a> {
115 fn new(buf: &'a mut BytesMut) -> Self {
116 let filled_len = buf.len();
117 Self { buf, filled_len }
118 }
119 fn reserve(&mut self, space: usize) {
120 self.buf.resize(self.filled_len + space, 0);
121 }
122 async fn read<S: AsyncRead + Unpin>(&mut self, stream: &mut S) -> Result<usize, Error> {
123 let n = stream.read(&mut self.buf[self.filled_len..]).await?;
124 self.filled_len += n;
125 Ok(n)
126 }
127 fn is_full(&self) -> bool {
128 self.filled_len >= self.buf.len()
129 }
130}
131
132impl Drop for BufTruncator<'_> {
133 fn drop(&mut self) {
134 self.buf.truncate(self.filled_len);
135 }
136}
137
138async fn read_raw_into<S: AsyncRead + Unpin>(
139 stream: &mut S,
140 buf: &mut BytesMut,
141 cnt: usize,
142) -> Result<(), Error> {
143 let mut buf = BufTruncator::new(buf);
144 buf.reserve(cnt);
145
146 while !buf.is_full() {
147 let n = buf.read(stream).await?;
148
149 if n == 0 {
150 return Err(io::Error::from(io::ErrorKind::ConnectionAborted).into());
156 }
157 }
158
159 Ok(())
160}