use crate::error::{RconError, Result};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::debug;
const PACKET_TYPE_AUTH: i32 = 3;
const PACKET_TYPE_COMMAND: i32 = 2;
const MAX_PACKET_SIZE: i32 = 4_194_304;
const MIN_PACKET_SIZE: i32 = 10;
#[derive(Debug, Clone)]
pub(crate) struct Packet {
pub id: i32,
pub packet_type: i32,
pub payload: String,
}
impl Packet {
pub fn auth(id: i32, password: &str) -> Self {
Self {
id,
packet_type: PACKET_TYPE_AUTH,
payload: password.to_string(),
}
}
pub fn command(id: i32, command: &str) -> Self {
Self {
id,
packet_type: PACKET_TYPE_COMMAND,
payload: command.to_string(),
}
}
pub fn encode(&self) -> Result<Vec<u8>> {
let payload_bytes = self.payload.as_bytes();
let body_length = 4i32
.checked_add(4)
.and_then(|v| v.checked_add(payload_bytes.len() as i32))
.and_then(|v| v.checked_add(2))
.ok_or_else(|| {
RconError::InvalidPacket(format!(
"Payload too large: {} bytes",
payload_bytes.len()
))
})?;
let mut buffer = Vec::with_capacity(4 + body_length as usize);
buffer.extend_from_slice(&body_length.to_le_bytes());
buffer.extend_from_slice(&self.id.to_le_bytes());
buffer.extend_from_slice(&self.packet_type.to_le_bytes());
buffer.extend_from_slice(payload_bytes);
buffer.push(0);
buffer.push(0);
debug!(
id = self.id,
packet_type = self.packet_type,
payload_len = payload_bytes.len(),
total_len = buffer.len(),
"Encoded packet"
);
Ok(buffer)
}
}
pub(crate) async fn read_packet(stream: &mut TcpStream) -> Result<Packet> {
let mut length_buf = [0u8; 4];
stream.read_exact(&mut length_buf).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
RconError::ConnectionLost(e)
} else {
RconError::Io(e)
}
})?;
let length = i32::from_le_bytes(length_buf);
if length < MIN_PACKET_SIZE {
return Err(RconError::InvalidPacket(format!(
"Packet body too small: {} bytes (minimum {})",
length, MIN_PACKET_SIZE
)));
}
if length > MAX_PACKET_SIZE {
return Err(RconError::InvalidPacket(format!(
"Packet body too large: {} bytes (maximum {})",
length, MAX_PACKET_SIZE
)));
}
let length = length as usize;
debug!(body_length = length, "Reading packet");
let mut packet_buf = vec![0u8; length];
stream.read_exact(&mut packet_buf).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
RconError::ConnectionLost(e)
} else {
RconError::Io(e)
}
})?;
parse_packet_buffer(&packet_buf)
}
fn parse_packet_buffer(buffer: &[u8]) -> Result<Packet> {
if buffer.len() < MIN_PACKET_SIZE as usize {
return Err(RconError::InvalidPacket(format!(
"Buffer too small: {} bytes",
buffer.len()
)));
}
let id = i32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]);
let packet_type = i32::from_le_bytes([buffer[4], buffer[5], buffer[6], buffer[7]]);
let payload_end = buffer.len() - 2;
let payload_bytes = &buffer[8..payload_end];
let payload = String::from_utf8_lossy(payload_bytes).to_string();
debug!(
id,
packet_type,
payload_len = payload.len(),
"Parsed packet"
);
Ok(Packet {
id,
packet_type,
payload,
})
}
pub(crate) async fn write_packet(stream: &mut TcpStream, packet: &Packet) -> Result<()> {
let encoded = packet.encode()?;
stream.write_all(&encoded).await?;
stream.flush().await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_auth_packet() {
let packet = Packet::auth(1, "password");
let encoded = packet.encode().unwrap();
assert_eq!(
i32::from_le_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]),
18
);
assert_eq!(
i32::from_le_bytes([encoded[4], encoded[5], encoded[6], encoded[7]]),
1
);
assert_eq!(
i32::from_le_bytes([encoded[8], encoded[9], encoded[10], encoded[11]]),
3
);
assert_eq!(&encoded[12..20], b"password");
assert_eq!(&encoded[20..22], &[0, 0]);
}
#[test]
fn test_encode_command_packet() {
let packet = Packet::command(2, "/version");
let encoded = packet.encode().unwrap();
assert_eq!(
i32::from_le_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]),
18
);
assert_eq!(
i32::from_le_bytes([encoded[4], encoded[5], encoded[6], encoded[7]]),
2
);
assert_eq!(
i32::from_le_bytes([encoded[8], encoded[9], encoded[10], encoded[11]]),
2
);
assert_eq!(&encoded[12..20], b"/version");
}
#[test]
fn test_encode_empty_payload() {
let packet = Packet::command(1, "");
let encoded = packet.encode().unwrap();
assert_eq!(
i32::from_le_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]),
10
);
assert_eq!(encoded.len(), 14);
}
#[test]
fn test_parse_valid_packet() {
let mut buffer = Vec::new();
buffer.extend_from_slice(&5i32.to_le_bytes()); buffer.extend_from_slice(&0i32.to_le_bytes()); buffer.extend_from_slice(b"test response"); buffer.extend_from_slice(&[0, 0]);
let packet = parse_packet_buffer(&buffer).unwrap();
assert_eq!(packet.id, 5);
assert_eq!(packet.packet_type, 0);
assert_eq!(packet.payload, "test response");
}
#[test]
fn test_parse_empty_payload() {
let mut buffer = Vec::new();
buffer.extend_from_slice(&1i32.to_le_bytes()); buffer.extend_from_slice(&0i32.to_le_bytes()); buffer.extend_from_slice(&[0, 0]);
let packet = parse_packet_buffer(&buffer).unwrap();
assert_eq!(packet.id, 1);
assert_eq!(packet.payload, "");
}
#[test]
fn test_reject_too_small_buffer() {
let buffer = vec![0u8; 8];
let result = parse_packet_buffer(&buffer);
assert!(matches!(result, Err(RconError::InvalidPacket(_))));
}
#[test]
fn test_parse_payload_with_embedded_nulls() {
let mut buffer = Vec::new();
buffer.extend_from_slice(&1i32.to_le_bytes()); buffer.extend_from_slice(&0i32.to_le_bytes()); buffer.extend_from_slice(b"hello\x00world"); buffer.extend_from_slice(&[0, 0]);
let packet = parse_packet_buffer(&buffer).unwrap();
assert_eq!(packet.payload, "hello\0world");
}
#[test]
fn test_encode_decode_roundtrip() {
let original = Packet::command(42, "test command");
let encoded = original.encode().unwrap();
let parsed = parse_packet_buffer(&encoded[4..]).unwrap();
assert_eq!(parsed.id, original.id);
assert_eq!(parsed.packet_type, original.packet_type);
assert_eq!(parsed.payload, original.payload);
}
#[test]
fn test_negative_id_roundtrip() {
let mut buffer = Vec::new();
buffer.extend_from_slice(&(-1i32).to_le_bytes()); buffer.extend_from_slice(&2i32.to_le_bytes()); buffer.extend_from_slice(&[0, 0]);
let packet = parse_packet_buffer(&buffer).unwrap();
assert_eq!(packet.id, -1);
}
}