mid_net/
reader.rs

1use std::{
2    future::{
3        poll_fn,
4        Future,
5    },
6    io,
7    pin::Pin,
8};
9
10use mid_compression::{
11    error::SizeRetrievalError,
12    interface::IDecompressor,
13};
14use tokio::io::{
15    AsyncRead,
16    AsyncReadExt,
17    BufReader,
18    ReadBuf,
19};
20
21use crate::{
22    compression::{
23        DecompressionConstraint,
24        DecompressionStrategy,
25    },
26    error,
27    utils::{
28        self,
29        flags,
30    },
31};
32
33pub struct MidReader<R, D> {
34    inner: R,
35    decompressor: D,
36}
37
38// Actual reading
39
40impl<R, D> MidReader<R, D>
41where
42    R: AsyncReadExt + Unpin,
43    D: IDecompressor,
44{
45    /// Reads compressed pile of bytes from the stream
46    pub async fn read_compressed(
47        &mut self,
48        size: usize,
49        strategy: DecompressionStrategy,
50    ) -> Result<Vec<u8>, error::CompressedReadError> {
51        let buffer = self.read_buffer(size).await?;
52        let dec_size = self.decompressor.try_decompressed_size(&buffer);
53        if matches!(dec_size, Err(SizeRetrievalError::InvalidData)) {
54            return Err(error::CompressedReadError::InvalidData);
55        }
56
57        let mut output = Vec::new();
58
59        match strategy {
60            DecompressionStrategy::ConstrainedConst { constraint } => {
61                match &constraint {
62                    ty @ (DecompressionConstraint::Max(m)
63                    | DecompressionConstraint::MaxSizeMultiplier(m)) => {
64                        let max_size =
65                            if matches!(ty, DecompressionConstraint::Max(..)) {
66                                *m
67                            } else {
68                                size * *m
69                            };
70
71                        if let Ok(dec_size) = dec_size {
72                            if dec_size > max_size {
73                                Err(error::CompressedReadError::ConstraintFailed { constraint: ty.clone() })
74                            } else {
75                                output.reserve(dec_size);
76                                self.decompressor
77                                    .try_decompress(&buffer, &mut output)
78                                    .map_err(|_| {
79                                        error::CompressedReadError::InvalidData
80                                    })
81                                    .map(move |_| output)
82                            }
83                        } else {
84                            output.reserve(size);
85                            while output.capacity() < max_size {
86                                if self
87                                    .decompressor
88                                    .try_decompress(&buffer, &mut output)
89                                    .is_ok()
90                                {
91                                    return Ok(output);
92                                }
93
94                                output.reserve(output.capacity());
95                            }
96
97                            Err(error::CompressedReadError::ConstraintFailed {
98                                constraint: ty.clone(),
99                            })
100                        }
101                    }
102                }
103            }
104
105            DecompressionStrategy::Unconstrained => {
106                if let Ok(size) = dec_size {
107                    output.reserve(size);
108                    self.decompressor
109                        .try_decompress(&buffer, &mut output)
110                        .unwrap_or_else(|_| unreachable!());
111                    return Ok(output);
112                }
113
114                output.reserve(size << 1);
115                loop {
116                    if self
117                        .decompressor
118                        .try_decompress(&buffer, &mut output)
119                        .is_ok()
120                    {
121                        return Ok(buffer);
122                    }
123
124                    output.reserve(output.capacity());
125                }
126            }
127        }
128    }
129}
130
131impl<R, D> MidReader<R, D>
132where
133    R: AsyncReadExt + Unpin,
134{
135    /// Skips `nbytes` bytes from the underlying stream.
136    pub async fn skip_n_bytes(&mut self, nbytes: usize) -> io::Result<()> {
137        const CHUNK_SIZE: usize = 128;
138        let mut buf = [0; CHUNK_SIZE];
139        let mut read = 0;
140
141        while read < nbytes {
142            let remaining = (nbytes - read).min(CHUNK_SIZE);
143            let current_read = self.inner.read(&mut buf[..remaining]).await?;
144
145            read += current_read;
146        }
147
148        Ok(())
149    }
150
151    /// Reads packet type and decodes it returning pair of
152    /// `u8`'s
153    pub async fn read_raw_packet_type(&mut self) -> io::Result<(u8, u8)> {
154        self.read_u8().await.map(utils::decode_type)
155    }
156
157    /// Reads string of prefixed size with max size of
158    /// `u8::MAX`, uses lossy utf8 decoding.
159    pub async fn read_string_prefixed(&mut self) -> io::Result<String> {
160        let size = self.read_u8().await?;
161        self.read_string(size as usize).await
162    }
163
164    /// Reads string of size `bytes_size` with lossy utf8
165    /// decoding.
166    pub async fn read_string(
167        &mut self,
168        bytes_size: usize,
169    ) -> io::Result<String> {
170        self.read_buffer(bytes_size)
171            .await
172            .map(|buf| String::from_utf8_lossy(&buf).into_owned())
173    }
174
175    /// Reads prefixed buffer with max size of `u8::MAX`.
176    pub async fn read_bytes_prefixed(&mut self) -> io::Result<Vec<u8>> {
177        let size = self.read_u8().await?;
178        self.read_buffer(size as usize).await
179    }
180
181    /// Read `u8` from the underlying stream
182    pub fn read_u8(&mut self) -> impl Future<Output = io::Result<u8>> + '_ {
183        self.inner.read_u8()
184    }
185
186    /// Read `u16` from the underlying stream (little
187    /// endian)
188    pub fn read_u16(&mut self) -> impl Future<Output = io::Result<u16>> + '_ {
189        self.inner.read_u16_le()
190    }
191
192    /// Read `u32` from the underlying stream (little
193    /// endian)
194    pub fn read_u32(&mut self) -> impl Future<Output = io::Result<u32>> + '_ {
195        self.inner.read_u32_le()
196    }
197
198    /// Read `size` bytes from the socket without buffer
199    /// pre-filling.
200    pub async fn read_buffer(&mut self, size: usize) -> io::Result<Vec<u8>> {
201        let mut buffer: Vec<u8> = Vec::with_capacity(size);
202        {
203            let mut read_buf =
204                ReadBuf::uninit(&mut buffer.spare_capacity_mut()[..size]);
205
206            while read_buf.filled().len() < size {
207                poll_fn(|cx| {
208                    Pin::new(&mut self.inner).poll_read(cx, &mut read_buf)
209                })
210                .await?;
211            }
212        }
213
214        // SAFETY: this is safe since we passed
215        // `read_buf.filled().len() >= size` condition,
216        // so `buffer` initialized with exactly `size` items.
217        unsafe { buffer.set_len(size) }
218        Ok(buffer)
219    }
220
221    /// Reads variadic length of payload from the stream
222    pub fn read_length(
223        &mut self,
224        flags: u8,
225    ) -> impl Future<Output = io::Result<u16>> + '_ {
226        self.read_variadic(flags, flags::SHORT)
227    }
228
229    /// Reads variadic client id from the stream
230    pub fn read_client_id(
231        &mut self,
232        flags: u8,
233    ) -> impl Future<Output = io::Result<u16>> + '_ {
234        self.read_variadic(flags, flags::SHORT_CLIENT)
235    }
236
237    /// Reads `u8` or `u16` from the stream, depending on
238    /// flags.
239    pub async fn read_variadic(
240        &mut self,
241        current_flags: u8,
242        needed: u8,
243    ) -> io::Result<u16> {
244        if (current_flags & needed) == needed {
245            self.read_u8().await.map(|o| o as u16)
246        } else {
247            self.read_u16().await
248        }
249    }
250}
251
252// Bufferization & creation related stuff
253
254impl<R, D> MidReader<R, D>
255where
256    R: AsyncRead,
257{
258    /// Create buffered reader (wraps R with `BufReader<R>`
259    /// with specified capacity)
260    pub fn make_buffered(
261        self,
262        buffer_size: usize,
263        decompressor: D,
264    ) -> MidReader<BufReader<R>, D> {
265        MidReader::new_buffered(self.inner, decompressor, buffer_size)
266    }
267}
268
269impl<R, D> MidReader<BufReader<R>, D>
270where
271    R: AsyncRead,
272{
273    /// Create buffered version of the reader
274    pub fn new_buffered(
275        socket: R,
276        decompressor: D,
277        buffer_size: usize,
278    ) -> Self {
279        Self {
280            inner: BufReader::with_capacity(buffer_size, socket),
281            decompressor,
282        }
283    }
284
285    /// Remove underlying buffer.
286    ///
287    /// WARNING: buffered data can be lost!
288    pub fn unbuffer(self) -> MidReader<R, D> {
289        MidReader {
290            inner: self.inner.into_inner(),
291            decompressor: self.decompressor,
292        }
293    }
294}
295
296impl<R, D> MidReader<R, D> {
297    /// Get shared access to the underlying socket
298    pub const fn socket(&self) -> &R {
299        &self.inner
300    }
301
302    /// Get exclusive access to the underlying socket
303    pub fn socket_mut(&mut self) -> &mut R {
304        &mut self.inner
305    }
306
307    /// Simply create reader from the underlying socket type
308    pub const fn new(socket: R, decompressor: D) -> Self {
309        Self {
310            inner: socket,
311            decompressor,
312        }
313    }
314}