use std::io::Cursor;
use aes::cipher::BlockDecryptMut;
use async_trait::async_trait;
use bytes::BytesMut;
use libdeflater::Decompressor;
use tokio::io::{AsyncRead, AsyncReadExt};
use super::super::{
base::{CompressionState, EncryptionState, Packet},
error::{PacketError, PacketResult},
};
use super::RawPacketIO;
use crate::version::Version;
use crate::{ProtocolRead, protocol::types::VarInt};
use crate::{
network::packet::MAX_UNCOMPRESSED_LENGTH,
security::encryption::{Aes128Cfb8Dec, Cfb8Closure},
};
#[derive(Debug)]
pub struct PacketReader<R> {
pub reader: R,
encryption: Option<Aes128Cfb8Dec>,
compression: CompressionState,
protocol_version: Version,
buffer: BytesMut,
}
impl<R: AsyncRead + Unpin> PacketReader<R> {
pub fn new(reader: R) -> Self {
Self {
reader,
encryption: None,
compression: CompressionState::Disabled,
protocol_version: Version::new(0),
buffer: BytesMut::with_capacity(8192),
}
}
pub fn is_encryption_enabled(&self) -> bool {
self.encryption.is_some()
}
pub fn enable_encryption(&mut self, cipher: Aes128Cfb8Dec) {
self.encryption = Some(cipher);
}
pub fn disable_encryption(&mut self) {
self.encryption = None;
}
pub fn enable_compression(&mut self, threshold: i32) {
self.compression = CompressionState::Enabled { threshold };
}
pub fn disable_compression(&mut self) {
self.compression = CompressionState::Disabled;
}
pub fn is_compressing(&self) -> bool {
matches!(self.compression, CompressionState::Enabled { .. })
}
pub fn set_protocol_version(&mut self, version: Version) {
self.protocol_version = version;
}
pub async fn read_packet(&mut self) -> PacketResult<Packet> {
let packet_length = {
let mut length_bytes = BytesMut::new();
loop {
let mut byte = [0u8; 1];
self.reader.read_exact(&mut byte).await?;
if let Some(cipher) = &mut self.encryption {
cipher.decrypt_with_backend_mut(Cfb8Closure { data: &mut byte });
}
length_bytes.extend_from_slice(&byte);
if byte[0] & 0x80 == 0 {
break;
}
if length_bytes.len() >= 3 {
return Err(PacketError::VarIntDecoding("VarInt too long".to_string()));
}
}
let mut cursor = Cursor::new(&length_bytes);
let (VarInt(length), _) = VarInt::read_from(&mut cursor)?;
length
};
let mut encrypted_data = vec![0u8; packet_length as usize];
self.reader.read_exact(&mut encrypted_data).await?;
if let Some(cipher) = &mut self.encryption {
cipher.decrypt_with_backend_mut(Cfb8Closure {
data: &mut encrypted_data,
});
}
let packet_data = if let CompressionState::Enabled { threshold: _ } = self.compression {
let mut cursor = Cursor::new(&encrypted_data);
let (VarInt(data_length), bytes_read) = VarInt::read_from(&mut cursor)?;
if data_length == 0 {
BytesMut::from(&encrypted_data[bytes_read..])
} else {
if data_length > MAX_UNCOMPRESSED_LENGTH as i32 {
return Err(PacketError::InvalidLength {
length: data_length as usize,
max: MAX_UNCOMPRESSED_LENGTH,
});
}
let mut decompressor = Decompressor::new();
let mut outbuf = vec![0; data_length as usize];
decompressor
.zlib_decompress(&encrypted_data[bytes_read..], &mut outbuf)
.unwrap();
if outbuf.len() != data_length as usize {
return Err(PacketError::compression("Decompressed length mismatch"));
}
BytesMut::from(&outbuf[..])
}
} else {
BytesMut::from(&encrypted_data[..])
};
let mut cursor = Cursor::new(&packet_data);
let (VarInt(id), id_size) = VarInt::read_from(&mut cursor)?;
Ok(Packet {
id,
data: BytesMut::from(&packet_data[id_size..]),
compression: self.compression,
encryption: if self.encryption.is_some() {
EncryptionState::Enabled {
encrypted_data: true,
}
} else {
EncryptionState::Disabled
},
protocol_version: self.protocol_version,
})
}
pub fn get_ref(&self) -> &R {
&self.reader
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.reader
}
}
#[async_trait]
impl<R> RawPacketIO for PacketReader<R>
where
R: AsyncRead + Unpin + Send,
{
async fn read_raw(&mut self) -> PacketResult<Option<BytesMut>> {
self.buffer.clear();
match self.reader.read_buf(&mut self.buffer).await {
Ok(0) => Ok(None), Ok(_) => {
let result = self.buffer.clone();
Ok(Some(result))
}
Err(e) => Err(PacketError::Io(e)),
}
}
async fn write_raw(&mut self, _data: &[u8]) -> PacketResult<()> {
Err(PacketError::invalid_format("Readers cannot write"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
use tokio::io::BufReader;
#[tokio::test]
async fn test_read_simple_packet() {
let mut data = Vec::new();
data.extend_from_slice(&[3]); data.extend_from_slice(&[0]); data.extend_from_slice(&[1, 2]);
let cursor = Cursor::new(data);
let mut reader = PacketReader::new(BufReader::new(cursor));
let packet = reader.read_packet().await.unwrap();
assert_eq!(packet.id, 0);
assert_eq!(&packet.data[..], &[1, 2]);
}
#[tokio::test]
#[ignore = "TODO"]
async fn test_read_compressed_packet() {
let mut reader = PacketReader::new(BufReader::new(Cursor::new(Vec::new())));
reader.enable_compression(256);
}
#[tokio::test]
async fn test_invalid_packet_length() {
let cursor = Cursor::new(vec![0]); let mut reader = PacketReader::new(BufReader::new(cursor));
let result = reader.read_packet().await;
assert!(result.is_err());
}
}