use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use std::io::{Read, Write};
use zamsync_core::{ZamError, ZamResult};
pub const MAX_FRAME_SIZE: u32 = 64 * 1024 * 1024;
const COMPRESS_THRESHOLD: usize = 64;
const FLAG_RAW: u8 = 0x00;
const FLAG_ZSTD: u8 = 0x01;
pub fn write_frame(writer: &mut impl Write, payload: &[u8]) -> ZamResult<usize> {
if payload.len() as u64 >= MAX_FRAME_SIZE as u64 {
return Err(ZamError::Protocol(format!(
"frame payload too large: {} bytes (max {})",
payload.len(),
MAX_FRAME_SIZE - 1
)));
}
let (flag, body): (u8, Vec<u8>) = if payload.len() >= COMPRESS_THRESHOLD {
let compressed = zstd::encode_all(payload, 3)
.map_err(|e| ZamError::Protocol(format!("zstd compress: {e}")))?;
if compressed.len() < payload.len() {
(FLAG_ZSTD, compressed)
} else {
(FLAG_RAW, payload.to_vec())
}
} else {
(FLAG_RAW, payload.to_vec())
};
let total_len = 1u32 + body.len() as u32;
writer.write_u32::<BigEndian>(total_len)?;
writer.write_u8(flag)?;
writer.write_all(&body)?;
Ok(4 + 1 + body.len())
}
pub fn read_frame(reader: &mut impl Read) -> ZamResult<Vec<u8>> {
let total_len = reader.read_u32::<BigEndian>()?;
if total_len as u64 > MAX_FRAME_SIZE as u64 {
return Err(ZamError::Protocol(format!(
"received frame too large: {} bytes (max {})",
total_len, MAX_FRAME_SIZE
)));
}
if total_len == 0 {
return Ok(vec![]);
}
let flag = reader.read_u8()?;
let body_len = (total_len - 1) as usize;
let mut body = vec![0u8; body_len];
reader.read_exact(&mut body)?;
match flag {
FLAG_RAW => Ok(body),
FLAG_ZSTD => zstd::decode_all(body.as_slice())
.map_err(|e| ZamError::Protocol(format!("zstd decompress: {e}"))),
other => Err(ZamError::Protocol(format!(
"unknown frame flag: 0x{other:02x}"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_frame_roundtrip_small() {
let payload = b"hello world"; let mut buf = Vec::new();
write_frame(&mut buf, payload).unwrap();
let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
assert_eq!(decoded, payload);
}
#[test]
fn test_frame_roundtrip_empty() {
let mut buf = Vec::new();
write_frame(&mut buf, &[]).unwrap();
let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
assert!(decoded.is_empty());
}
#[test]
fn test_frame_compression_roundtrip() {
let payload: Vec<u8> = (0..512).map(|i| b"abcdefghij"[i % 10]).collect();
let mut buf = Vec::new();
write_frame(&mut buf, &payload).unwrap();
assert!(
buf.len() < payload.len(),
"compressed frame ({} bytes) should be smaller than raw payload ({} bytes)",
buf.len(),
payload.len()
);
let decoded = read_frame(&mut Cursor::new(&buf)).unwrap();
assert_eq!(decoded, payload);
}
#[test]
fn test_frame_compression_flag_raw() {
let payload = b"hi";
let mut buf = Vec::new();
write_frame(&mut buf, payload).unwrap();
assert_eq!(buf[4], FLAG_RAW);
}
#[test]
fn test_frame_compression_flag_zstd() {
let payload: Vec<u8> = vec![b'x'; 1024];
let mut buf = Vec::new();
write_frame(&mut buf, &payload).unwrap();
assert_eq!(buf[4], FLAG_ZSTD);
}
#[test]
fn test_write_frame_rejects_payload_at_max_size() {
let huge = vec![0u8; MAX_FRAME_SIZE as usize];
let mut buf = Vec::new();
let result = write_frame(&mut buf, &huge);
assert!(
result.is_err(),
"payload at MAX_FRAME_SIZE must be rejected"
);
assert!(buf.is_empty(), "no bytes must be written on rejection");
}
#[test]
fn test_try_consume_frame_rejects_oversized_length_field() {
use super::super::frame_buf::FrameBuffer;
use std::io::Cursor;
let oversized_len = (MAX_FRAME_SIZE as u64 + 1) as u32;
let mut wire = Vec::new();
wire.extend_from_slice(&oversized_len.to_be_bytes()); wire.push(0x00);
let mut fb = FrameBuffer::new();
let result = fb.try_read_frame(&mut Cursor::new(&wire));
assert!(result.is_err(), "oversized length field must be rejected");
}
}