use std::io::Cursor;
use aes::cipher::BlockDecryptMut;
use async_trait::async_trait;
use bytes::BytesMut;
use infrarust_protocol::{
ProtocolRead,
packet::{CompressionState, EncryptionState},
types::VarInt,
version::Version,
};
use libdeflater::Decompressor;
use tokio::io::{AsyncRead, AsyncReadExt};
use super::super::{
base::Packet,
error::{PacketError, PacketResult},
};
use super::RawPacketIO;
use super::buffer_pool::{get_buffer_with_capacity, return_buffer};
use crate::{
network::packet::MAX_UNCOMPRESSED_LENGTH,
security::encryption::{Aes128Cfb8Dec, Cfb8Closure},
};
#[derive(Debug)]
pub struct PacketReader<R> {
pub reader: R,
encryption: Option<Aes128Cfb8Dec>,
compression: CompressionState,
protocol_version: Version,
buffer: BytesMut,
}
impl<R: AsyncRead + Unpin> PacketReader<R> {
pub fn new(reader: R) -> Self {
Self {
reader,
encryption: None,
compression: CompressionState::Disabled,
protocol_version: Version::new(0),
buffer: BytesMut::with_capacity(8192),
}
}
pub fn is_encryption_enabled(&self) -> bool {
self.encryption.is_some()
}
pub fn enable_encryption(&mut self, cipher: Aes128Cfb8Dec) {
self.encryption = Some(cipher);
}
pub fn disable_encryption(&mut self) {
self.encryption = None;
}
pub fn enable_compression(&mut self, threshold: i32) {
self.compression = CompressionState::Enabled { threshold };
}
pub fn disable_compression(&mut self) {
self.compression = CompressionState::Disabled;
}
pub fn is_compressing(&self) -> bool {
matches!(self.compression, CompressionState::Enabled { .. })
}
pub fn set_protocol_version(&mut self, version: Version) {
self.protocol_version = version;
}
pub async fn read_packet(&mut self) -> PacketResult<Packet> {
let packet_length = if self.encryption.is_some() {
let mut length_bytes = BytesMut::with_capacity(3);
loop {
let mut byte = [0u8; 1];
self.reader.read_exact(&mut byte).await?;
if let Some(cipher) = &mut self.encryption {
cipher.decrypt_with_backend_mut(Cfb8Closure { data: &mut byte });
}
length_bytes.extend_from_slice(&byte);
if byte[0] & 0x80 == 0 {
break;
}
if length_bytes.len() >= 3 {
return Err(PacketError::VarIntDecoding("VarInt too long".to_string()));
}
}
let mut cursor = Cursor::new(&length_bytes);
let (VarInt(length), _) = VarInt::read_from(&mut cursor)?;
length
} else {
let mut varint_buf = [0u8; 3];
let length_bytes_count;
self.reader.read_exact(&mut varint_buf[0..1]).await?;
if varint_buf[0] & 0x80 == 0 {
length_bytes_count = 1;
} else {
self.reader.read_exact(&mut varint_buf[1..2]).await?;
if varint_buf[1] & 0x80 == 0 {
length_bytes_count = 2;
} else {
self.reader.read_exact(&mut varint_buf[2..3]).await?;
length_bytes_count = 3;
if varint_buf[2] & 0x80 != 0 {
return Err(PacketError::VarIntDecoding("VarInt too long".to_string()));
}
}
}
let mut result = 0i32;
for (i, &byte) in varint_buf[..length_bytes_count].iter().enumerate() {
result |= ((byte & 0x7F) as i32) << (7 * i);
}
result
};
let mut encrypted_buffer = get_buffer_with_capacity(packet_length as usize);
unsafe {
encrypted_buffer.set_len(packet_length as usize);
}
self.reader.read_exact(&mut encrypted_buffer).await?;
if let Some(cipher) = &mut self.encryption {
cipher.decrypt_with_backend_mut(Cfb8Closure {
data: &mut encrypted_buffer,
});
}
let packet_data = if let CompressionState::Enabled { threshold: _ } = self.compression {
let mut cursor = Cursor::new(&encrypted_buffer[..]);
let (VarInt(data_length), bytes_read) = VarInt::read_from(&mut cursor)?;
if data_length == 0 {
let result = BytesMut::from(&encrypted_buffer[bytes_read..]);
return_buffer(encrypted_buffer);
result
} else {
if data_length > MAX_UNCOMPRESSED_LENGTH as i32 {
return_buffer(encrypted_buffer);
return Err(PacketError::InvalidLength {
length: data_length as usize,
max: MAX_UNCOMPRESSED_LENGTH,
});
}
let mut decompressor = Decompressor::new();
let mut outbuf = get_buffer_with_capacity(data_length as usize);
unsafe {
outbuf.set_len(data_length as usize);
}
let decompress_result = decompressor
.zlib_decompress(&encrypted_buffer[bytes_read..], &mut outbuf)
.map_err(|e| {
PacketError::Compression(format!("Decompression failed: {:?}", e))
});
return_buffer(encrypted_buffer);
decompress_result?;
if outbuf.len() != data_length as usize {
return_buffer(outbuf);
return Err(PacketError::Compression(
"Decompressed length mismatch".to_string(),
));
}
let result = BytesMut::from(&outbuf[..]);
return_buffer(outbuf);
result
}
} else {
let result = BytesMut::from(&encrypted_buffer[..]);
return_buffer(encrypted_buffer);
result
};
let mut cursor = Cursor::new(&packet_data);
let (VarInt(id), id_size) = VarInt::read_from(&mut cursor)?;
Ok(Packet {
id,
data: BytesMut::from(&packet_data[id_size..]),
compression: self.compression,
encryption: if self.encryption.is_some() {
EncryptionState::Enabled {
encrypted_data: true,
}
} else {
EncryptionState::Disabled
},
protocol_version: self.protocol_version,
})
}
pub fn get_ref(&self) -> &R {
&self.reader
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.reader
}
pub fn into_inner(self) -> R {
self.reader
}
pub fn buffer(&self) -> &BytesMut {
&self.buffer
}
}
#[async_trait]
impl<R> RawPacketIO for PacketReader<R>
where
R: AsyncRead + Unpin + Send,
{
async fn read_raw(&mut self) -> PacketResult<Option<BytesMut>> {
self.buffer.clear();
match self.reader.read_buf(&mut self.buffer).await {
Ok(0) => Ok(None), Ok(_) => {
let result = self.buffer.split();
Ok(Some(result))
}
Err(e) => Err(PacketError::Io(e)),
}
}
async fn write_raw(&mut self, _data: &[u8]) -> PacketResult<()> {
Err(PacketError::InvalidFormat(
"Readers cannot write".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
use tokio::io::BufReader;
#[tokio::test]
async fn test_read_simple_packet() {
let mut data = Vec::new();
data.extend_from_slice(&[3]); data.extend_from_slice(&[0]); data.extend_from_slice(&[1, 2]);
let cursor = Cursor::new(data);
let mut reader = PacketReader::new(BufReader::new(cursor));
let packet = reader.read_packet().await.unwrap();
assert_eq!(packet.id, 0);
assert_eq!(&packet.data[..], &[1, 2]);
}
#[tokio::test]
#[ignore = "TODO"]
async fn test_read_compressed_packet() {
let mut reader = PacketReader::new(BufReader::new(Cursor::new(Vec::new())));
reader.enable_compression(256);
}
#[tokio::test]
async fn test_invalid_packet_length() {
let cursor = Cursor::new(vec![0]); let mut reader = PacketReader::new(BufReader::new(cursor));
let result = reader.read_packet().await;
assert!(result.is_err());
}
}