infrarust 1.6.1

A Rust universal Minecraft proxy
Documentation
use std::io::Cursor;

use aes::cipher::BlockDecryptMut;
use async_trait::async_trait;
use bytes::BytesMut;
use infrarust_protocol::{
    ProtocolRead,
    packet::{CompressionState, EncryptionState},
    types::VarInt,
    version::Version,
};
use libdeflater::Decompressor;
use tokio::io::{AsyncRead, AsyncReadExt};

use super::super::{
    base::Packet,
    error::{PacketError, PacketResult},
};
use super::RawPacketIO;
use super::buffer_pool::{get_buffer_with_capacity, return_buffer};
use crate::{
    network::packet::MAX_UNCOMPRESSED_LENGTH,
    security::encryption::{Aes128Cfb8Dec, Cfb8Closure},
};

#[derive(Debug)]
pub struct PacketReader<R> {
    pub reader: R,
    encryption: Option<Aes128Cfb8Dec>,
    compression: CompressionState,
    protocol_version: Version,
    buffer: BytesMut,
}

impl<R: AsyncRead + Unpin> PacketReader<R> {
    pub fn new(reader: R) -> Self {
        Self {
            reader,
            encryption: None,
            compression: CompressionState::Disabled,
            protocol_version: Version::new(0),
            buffer: BytesMut::with_capacity(8192),
        }
    }

    pub fn is_encryption_enabled(&self) -> bool {
        self.encryption.is_some()
    }

    pub fn enable_encryption(&mut self, cipher: Aes128Cfb8Dec) {
        self.encryption = Some(cipher);
    }

    pub fn disable_encryption(&mut self) {
        self.encryption = None;
    }

    pub fn enable_compression(&mut self, threshold: i32) {
        self.compression = CompressionState::Enabled { threshold };
    }

    pub fn disable_compression(&mut self) {
        self.compression = CompressionState::Disabled;
    }

    pub fn is_compressing(&self) -> bool {
        matches!(self.compression, CompressionState::Enabled { .. })
    }

    pub fn set_protocol_version(&mut self, version: Version) {
        self.protocol_version = version;
    }

    pub async fn read_packet(&mut self) -> PacketResult<Packet> {
        // Read total packet length (may be encrypted)
        let packet_length = if self.encryption.is_some() {
            let mut length_bytes = BytesMut::with_capacity(3);
            loop {
                let mut byte = [0u8; 1];
                self.reader.read_exact(&mut byte).await?;

                if let Some(cipher) = &mut self.encryption {
                    cipher.decrypt_with_backend_mut(Cfb8Closure { data: &mut byte });
                }

                length_bytes.extend_from_slice(&byte);
                if byte[0] & 0x80 == 0 {
                    break;
                }
                if length_bytes.len() >= 3 {
                    return Err(PacketError::VarIntDecoding("VarInt too long".to_string()));
                }
            }

            let mut cursor = Cursor::new(&length_bytes);
            let (VarInt(length), _) = VarInt::read_from(&mut cursor)?;
            length
        } else {
            let mut varint_buf = [0u8; 3];
            let length_bytes_count;
            self.reader.read_exact(&mut varint_buf[0..1]).await?;
            if varint_buf[0] & 0x80 == 0 {
                length_bytes_count = 1;
            } else {
                self.reader.read_exact(&mut varint_buf[1..2]).await?;
                if varint_buf[1] & 0x80 == 0 {
                    length_bytes_count = 2;
                } else {
                    self.reader.read_exact(&mut varint_buf[2..3]).await?;
                    length_bytes_count = 3;
                    if varint_buf[2] & 0x80 != 0 {
                        return Err(PacketError::VarIntDecoding("VarInt too long".to_string()));
                    }
                }
            }
            let mut result = 0i32;
            for (i, &byte) in varint_buf[..length_bytes_count].iter().enumerate() {
                result |= ((byte & 0x7F) as i32) << (7 * i);
            }
            result
        };

        // debug!("Reading packet with total length: {}", packet_length);

        // Read packet data using pooled buffer (may be encrypted)
        let mut encrypted_buffer = get_buffer_with_capacity(packet_length as usize);
        // SAFETY: We're about to overwrite this memory with read_exact.
        // Using resize with 0 would unnecessarily zero memory that gets overwritten.
        unsafe {
            encrypted_buffer.set_len(packet_length as usize);
        }
        self.reader.read_exact(&mut encrypted_buffer).await?;

        // Handle decryption if enabled
        if let Some(cipher) = &mut self.encryption {
            cipher.decrypt_with_backend_mut(Cfb8Closure {
                data: &mut encrypted_buffer,
            });
        }

        // Handle decompression if enabled
        let packet_data = if let CompressionState::Enabled { threshold: _ } = self.compression {
            let mut cursor = Cursor::new(&encrypted_buffer[..]);
            let (VarInt(data_length), bytes_read) = VarInt::read_from(&mut cursor)?;
            // debug!("Data length (uncompressed): {}", data_length);

            if data_length == 0 {
                let result = BytesMut::from(&encrypted_buffer[bytes_read..]);
                return_buffer(encrypted_buffer);
                result
            } else {
                if data_length > MAX_UNCOMPRESSED_LENGTH as i32 {
                    return_buffer(encrypted_buffer);
                    return Err(PacketError::InvalidLength {
                        length: data_length as usize,
                        max: MAX_UNCOMPRESSED_LENGTH,
                    });
                }

                let mut decompressor = Decompressor::new();
                // Use pooled buffer for decompression output
                let mut outbuf = get_buffer_with_capacity(data_length as usize);
                // SAFETY: We're about to overwrite this memory with zlib_decompress.
                // Using resize with 0 would unnecessarily zero memory that gets overwritten.
                unsafe {
                    outbuf.set_len(data_length as usize);
                }

                let decompress_result = decompressor
                    .zlib_decompress(&encrypted_buffer[bytes_read..], &mut outbuf)
                    .map_err(|e| {
                        PacketError::Compression(format!("Decompression failed: {:?}", e))
                    });

                // Return encrypted buffer early as we don't need it anymore
                return_buffer(encrypted_buffer);

                decompress_result?;

                if outbuf.len() != data_length as usize {
                    return_buffer(outbuf);
                    return Err(PacketError::Compression(
                        "Decompressed length mismatch".to_string(),
                    ));
                }

                let result = BytesMut::from(&outbuf[..]);
                return_buffer(outbuf);
                result
            }
        } else {
            let result = BytesMut::from(&encrypted_buffer[..]);
            return_buffer(encrypted_buffer);
            result
        };

        // Read packet ID and create final packet
        let mut cursor = Cursor::new(&packet_data);
        let (VarInt(id), id_size) = VarInt::read_from(&mut cursor)?;

        Ok(Packet {
            id,
            data: BytesMut::from(&packet_data[id_size..]),
            compression: self.compression,
            encryption: if self.encryption.is_some() {
                EncryptionState::Enabled {
                    encrypted_data: true,
                }
            } else {
                EncryptionState::Disabled
            },
            protocol_version: self.protocol_version,
        })
    }

