use bytes::{Buf, BytesMut};
use std::io::Cursor;
use infrarust_protocol::{
packet::{
CompressionControl, CompressionState, EncryptionControl, EncryptionState, PacketCodec,
PacketDataAccess, PacketError, PacketSerialization, PacketValidation, Result,
},
types::{ProtocolRead, ProtocolWrite, VarInt, WriteToBytes},
version::Version,
};
use super::PacketBuilder;
pub const MAX_PACKET_LENGTH: usize = 2097151; pub const MAX_PACKET_DATA_LENGTH: usize = 0x200000; pub const MAX_UNCOMPRESSED_LENGTH: usize = 8388608;
#[derive(Clone)]
pub struct Packet {
pub id: i32,
pub data: BytesMut,
pub compression: CompressionState,
pub encryption: EncryptionState,
pub protocol_version: Version,
}
impl std::fmt::Debug for Packet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Packet")
.field("id", &format!("0x{:02x}", self.id))
.field("data_len", &self.data.len())
.field("compression", &self.compression)
.field("encryption", &self.encryption)
.field("protocol_version", &self.protocol_version)
.finish()
}
}
impl Packet {
pub fn new(id: i32) -> Self {
Self {
id,
data: BytesMut::new(),
compression: CompressionState::Disabled,
encryption: EncryptionState::Disabled,
protocol_version: Version::V1_20_2,
}
}
pub fn with_capacity(id: i32, capacity: usize) -> Self {
Self {
id,
data: BytesMut::with_capacity(capacity),
compression: CompressionState::Disabled,
encryption: EncryptionState::Disabled,
protocol_version: Version::V1_20_2,
}
}
}
impl PacketDataAccess for Packet {
fn id(&self) -> i32 {
self.id
}
fn data(&self) -> &[u8] {
&self.data
}
fn protocol_version(&self) -> Version {
self.protocol_version
}
fn set_protocol_version(&mut self, version: Version) {
self.protocol_version = version;
}
}
impl CompressionControl for Packet {
fn compression_state(&self) -> CompressionState {
self.compression
}
fn enable_compression(&mut self, threshold: i32) {
self.compression = CompressionState::Enabled { threshold };
}
fn disable_compression(&mut self) {
self.compression = CompressionState::Disabled;
}
fn is_compressing(&self) -> bool {
matches!(self.compression, CompressionState::Enabled { .. })
}
}
impl EncryptionControl for Packet {
fn encryption_state(&self) -> EncryptionState {
self.encryption.clone()
}
fn enable_encryption(&mut self) {
self.encryption = EncryptionState::Enabled {
encrypted_data: false,
};
}
fn disable_encryption(&mut self) {
self.encryption = EncryptionState::Disabled;
}
fn mark_as_encrypted(&mut self) {
if let EncryptionState::Enabled {
ref mut encrypted_data,
} = self.encryption
{
*encrypted_data = true;
}
}
fn is_encrypted(&self) -> bool {
matches!(self.encryption, EncryptionState::Enabled { .. })
}
}
impl PacketValidation for Packet {
fn validate_length(&self) -> Result<()> {
if let EncryptionState::Enabled { encrypted_data: _ } = self.encryption {
return Ok(());
}
if self.data.len() > MAX_PACKET_LENGTH {
return Err(PacketError::InvalidLength {
length: self.data.len(),
max: MAX_PACKET_LENGTH,
});
}
Ok(())
}
fn validate_encryption(&self) -> Result<()> {
match self.encryption {
EncryptionState::Enabled {
encrypted_data: false,
} => {
if self.data.is_empty() {
Ok(())
} else {
Err(PacketError::Encryption(
"Non-encrypted data when encryption is enabled".to_string(),
))
}
}
_ => Ok(()),
}
}
fn validate_compression(&self) -> Result<()> {
if let CompressionState::Enabled { threshold } = self.compression
&& threshold < 0
{
return Err(PacketError::Compression(
"Invalid compression threshold".to_string(),
));
}
Ok(())
}
}
impl PacketCodec for Packet {
fn encode<T: ProtocolWrite>(&mut self, value: &T) -> Result<()> {
let mut cursor = Cursor::new(Vec::new());
value.write_to(&mut cursor).map_err(PacketError::Io)?;
self.data.extend_from_slice(&cursor.into_inner());
Ok(())
}
fn decode<T: ProtocolRead>(&self) -> Result<T> {
let mut cursor = Cursor::new(&self.data[..]);
let (value, _) = T::read_from(&mut cursor).map_err(PacketError::Io)?;
Ok(value)
}
}
impl PacketSerialization for Packet {
fn into_raw_bytes(self) -> Result<BytesMut> {
let mut output = BytesMut::new();
let mut packet_content = BytesMut::new();
VarInt(self.id).write_to_bytes(&mut packet_content)?;
packet_content.extend_from_slice(&self.data);
let total_length = VarInt(packet_content.len() as i32);
total_length.write_to_bytes(&mut output)?;
output.extend_from_slice(&packet_content);
Ok(output)
}
fn from_raw_bytes(mut bytes: BytesMut) -> Result<Self> {
let (VarInt(length), length_size) = VarInt::read_from(&mut Cursor::new(&bytes[..]))
.map_err(|_| PacketError::InvalidFormat("Invalid packet length VarInt".to_string()))?;
if length <= 0 || length as usize > MAX_PACKET_LENGTH {
return Err(PacketError::InvalidLength {
length: length as usize,
max: MAX_PACKET_LENGTH,
});
}
bytes.advance(length_size);
let (VarInt(id), id_size) = VarInt::read_from(&mut Cursor::new(&bytes[..]))
.map_err(|_| PacketError::InvalidFormat("Invalid packet ID VarInt".to_string()))?;
bytes.advance(id_size);
if bytes.len() > MAX_PACKET_DATA_LENGTH {
return Err(PacketError::InvalidLength {
length: bytes.len(),
max: MAX_PACKET_DATA_LENGTH,
});
}
let packet = PacketBuilder::new().id(id).data(bytes).build()?;
packet.validate_length()?;
Ok(packet)
}
}