use super::TransportError;
pub const COMPRESSION_THRESHOLD: usize = 256;
pub const MAX_DECOMPRESSED_SIZE: usize = 256 * 1024 * 1024;
const ZSTD_LEVEL: i32 = 3;
const FLAG_COMPRESSED: u8 = 0x01;
pub fn encode_framed(data: &[u8]) -> Vec<u8> {
if data.len() < COMPRESSION_THRESHOLD {
let mut out = Vec::with_capacity(1 + data.len());
out.push(0x00); out.extend_from_slice(data);
return out;
}
match zstd::stream::encode_all(data, ZSTD_LEVEL) {
Ok(compressed) if compressed.len() < data.len() => {
let mut out = Vec::with_capacity(1 + compressed.len());
out.push(FLAG_COMPRESSED);
out.extend_from_slice(&compressed);
out
}
_ => {
let mut out = Vec::with_capacity(1 + data.len());
out.push(0x00);
out.extend_from_slice(data);
out
}
}
}
pub fn decode_framed(data: &[u8]) -> Result<Vec<u8>, TransportError> {
if data.is_empty() {
return Err(TransportError::ReceiveFailed(
"empty framed payload".to_string(),
));
}
let flags = data[0];
let body = &data[1..];
if flags & FLAG_COMPRESSED != 0 {
let decompressed = zstd::stream::decode_all(body)
.map_err(|e| TransportError::ReceiveFailed(format!("zstd decompress: {}", e)))?;
if decompressed.len() > MAX_DECOMPRESSED_SIZE {
return Err(TransportError::PayloadTooLarge {
size: decompressed.len(),
max: MAX_DECOMPRESSED_SIZE,
});
}
Ok(decompressed)
} else {
Ok(body.to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_small_payload_no_compression() {
let data = b"hello";
let framed = encode_framed(data);
assert_eq!(framed[0], 0x00, "small payload should not be compressed");
assert_eq!(&framed[1..], data);
let decoded = decode_framed(&framed).unwrap();
assert_eq!(decoded, data);
}
#[test]
fn test_large_compressible_payload() {
let data = vec![0x42u8; 4096];
let framed = encode_framed(&data);
assert_eq!(
framed[0] & FLAG_COMPRESSED,
FLAG_COMPRESSED,
"large compressible payload should be compressed"
);
assert!(framed.len() < data.len(), "compressed should be smaller");
let decoded = decode_framed(&framed).unwrap();
assert_eq!(decoded, data);
}
#[test]
fn test_large_incompressible_payload() {
let mut data = Vec::with_capacity(1024);
for i in 0..1024u32 {
data.extend_from_slice(&i.to_le_bytes());
}
let framed = encode_framed(&data);
let decoded = decode_framed(&framed).unwrap();
assert_eq!(decoded, data);
}
#[test]
fn test_empty_framed_payload_error() {
assert!(decode_framed(&[]).is_err());
}
#[test]
fn test_roundtrip_at_threshold_boundary() {
let data = vec![0xAA; COMPRESSION_THRESHOLD];
let framed = encode_framed(&data);
let decoded = decode_framed(&framed).unwrap();
assert_eq!(decoded, data);
}
#[test]
fn test_just_below_threshold() {
let data = vec![0xBB; COMPRESSION_THRESHOLD - 1];
let framed = encode_framed(&data);
assert_eq!(framed[0], 0x00, "below threshold should not compress");
let decoded = decode_framed(&framed).unwrap();
assert_eq!(decoded, data);
}
}