const MAX_DECOMPRESSED_SIZE: u32 = 16 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct CompressedMessage {
pub original_size: u32,
pub uncompressed_prefix: Vec<u8>,
pub compressed_data: Vec<u8>,
pub offset: u32,
}
pub fn compress_message(message: &[u8], offset: usize) -> Option<CompressedMessage> {
if offset >= message.len() {
return None;
}
let prefix = &message[..offset];
let to_compress = &message[offset..];
let compressed = lz4_flex::block::compress(to_compress);
if compressed.len() >= to_compress.len() {
return None;
}
Some(CompressedMessage {
original_size: to_compress.len() as u32,
uncompressed_prefix: prefix.to_vec(),
compressed_data: compressed,
offset: offset as u32,
})
}
pub fn decompress_message(
uncompressed_prefix: &[u8],
compressed_data: &[u8],
original_size: u32,
) -> Result<Vec<u8>, crate::Error> {
if original_size > MAX_DECOMPRESSED_SIZE {
return Err(crate::Error::invalid_data(format!(
"decompressed size {} exceeds maximum allowed size {}",
original_size, MAX_DECOMPRESSED_SIZE
)));
}
let decompressed = lz4_flex::block::decompress(compressed_data, original_size as usize)
.map_err(|e| crate::Error::invalid_data(format!("LZ4 decompression failed: {e}")))?;
let mut result = Vec::with_capacity(uncompressed_prefix.len() + decompressed.len());
result.extend_from_slice(uncompressed_prefix);
result.extend_from_slice(&decompressed);
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compress_and_decompress_roundtrip() {
let message: Vec<u8> = b"ABCDEFGH".iter().copied().cycle().take(1024).collect();
let compressed = compress_message(&message, 0).expect("should compress");
assert!(compressed.compressed_data.len() < message.len());
assert_eq!(compressed.original_size, message.len() as u32);
assert!(compressed.uncompressed_prefix.is_empty());
assert_eq!(compressed.offset, 0);
let decompressed = decompress_message(
&compressed.uncompressed_prefix,
&compressed.compressed_data,
compressed.original_size,
)
.expect("should decompress");
assert_eq!(decompressed, message);
}
#[test]
fn compress_with_offset_preserves_prefix() {
let mut message = vec![0xFE; 64]; let payload: Vec<u8> = b"HelloWorld".iter().copied().cycle().take(2048).collect();
message.extend_from_slice(&payload);
let compressed = compress_message(&message, 64).expect("should compress");
assert_eq!(compressed.offset, 64);
assert_eq!(compressed.uncompressed_prefix, &message[..64]);
assert_eq!(compressed.original_size, payload.len() as u32);
assert!(compressed.compressed_data.len() < payload.len());
let decompressed = decompress_message(
&compressed.uncompressed_prefix,
&compressed.compressed_data,
compressed.original_size,
)
.expect("should decompress");
assert_eq!(decompressed, message);
}
#[test]
fn compress_with_offset_zero_compresses_entire_message() {
let message: Vec<u8> = vec![42u8; 4096];
let compressed = compress_message(&message, 0).expect("should compress");
assert_eq!(compressed.offset, 0);
assert!(compressed.uncompressed_prefix.is_empty());
assert_eq!(compressed.original_size, 4096);
let decompressed = decompress_message(
&compressed.uncompressed_prefix,
&compressed.compressed_data,
compressed.original_size,
)
.expect("should decompress");
assert_eq!(decompressed, message);
}
#[test]
fn compress_empty_message_returns_none() {
let message: &[u8] = &[];
assert!(compress_message(message, 0).is_none());
}
#[test]
fn compress_offset_at_end_returns_none() {
let message = b"short";
assert!(compress_message(message, 5).is_none());
assert!(compress_message(message, 100).is_none());
}
#[test]
fn incompressible_data_returns_none() {
let mut message = Vec::with_capacity(256);
for i in 0u16..256 {
message.push(((i.wrapping_mul(137).wrapping_add(53)) & 0xFF) as u8);
}
assert!(
compress_message(&message, 0).is_none(),
"incompressible data should return None"
);
}
#[test]
fn large_message_compresses_well() {
let message: Vec<u8> = b"SMB2 compression test data! "
.iter()
.copied()
.cycle()
.take(1024 * 1024)
.collect();
let compressed = compress_message(&message, 0).expect("should compress large message");
let ratio = message.len() as f64 / compressed.compressed_data.len() as f64;
assert!(
ratio > 4.0,
"compression ratio {ratio:.1} is too low for repetitive data"
);
let decompressed = decompress_message(
&compressed.uncompressed_prefix,
&compressed.compressed_data,
compressed.original_size,
)
.expect("should decompress");
assert_eq!(decompressed.len(), message.len());
assert_eq!(decompressed, message);
}
#[test]
fn decompress_with_wrong_original_size_fails() {
let message: Vec<u8> = vec![0xAA; 1024];
let compressed = compress_message(&message, 0).expect("should compress");
let result = decompress_message(&[], &compressed.compressed_data, 512);
assert!(result.is_err(), "wrong original_size should cause an error");
}
#[test]
fn decompress_rejects_oversized_original_size() {
let bogus_compressed = vec![0u8; 10];
let result = decompress_message(&[], &bogus_compressed, MAX_DECOMPRESSED_SIZE + 1);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("exceeds maximum"),
"error should mention size limit, got: {err_msg}"
);
}
#[test]
fn decompress_with_exact_max_size_is_allowed() {
let bogus_compressed = vec![0u8; 10];
let result = decompress_message(&[], &bogus_compressed, MAX_DECOMPRESSED_SIZE);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("decompression failed"),
"should fail on decompression, not size check, got: {err_msg}"
);
}
#[test]
fn decompress_corrupt_data_fails() {
let corrupt = vec![0xFF, 0xFE, 0xFD, 0xFC, 0xFB];
let result = decompress_message(&[], &corrupt, 1024);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("decompression failed"),
"error should mention decompression failure, got: {err_msg}"
);
}
#[test]
fn decompress_preserves_prefix_in_output() {
let prefix = b"PREFIX_DATA";
let payload: Vec<u8> = vec![0x42; 2048];
let compressed_payload = compress_message(&payload, 0).expect("should compress payload");
let result = decompress_message(
prefix,
&compressed_payload.compressed_data,
compressed_payload.original_size,
)
.expect("should decompress");
assert_eq!(&result[..prefix.len()], prefix);
assert_eq!(&result[prefix.len()..], &payload);
}
}