craftio_rs/
reader.rs

1#[cfg(feature = "encryption")]
2use crate::cfb8::{setup_craft_cipher, CipherError, CraftCipher};
3use crate::util::{get_sized_buf, VAR_INT_BUF_SIZE};
4use crate::wrapper::{CraftIo, CraftWrapper};
5#[cfg(feature = "compression")]
6use flate2::{DecompressError, FlushDecompress, Status};
7use mcproto_rs::protocol::{Id, PacketDirection, RawPacket, State};
8#[cfg(feature = "gat")]
9use mcproto_rs::protocol::PacketKind;
10use mcproto_rs::types::VarInt;
11use mcproto_rs::{Deserialize, Deserialized};
12#[cfg(feature = "backtrace")]
13use std::backtrace::Backtrace;
14use std::io;
15use thiserror::Error;
16#[cfg(any(feature = "futures-io", feature = "tokio-io"))]
17use async_trait::async_trait;
18
19pub const DEAFULT_MAX_PACKET_SIZE: usize = 32 * 1000 * 1000; // 32MB
20
21#[derive(Debug, Error)]
22pub enum ReadError {
23    #[error("i/o failure during read")]
24    IoFailure {
25        #[from]
26        err: io::Error,
27        #[cfg(feature = "backtrace")]
28        backtrace: Backtrace,
29    },
30    #[error("failed to read header VarInt")]
31    PacketHeaderErr {
32        #[from]
33        err: mcproto_rs::DeserializeErr,
34        #[cfg(feature = "backtrace")]
35        backtrace: Backtrace,
36    },
37    #[error("failed to read packet")]
38    PacketErr {
39        #[from]
40        err: mcproto_rs::protocol::PacketErr,
41        #[cfg(feature = "backtrace")]
42        backtrace: Backtrace,
43    },
44    #[cfg(feature = "compression")]
45    #[error("failed to decompress packet")]
46    DecompressFailed {
47        #[from]
48        err: DecompressErr,
49        #[cfg(feature = "backtrace")]
50        backtrace: Backtrace,
51    },
52    #[error("{size} exceeds max size of {max_size}")]
53    PacketTooLarge {
54        size: usize,
55        max_size: usize,
56        #[cfg(feature = "backtrace")]
57        backtrace: Backtrace,
58    }
59}
60
61#[cfg(feature = "compression")]
62#[derive(Debug, Error)]
63pub enum DecompressErr {
64    #[error("buf error")]
65    BufError,
66    #[error("failure while decompressing")]
67    Failure(#[from] DecompressError),
68}
69
70pub type ReadResult<P> = Result<Option<P>, ReadError>;
71
72#[cfg(any(feature = "futures-io", feature = "tokio-io"))]
73#[async_trait]
74pub trait CraftAsyncReader {
75    #[cfg(not(feature = "gat"))]
76    async fn read_packet_async<'a, P>(&'a mut self) -> ReadResult<<P as RawPacket<'a>>::Packet>
77    where
78        P: RawPacket<'a>,
79    {
80        deserialize_raw_packet(self.read_raw_packet_async::<P>().await)
81    }
82
83    #[cfg(feature = "gat")]
84    async fn read_packet_async<P>(&mut self) -> ReadResult<<P::RawPacket<'_> as RawPacket<'_>>::Packet>
85    where
86        P: PacketKind
87    {
88        deserialize_raw_packet(self.read_raw_packet_async::<P>().await)
89    }
90
91    #[cfg(not(feature = "gat"))]
92    async fn read_raw_packet_async<'a, P>(&'a mut self) -> ReadResult<P>
93    where
94        P: RawPacket<'a>;
95
96    #[cfg(feature = "gat")]
97    async fn read_raw_packet_async<P>(&mut self) -> ReadResult<P::RawPacket<'_>>
98    where
99        P: PacketKind;
100
101    async fn read_raw_untyped_packet_async(&mut self) -> ReadResult<(Id, &[u8])>;
102}
103
104pub trait CraftSyncReader {
105    #[cfg(not(feature = "gat"))]
106    fn read_packet<'a, P>(&'a mut self) -> ReadResult<<P as RawPacket<'a>>::Packet>
107    where
108        P: RawPacket<'a>,
109    {
110        deserialize_raw_packet(self.read_raw_packet::<'a, P>())
111    }
112
113    #[cfg(feature = "gat")]
114    fn read_packet<P>(&mut self) -> ReadResult<<P::RawPacket<'_> as RawPacket>::Packet>
115    where
116        P: PacketKind
117    {
118        deserialize_raw_packet(self.read_raw_packet::<P>())
119    }
120
121    #[cfg(not(feature = "gat"))]
122    fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P>
123    where
124        P: RawPacket<'a>;
125
126    #[cfg(feature = "gat")]
127    fn read_raw_packet<P>(&mut self) -> ReadResult<P::RawPacket<'_>>
128    where
129        P: PacketKind;
130
131    fn read_raw_untyped_packet(&mut self) -> ReadResult<(Id, &[u8])>;
132}
133
134///
135/// Wraps some stream of type `R`, and implements either `CraftSyncReader` or `CraftAsyncReader` (or both)
136/// based on what types `R` implements.
137///
138/// You can construct this type calling the function `wrap_with_state`, which requires you to specify
139/// a packet direction (are written packets server-bound or client-bound?) and a state
140/// (`handshaking`? `login`? `status`? `play`?).
141///
142/// This type holds some internal buffers but only allocates them when they are required.
143///
144pub struct CraftReader<R> {
145    inner: R,
146    raw_buf: Option<Vec<u8>>,
147    raw_ready: usize,
148    raw_offset: usize,
149    max_packet_size: usize,
150    #[cfg(feature = "compression")]
151    decompress_buf: Option<Vec<u8>>,
152    #[cfg(feature = "compression")]
153    compression_threshold: Option<i32>,
154    state: State,
155    direction: PacketDirection,
156    #[cfg(feature = "encryption")]
157    encryption: Option<CraftCipher>,
158}
159
160impl<R> CraftWrapper<R> for CraftReader<R> {
161    fn into_inner(self) -> R {
162        self.inner
163    }
164}
165
166impl<R> CraftIo for CraftReader<R> {
167    fn set_state(&mut self, next: State) {
168        self.state = next;
169    }
170
171    #[cfg(feature = "compression")]
172    fn set_compression_threshold(&mut self, threshold: Option<i32>) {
173        self.compression_threshold = threshold;
174    }
175
176    #[cfg(feature = "encryption")]
177    fn enable_encryption(&mut self, key: &[u8], iv: &[u8]) -> Result<(), CipherError> {
178        setup_craft_cipher(&mut self.encryption, key, iv)
179    }
180
181    fn set_max_packet_size(&mut self, max_size: usize) {
182        debug_assert!(max_size > 5);
183        self.max_packet_size = max_size;
184    }
185
186    fn ensure_buf_capacity(&mut self, capacity: usize) {
187        let alloc_to = if capacity > self.max_packet_size {
188            self.max_packet_size
189        } else {
190            capacity
191        };
192        self.move_ready_data_to_front();
193        get_sized_buf(&mut self.raw_buf, 0, alloc_to);
194    }
195
196    #[cfg(feature = "compression")]
197    fn ensure_compression_buf_capacity(&mut self, capacity: usize) {
198        let alloc_to = if capacity > self.max_packet_size {
199            self.max_packet_size
200        } else {
201            capacity
202        };
203        get_sized_buf(&mut self.decompress_buf, 0, alloc_to);
204    }
205}
206
207macro_rules! rr_unwrap {
208    ($result: expr) => {
209        match $result {
210            Ok(Some(r)) => r,
211            Ok(None) => return Ok(None),
212            Err(err) => return Err(err),
213        }
214    };
215}
216
217macro_rules! check_unexpected_eof {
218    ($result: expr) => {
219        if let Err(err) = $result {
220            if err.kind() == std::io::ErrorKind::UnexpectedEof {
221                return Ok(None);
222            }
223
224            return Err(err.into());
225        }
226    };
227}
228
229impl<R> CraftSyncReader for CraftReader<R>
230where
231    R: io::Read,
232{
233    #[cfg(not(feature = "gat"))]
234    fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P>
235    where
236        P: RawPacket<'a>,
237    {
238        self.read_raw_packet_inner::<P>()
239    }
240
241    #[cfg(feature = "gat")]
242    fn read_raw_packet<P>(&mut self) -> ReadResult<P::RawPacket<'_>>
243    where
244        P: PacketKind
245    {
246        self.read_raw_packet_inner::<P::RawPacket<'_>>()
247    }
248
249    fn read_raw_untyped_packet(&mut self) -> ReadResult<(Id, &[u8])> {
250        self.read_untyped_packet_inner()
251    }
252}
253
254#[cfg(any(feature = "futures-io", feature = "tokio-io"))]
255#[async_trait]
256impl<R> CraftAsyncReader for CraftReader<R>
257where
258    R: AsyncReadExact,
259{
260    #[cfg(not(feature = "gat"))]
261    async fn read_raw_packet_async<'a, P>(&'a mut self) -> ReadResult<P>
262    where
263        P: RawPacket<'a>,
264    {
265        self.read_raw_packet_inner_async().await
266    }
267
268    #[cfg(feature = "gat")]
269    async fn read_raw_packet_async<P>(&mut self) -> ReadResult<P::RawPacket<'_>>
270    where
271        P: PacketKind,
272    {
273        self.read_raw_packet_inner_async::<P::RawPacket<'_>>().await
274    }
275
276    async fn read_raw_untyped_packet_async(&mut self) -> ReadResult<(Id, &[u8])> {
277        self.read_raw_untyped_packet_inner_async().await
278    }
279}
280
281impl<R> CraftReader<R>
282where
283    R: io::Read,
284{
285    fn read_untyped_packet_inner(&mut self) -> ReadResult<(Id, &[u8])> {
286        if let Some(primary_packet_len) = self.read_raw_inner()? {
287            self.read_untyped_packet_in_buf(primary_packet_len)
288        } else {
289            Ok(None)
290        }
291    }
292
293    fn read_raw_packet_inner<'a, P>(&'a mut self) -> ReadResult<P>
294    where
295        P: RawPacket<'a>
296    {
297        if let Some(primary_packet_len) = self.read_raw_inner()? {
298            self.read_packet_in_buf(primary_packet_len)
299        } else {
300            Ok(None)
301        }
302    }
303
304    fn read_raw_inner(&mut self) -> ReadResult<usize> {
305        self.move_ready_data_to_front();
306        let primary_packet_len = rr_unwrap!(self.read_packet_len_sync()).0 as usize;
307        if primary_packet_len > self.max_packet_size {
308            return Err(ReadError::PacketTooLarge {
309                size: primary_packet_len,
310                max_size: self.max_packet_size,
311                #[cfg(feature="backtrace")]
312                backtrace: Backtrace::capture(),
313            });
314        }
315
316        if self.ensure_n_ready_sync(primary_packet_len)?.is_none() {
317            return Ok(None);
318        }
319
320        Ok(Some(primary_packet_len))
321    }
322
323    fn read_packet_len_sync(&mut self) -> ReadResult<VarInt> {
324        let buf = rr_unwrap!(self.ensure_n_ready_sync(VAR_INT_BUF_SIZE));
325        let (v, size) = rr_unwrap!(deserialize_varint(buf));
326        self.raw_ready -= size;
327        self.raw_offset += size;
328        Ok(Some(v))
329    }
330
331    fn ensure_n_ready_sync(&mut self, n: usize) -> ReadResult<&[u8]> {
332        if self.raw_ready < n {
333            let to_read = n - self.raw_ready;
334            let target =
335                get_sized_buf(&mut self.raw_buf, self.raw_offset + self.raw_ready, to_read);
336            check_unexpected_eof!(self.inner.read_exact(target));
337            self.raw_ready = n;
338        }
339
340        let ready = get_sized_buf(&mut self.raw_buf, self.raw_offset, n);
341        Ok(Some(ready))
342    }
343}
344
345#[cfg(any(feature = "futures-io", feature = "tokio-io"))]
346impl<R> CraftReader<R>
347where
348    R: AsyncReadExact,
349{
350    async fn read_raw_packet_inner_async<'a, P>(&'a mut self) -> ReadResult<P>
351    where
352        P: RawPacket<'a>
353    {
354        if let Some(primary_packet_len) = self.read_raw_inner_async().await? {
355            self.read_packet_in_buf(primary_packet_len)
356        } else {
357            Ok(None)
358        }
359    }
360
361    async fn read_raw_untyped_packet_inner_async(&mut self) -> ReadResult<(Id, &[u8])> {
362        if let Some(primary_packet_len) = self.read_raw_inner_async().await? {
363            self.read_untyped_packet_in_buf(primary_packet_len)
364        } else {
365            Ok(None)
366        }
367    }
368
369    async fn read_raw_inner_async(&mut self) -> ReadResult<usize> {
370        self.move_ready_data_to_front();
371        let primary_packet_len = rr_unwrap!(self.read_packet_len_async().await).0 as usize;
372        if primary_packet_len > self.max_packet_size {
373            return Err(ReadError::PacketTooLarge {
374                size: primary_packet_len,
375                max_size: self.max_packet_size,
376                #[cfg(feature="backtrace")]
377                backtrace: Backtrace::capture(),
378            });
379        }
380
381        if self.ensure_n_ready_async(primary_packet_len).await?.is_none() {
382            return Ok(None);
383        }
384
385        debug_assert!(self.raw_ready >= primary_packet_len, "{} packet len bytes are ready (actual: {})", primary_packet_len, self.raw_ready);
386        Ok(Some(primary_packet_len))
387    }
388
389    async fn read_packet_len_async(&mut self) -> ReadResult<VarInt> {
390        let buf = rr_unwrap!(self.ensure_n_ready_async(VAR_INT_BUF_SIZE).await);
391        let (v, size) = rr_unwrap!(deserialize_varint(buf));
392        self.raw_ready -= size;
393        self.raw_offset += size;
394        Ok(Some(v))
395    }
396
397    async fn ensure_n_ready_async(&mut self, n: usize) -> ReadResult<&[u8]> {
398        if self.raw_ready < n {
399            let to_read = n - self.raw_ready;
400            let target =
401                get_sized_buf(&mut self.raw_buf, self.raw_offset + self.raw_ready, to_read);
402            debug_assert_eq!(target.len(), to_read);
403            check_unexpected_eof!(self.inner.read_exact(target).await);
404            self.raw_ready = n;
405        }
406
407        let ready = get_sized_buf(&mut self.raw_buf, self.raw_offset, n);
408        debug_assert_eq!(ready.len(), n);
409        Ok(Some(ready))
410    }
411}
412
413#[cfg(any(feature = "futures-io", feature = "tokio-io"))]
414#[async_trait]
415pub trait AsyncReadExact: Unpin + Sync + Send {
416    async fn read_exact(&mut self, to: &mut [u8]) -> Result<(), io::Error>;
417}
418
419#[cfg(all(feature = "futures-io", not(feature = "tokio-io")))]
420#[async_trait]
421impl<R> AsyncReadExact for R
422where
423    R: futures::AsyncReadExt + Unpin + Sync + Send,
424{
425    async fn read_exact(&mut self, to: &mut [u8]) -> Result<(), io::Error> {
426        futures::AsyncReadExt::read_exact(self, to).await
427    }
428}
429
430#[cfg(feature = "tokio-io")]
431#[async_trait]
432impl<R> AsyncReadExact for R
433where
434    R: tokio::io::AsyncRead + Unpin + Sync + Send,
435{
436    async fn read_exact(&mut self, to: &mut [u8]) -> Result<(), io::Error> {
437        tokio::io::AsyncReadExt::read_exact(self, to).await?;
438        Ok(())
439    }
440}
441
442macro_rules! dsz_unwrap {
443    ($bnam: expr, $k: ty) => {
444        match <$k>::mc_deserialize($bnam) {
445            Ok(Deserialized {
446                value: val,
447                data: rest,
448            }) => (val, rest),
449            Err(err) => {
450                return Err(err.into());
451            }
452        };
453    };
454}
455
456impl<R> CraftReader<R> {
457    pub fn wrap(inner: R, direction: PacketDirection) -> Self {
458        Self::wrap_with_state(inner, direction, State::Handshaking)
459    }
460
461    pub fn wrap_with_state(inner: R, direction: PacketDirection, state: State) -> Self {
462        Self {
463            inner,
464            raw_buf: None,
465            raw_ready: 0,
466            raw_offset: 0,
467            #[cfg(feature = "compression")]
468            decompress_buf: None,
469            #[cfg(feature = "compression")]
470            compression_threshold: None,
471            state,
472            direction,
473            #[cfg(feature = "encryption")]
474            encryption: None,
475            max_packet_size: DEAFULT_MAX_PACKET_SIZE
476        }
477    }
478
479    fn read_untyped_packet_in_buf(&mut self, size: usize) -> ReadResult<(Id, &[u8])>
480    {
481        // find data in buf
482        let offset = self.raw_offset;
483        if self.raw_ready < size {
484            panic!("not enough data is ready, got {} ready and {} desired ready!", self.raw_ready, size);
485        }
486        self.raw_ready -= size;
487        self.raw_offset += size;
488        let buf =
489            &mut self.raw_buf.as_mut().expect("should exist right now")[offset..offset + size];
490        // decrypt the packet if encryption is enabled
491        #[cfg(feature = "encryption")]
492        handle_decryption(self.encryption.as_mut(), buf);
493
494        // try to get the packet body bytes... this boils down to:
495        // * check if compression enabled,
496        //    * read data len (VarInt) which isn't compressed
497        //    * if data len is 0, then rest of packet is not compressed, remaining data is body
498        //    * otherwise, data len is decompressed length, so prepare a decompression buf and decompress from
499        //      the buffer into the decompression buffer, and return the slice of the decompression buffer
500        //      which contains this packet's data
501        // * if compression not enabled, then the buf contains only the packet body bytes
502
503        #[cfg(feature = "compression")]
504        let packet_buf = if let Some(_) = self.compression_threshold {
505            let (data_len, rest) = dsz_unwrap!(buf, VarInt);
506            let data_len = data_len.0 as usize;
507            if data_len == 0 {
508                rest
509            } else if data_len >= self.max_packet_size {
510                return Err(ReadError::PacketTooLarge {
511                    size: data_len,
512                    max_size: self.max_packet_size,
513                    #[cfg(feature = "backtrace")]
514                    backtrace: Backtrace::capture()
515                })
516            } else {
517                decompress(rest, &mut self.decompress_buf, data_len)?
518            }
519        } else {
520            buf
521        };
522
523        #[cfg(not(feature = "compression"))]
524        let packet_buf = buf;
525
526        let (raw_id, body_buf) = dsz_unwrap!(packet_buf, VarInt);
527        let id = Id {
528            id: raw_id.0,
529            state: self.state.clone(),
530            direction: self.direction.clone()
531        };
532
533        Ok(Some((id, body_buf)))
534    }
535
536    fn read_packet_in_buf<'a, P>(&'a mut self, size: usize) -> ReadResult<P>
537    where
538        P: RawPacket<'a>,
539    {
540        if let Some((id, body_buf)) = self.read_untyped_packet_in_buf(size)? {
541            match P::create(id, body_buf) {
542                Ok(raw) => Ok(Some(raw)),
543                Err(err) => Err(err.into()),
544            }
545        } else {
546            Ok(None)
547        }
548    }
549
550    fn move_ready_data_to_front(&mut self) {
551        // if there's data that's ready which isn't at the front of the buf, move it to the front
552        if self.raw_ready > 0 && self.raw_offset > 0 {
553            let raw_buf = self
554                .raw_buf
555                .as_mut()
556                .expect("if raw_ready > 0 and raw_offset > 0 then a raw_buf should exist!");
557
558            raw_buf.copy_within(self.raw_offset..(self.raw_offset+self.raw_ready), 0);
559        }
560
561        self.raw_offset = 0;
562    }
563}
564
565#[cfg(feature = "encryption")]
566fn handle_decryption(cipher: Option<&mut CraftCipher>, buf: &mut [u8]) {
567    if let Some(encryption) = cipher {
568        encryption.decrypt(buf);
569    }
570}
571
572fn deserialize_raw_packet<'a, P>(raw: ReadResult<P>) -> ReadResult<P::Packet>
573where
574    P: RawPacket<'a>,
575{
576    match raw {
577        Ok(Some(raw)) => match raw.deserialize() {
578            Ok(deserialized) => Ok(Some(deserialized)),
579            Err(err) => Err(err.into()),
580        },
581        Ok(None) => Ok(None),
582        Err(err) => Err(err),
583    }
584}
585
586fn deserialize_varint(buf: &[u8]) -> ReadResult<(VarInt, usize)> {
587    match VarInt::mc_deserialize(buf) {
588        Ok(v) => Ok(Some((v.value, buf.len() - v.data.len()))),
589        Err(err) => Err(err.into()),
590    }
591}
592
593#[cfg(feature = "compression")]
594fn decompress<'a>(
595    src: &'a [u8],
596    target: &'a mut Option<Vec<u8>>,
597    decompressed_len: usize,
598) -> Result<&'a mut [u8], ReadError> {
599    let mut decompress = flate2::Decompress::new(true);
600    let decompress_buf = get_sized_buf(target, 0, decompressed_len);
601    loop {
602        match decompress.decompress(src, decompress_buf, FlushDecompress::Finish) {
603            Ok(Status::StreamEnd) => break,
604            Ok(Status::Ok) => {}
605            Ok(Status::BufError) => return Err(DecompressErr::BufError.into()),
606            Err(err) => return Err(DecompressErr::Failure(err).into()),
607        }
608    }
609
610    let decompressed_size = decompress.total_out() as usize;
611    Ok(&mut decompress_buf[..decompressed_size])
612}