use bytes::{Buf, BufMut};
use crate::types::ProtocolError;
use crate::varint;
const MAX_FRAME_SIZE: usize = 2 * 1024 * 1024;
pub const MAX_UNCOMPRESSED_SIZE: usize = 8 * 1024 * 1024;
#[allow(clippy::cast_sign_loss)]
pub fn try_read_frame(data: &[u8]) -> Result<Option<(usize, usize)>, ProtocolError> {
let Some((frame_len, varint_size)) = varint::peek_var_int(data)? else {
return Ok(None);
};
if frame_len < 0 {
return Err(ProtocolError::InvalidData(
"negative frame length".to_string(),
));
}
let frame_len = frame_len as usize;
if frame_len > MAX_FRAME_SIZE {
return Err(ProtocolError::FrameTooLarge {
size: frame_len,
max: MAX_FRAME_SIZE,
});
}
let total = varint_size + frame_len;
if data.len() < total {
return Ok(None);
}
Ok(Some((varint_size, frame_len)))
}
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
pub fn write_frame(dst: &mut impl BufMut, inner: &[u8]) {
varint::write_var_int(dst, inner.len() as i32);
dst.put_slice(inner);
}
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
pub fn write_compressed_frame(dst: &mut impl BufMut, uncompressed_size: i32, payload: &[u8]) {
let size_varint_len = varint::var_int_bytes(uncompressed_size);
let frame_len = size_varint_len + payload.len();
varint::write_var_int(dst, frame_len as i32);
varint::write_var_int(dst, uncompressed_size);
dst.put_slice(payload);
}
#[allow(clippy::cast_sign_loss)]
pub fn read_compressed_frame(data: &[u8]) -> Result<(usize, &[u8]), ProtocolError> {
let mut cursor = data;
let uncompressed_size = varint::read_var_int(&mut cursor)?;
if uncompressed_size < 0 {
return Err(ProtocolError::InvalidData(
"negative uncompressed size".to_string(),
));
}
let uncompressed_size = uncompressed_size as usize;
if uncompressed_size > MAX_UNCOMPRESSED_SIZE {
return Err(ProtocolError::CompressionError(format!(
"uncompressed size {uncompressed_size} exceeds maximum {MAX_UNCOMPRESSED_SIZE}"
)));
}
Ok((uncompressed_size, cursor))
}
pub fn encode_packet_data(packet_id: i32, encode_fn: impl FnOnce(&mut Vec<u8>)) -> Vec<u8> {
let mut buf = Vec::with_capacity(64);
varint::write_var_int(&mut buf, packet_id);
encode_fn(&mut buf);
buf
}
pub fn read_packet_id(buf: &mut impl Buf) -> Result<i32, ProtocolError> {
varint::read_var_int(buf)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_try_read_frame_complete() {
let data = vec![0x03, 0x01, 0x02, 0x03];
let (varint_size, frame_len) = try_read_frame(&data).unwrap().unwrap();
assert_eq!(varint_size + frame_len, 4);
assert_eq!(
&data[varint_size..varint_size + frame_len],
&[0x01, 0x02, 0x03]
);
}
#[test]
fn test_try_read_frame_incomplete() {
let data = vec![0x03, 0x01]; assert!(try_read_frame(&data).unwrap().is_none());
}
#[test]
fn test_try_read_frame_empty() {
assert!(try_read_frame(&[]).unwrap().is_none());
}
#[test]
fn test_write_frame_roundtrip() {
let inner = vec![0x00, 0x48, 0x65, 0x6C, 0x6C, 0x6F]; let mut buf = Vec::new();
write_frame(&mut buf, &inner);
let (varint_size, frame_len) = try_read_frame(&buf).unwrap().unwrap();
assert_eq!(varint_size + frame_len, buf.len());
assert_eq!(&buf[varint_size..varint_size + frame_len], &inner[..]);
}
#[test]
fn test_compressed_frame_uncompressed() {
let mut buf = Vec::new();
let payload = vec![0x00, 0x48, 0x69];
write_compressed_frame(&mut buf, 0, &payload);
let (varint_size, frame_len) = try_read_frame(&buf).unwrap().unwrap();
let frame_data = &buf[varint_size..varint_size + frame_len];
let (uncompressed_size, data) = read_compressed_frame(frame_data).unwrap();
assert_eq!(uncompressed_size, 0);
assert_eq!(data, &payload[..]);
}
}