1use std::{
2 future::{
3 poll_fn,
4 Future,
5 },
6 io,
7 pin::Pin,
8};
9
10use mid_compression::{
11 error::SizeRetrievalError,
12 interface::IDecompressor,
13};
14use tokio::io::{
15 AsyncRead,
16 AsyncReadExt,
17 BufReader,
18 ReadBuf,
19};
20
21use crate::{
22 compression::{
23 DecompressionConstraint,
24 DecompressionStrategy,
25 },
26 error,
27 utils::{
28 self,
29 flags,
30 },
31};
32
33pub struct MidReader<R, D> {
34 inner: R,
35 decompressor: D,
36}
37
38impl<R, D> MidReader<R, D>
41where
42 R: AsyncReadExt + Unpin,
43 D: IDecompressor,
44{
45 pub async fn read_compressed(
47 &mut self,
48 size: usize,
49 strategy: DecompressionStrategy,
50 ) -> Result<Vec<u8>, error::CompressedReadError> {
51 let buffer = self.read_buffer(size).await?;
52 let dec_size = self.decompressor.try_decompressed_size(&buffer);
53 if matches!(dec_size, Err(SizeRetrievalError::InvalidData)) {
54 return Err(error::CompressedReadError::InvalidData);
55 }
56
57 let mut output = Vec::new();
58
59 match strategy {
60 DecompressionStrategy::ConstrainedConst { constraint } => {
61 match &constraint {
62 ty @ (DecompressionConstraint::Max(m)
63 | DecompressionConstraint::MaxSizeMultiplier(m)) => {
64 let max_size =
65 if matches!(ty, DecompressionConstraint::Max(..)) {
66 *m
67 } else {
68 size * *m
69 };
70
71 if let Ok(dec_size) = dec_size {
72 if dec_size > max_size {
73 Err(error::CompressedReadError::ConstraintFailed { constraint: ty.clone() })
74 } else {
75 output.reserve(dec_size);
76 self.decompressor
77 .try_decompress(&buffer, &mut output)
78 .map_err(|_| {
79 error::CompressedReadError::InvalidData
80 })
81 .map(move |_| output)
82 }
83 } else {
84 output.reserve(size);
85 while output.capacity() < max_size {
86 if self
87 .decompressor
88 .try_decompress(&buffer, &mut output)
89 .is_ok()
90 {
91 return Ok(output);
92 }
93
94 output.reserve(output.capacity());
95 }
96
97 Err(error::CompressedReadError::ConstraintFailed {
98 constraint: ty.clone(),
99 })
100 }
101 }
102 }
103 }
104
105 DecompressionStrategy::Unconstrained => {
106 if let Ok(size) = dec_size {
107 output.reserve(size);
108 self.decompressor
109 .try_decompress(&buffer, &mut output)
110 .unwrap_or_else(|_| unreachable!());
111 return Ok(output);
112 }
113
114 output.reserve(size << 1);
115 loop {
116 if self
117 .decompressor
118 .try_decompress(&buffer, &mut output)
119 .is_ok()
120 {
121 return Ok(buffer);
122 }
123
124 output.reserve(output.capacity());
125 }
126 }
127 }
128 }
129}
130
131impl<R, D> MidReader<R, D>
132where
133 R: AsyncReadExt + Unpin,
134{
135 pub async fn skip_n_bytes(&mut self, nbytes: usize) -> io::Result<()> {
137 const CHUNK_SIZE: usize = 128;
138 let mut buf = [0; CHUNK_SIZE];
139 let mut read = 0;
140
141 while read < nbytes {
142 let remaining = (nbytes - read).min(CHUNK_SIZE);
143 let current_read = self.inner.read(&mut buf[..remaining]).await?;
144
145 read += current_read;
146 }
147
148 Ok(())
149 }
150
151 pub async fn read_raw_packet_type(&mut self) -> io::Result<(u8, u8)> {
154 self.read_u8().await.map(utils::decode_type)
155 }
156
157 pub async fn read_string_prefixed(&mut self) -> io::Result<String> {
160 let size = self.read_u8().await?;
161 self.read_string(size as usize).await
162 }
163
164 pub async fn read_string(
167 &mut self,
168 bytes_size: usize,
169 ) -> io::Result<String> {
170 self.read_buffer(bytes_size)
171 .await
172 .map(|buf| String::from_utf8_lossy(&buf).into_owned())
173 }
174
175 pub async fn read_bytes_prefixed(&mut self) -> io::Result<Vec<u8>> {
177 let size = self.read_u8().await?;
178 self.read_buffer(size as usize).await
179 }
180
181 pub fn read_u8(&mut self) -> impl Future<Output = io::Result<u8>> + '_ {
183 self.inner.read_u8()
184 }
185
186 pub fn read_u16(&mut self) -> impl Future<Output = io::Result<u16>> + '_ {
189 self.inner.read_u16_le()
190 }
191
192 pub fn read_u32(&mut self) -> impl Future<Output = io::Result<u32>> + '_ {
195 self.inner.read_u32_le()
196 }
197
198 pub async fn read_buffer(&mut self, size: usize) -> io::Result<Vec<u8>> {
201 let mut buffer: Vec<u8> = Vec::with_capacity(size);
202 {
203 let mut read_buf =
204 ReadBuf::uninit(&mut buffer.spare_capacity_mut()[..size]);
205
206 while read_buf.filled().len() < size {
207 poll_fn(|cx| {
208 Pin::new(&mut self.inner).poll_read(cx, &mut read_buf)
209 })
210 .await?;
211 }
212 }
213
214 unsafe { buffer.set_len(size) }
218 Ok(buffer)
219 }
220
221 pub fn read_length(
223 &mut self,
224 flags: u8,
225 ) -> impl Future<Output = io::Result<u16>> + '_ {
226 self.read_variadic(flags, flags::SHORT)
227 }
228
229 pub fn read_client_id(
231 &mut self,
232 flags: u8,
233 ) -> impl Future<Output = io::Result<u16>> + '_ {
234 self.read_variadic(flags, flags::SHORT_CLIENT)
235 }
236
237 pub async fn read_variadic(
240 &mut self,
241 current_flags: u8,
242 needed: u8,
243 ) -> io::Result<u16> {
244 if (current_flags & needed) == needed {
245 self.read_u8().await.map(|o| o as u16)
246 } else {
247 self.read_u16().await
248 }
249 }
250}
251
252impl<R, D> MidReader<R, D>
255where
256 R: AsyncRead,
257{
258 pub fn make_buffered(
261 self,
262 buffer_size: usize,
263 decompressor: D,
264 ) -> MidReader<BufReader<R>, D> {
265 MidReader::new_buffered(self.inner, decompressor, buffer_size)
266 }
267}
268
269impl<R, D> MidReader<BufReader<R>, D>
270where
271 R: AsyncRead,
272{
273 pub fn new_buffered(
275 socket: R,
276 decompressor: D,
277 buffer_size: usize,
278 ) -> Self {
279 Self {
280 inner: BufReader::with_capacity(buffer_size, socket),
281 decompressor,
282 }
283 }
284
285 pub fn unbuffer(self) -> MidReader<R, D> {
289 MidReader {
290 inner: self.inner.into_inner(),
291 decompressor: self.decompressor,
292 }
293 }
294}
295
296impl<R, D> MidReader<R, D> {
297 pub const fn socket(&self) -> &R {
299 &self.inner
300 }
301
302 pub fn socket_mut(&mut self) -> &mut R {
304 &mut self.inner
305 }
306
307 pub const fn new(socket: R, decompressor: D) -> Self {
309 Self {
310 inner: socket,
311 decompressor,
312 }
313 }
314}