    pub fn get_ref(&self) -> &R {
        &self.reader
    }

    pub fn get_mut(&mut self) -> &mut R {
        &mut self.reader
    }

    pub fn into_inner(self) -> R {
        self.reader
    }

    pub fn buffer(&self) -> &BytesMut {
        &self.buffer
    }
}

#[async_trait]
impl<R> RawPacketIO for PacketReader<R>
where
    R: AsyncRead + Unpin + Send,
{
    async fn read_raw(&mut self) -> PacketResult<Option<BytesMut>> {
        self.buffer.clear();

        match self.reader.read_buf(&mut self.buffer).await {
            Ok(0) => Ok(None), // EOF
            Ok(_) => {
                let result = self.buffer.split();
                Ok(Some(result))
            }
            Err(e) => Err(PacketError::Io(e)),
        }
    }

    async fn write_raw(&mut self, _data: &[u8]) -> PacketResult<()> {
        Err(PacketError::InvalidFormat(
            "Readers cannot write".to_string(),
        ))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::io::Cursor;
    use tokio::io::BufReader;

    #[tokio::test]
    async fn test_read_simple_packet() {
        // Create simple packet: [total length][id][data]
        let mut data = Vec::new();
        data.extend_from_slice(&[3]); // Length
        data.extend_from_slice(&[0]); // ID
        data.extend_from_slice(&[1, 2]); // Data

        let cursor = Cursor::new(data);
        let mut reader = PacketReader::new(BufReader::new(cursor));

        let packet = reader.read_packet().await.unwrap();
        assert_eq!(packet.id, 0);
        assert_eq!(&packet.data[..], &[1, 2]);
    }

    #[tokio::test]
    #[ignore = "TODO"]
    async fn test_read_compressed_packet() {
        let mut reader = PacketReader::new(BufReader::new(Cursor::new(Vec::new())));
        reader.enable_compression(256);
        //TODO: Test à implémenter avec des données compressées réelles
    }

    #[tokio::test]
    async fn test_invalid_packet_length() {
        let cursor = Cursor::new(vec![0]); // Longueur invalide (0)
        let mut reader = PacketReader::new(BufReader::new(cursor));

        let result = reader.read_packet().await;
        assert!(result.is_err());
    }
}