use crate::config::MAX_PAYLOAD_SIZE;
use crate::error::{ProtocolError, Result};
#[derive(Copy, Clone)]
pub enum CompressionKind {
Lz4,
Zstd,
}
const MAX_DECOMPRESSION_SIZE: usize = MAX_PAYLOAD_SIZE;
const MIN_ENTROPY_THRESHOLD: f64 = 4.0;
fn calculate_entropy(data: &[u8]) -> f64 {
if data.is_empty() {
return 0.0;
}
let mut freq = [0u32; 256];
for &byte in data {
freq[byte as usize] += 1;
}
let len = data.len() as f64;
let mut entropy = 0.0;
for &count in &freq {
if count > 0 {
let p = count as f64 / len;
entropy -= p * p.log2();
}
}
entropy
}
fn should_compress_adaptive(data: &[u8], threshold_bytes: usize) -> bool {
if data.len() < threshold_bytes {
return false;
}
if data.len() < 1024 {
return true;
}
let sample_size = data.len().min(512);
let entropy = calculate_entropy(&data[..sample_size]);
entropy < MIN_ENTROPY_THRESHOLD
}
pub fn compress(data: &[u8], kind: &CompressionKind) -> Result<Vec<u8>> {
match kind {
CompressionKind::Lz4 => Ok(lz4_flex::compress_prepend_size(data)),
CompressionKind::Zstd => {
let mut out = Vec::new();
zstd::stream::copy_encode(data, &mut out, 1)
.map_err(|_| ProtocolError::CompressionFailure)?;
Ok(out)
}
}
}
pub fn decompress(data: &[u8], kind: &CompressionKind) -> Result<Vec<u8>> {
match *kind {
CompressionKind::Lz4 => {
if data.len() < 4 {
return Err(ProtocolError::DecompressionFailure);
}
let claimed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
if claimed_size > MAX_DECOMPRESSION_SIZE {
return Err(ProtocolError::DecompressionFailure);
}
let decompressed = lz4_flex::decompress_size_prepended(data)
.map_err(|_| ProtocolError::DecompressionFailure)?;
if decompressed.len() > MAX_DECOMPRESSION_SIZE {
return Err(ProtocolError::DecompressionFailure);
}
Ok(decompressed)
}
CompressionKind::Zstd => {
let mut out = Vec::new();
let mut reader = zstd::stream::Decoder::new(data)
.map_err(|_| ProtocolError::DecompressionFailure)?;
use std::io::Read;
let mut buffer = [0u8; 8192];
loop {
match reader.read(&mut buffer) {
Ok(0) => break, Ok(n) => {
out.extend_from_slice(&buffer[..n]);
if out.len() > MAX_DECOMPRESSION_SIZE {
return Err(ProtocolError::DecompressionFailure);
}
}
Err(_) => return Err(ProtocolError::DecompressionFailure),
}
}
Ok(out)
}
}
}
pub fn maybe_compress(
data: &[u8],
kind: &CompressionKind,
threshold_bytes: usize,
) -> Result<(Vec<u8>, bool)> {
if data.len() < threshold_bytes {
Ok((data.to_vec(), false))
} else {
Ok((compress(data, kind)?, true))
}
}
pub fn maybe_compress_adaptive(
data: &[u8],
kind: &CompressionKind,
threshold_bytes: usize,
) -> Result<(Vec<u8>, bool)> {
if should_compress_adaptive(data, threshold_bytes) {
let compressed = compress(data, kind)?;
if compressed.len() < data.len() {
Ok((compressed, true))
} else {
Ok((data.to_vec(), false))
}
} else {
Ok((data.to_vec(), false))
}
}
pub fn maybe_decompress(
data: &[u8],
kind: &CompressionKind,
was_compressed: bool,
) -> Result<Vec<u8>> {
if was_compressed {
decompress(data, kind)
} else {
Ok(data.to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[allow(clippy::unwrap_used)]
fn test_lz4_compression_roundtrip() {
let original = b"Hello, World! This is a test of LZ4 compression.";
let compressed = compress(original, &CompressionKind::Lz4).unwrap();
let decompressed = decompress(&compressed, &CompressionKind::Lz4).unwrap();
assert_eq!(original.as_slice(), decompressed.as_slice());
}
#[test]
#[allow(clippy::unwrap_used)]
fn test_zstd_compression_roundtrip() {
let original = b"Hello, World! This is a test of Zstd compression.";
let compressed = compress(original, &CompressionKind::Zstd).unwrap();
let decompressed = decompress(&compressed, &CompressionKind::Zstd).unwrap();
assert_eq!(original.as_slice(), decompressed.as_slice());
}
#[test]
fn test_lz4_oom_attack_prevention() {
let malicious_payload = vec![0x2b, 0x60, 0xbb, 0xbb];
let result = decompress(&malicious_payload, &CompressionKind::Lz4);
assert!(
result.is_err(),
"Should reject malicious payload claiming huge output size"
);
}
#[test]
fn test_lz4_size_limit_enforcement() {
let claimed_size = (MAX_DECOMPRESSION_SIZE + 1) as u32;
let mut malicious = claimed_size.to_le_bytes().to_vec();
malicious.extend_from_slice(&[0u8; 16]);
let result = decompress(&malicious, &CompressionKind::Lz4);
assert!(
result.is_err(),
"Should reject payload claiming size > MAX_DECOMPRESSION_SIZE"
);
}
#[test]
fn test_lz4_short_input_rejection() {
let short_input = vec![0x2b, 0x60];
let result = decompress(&short_input, &CompressionKind::Lz4);
assert!(result.is_err(), "Should reject input shorter than 4 bytes");
}
#[test]
fn test_malformed_compressed_data() {
let malformed = vec![0x10, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff];
let result = decompress(&malformed, &CompressionKind::Lz4);
assert!(result.is_err(), "Should reject malformed compressed data");
}
#[test]
#[allow(clippy::unwrap_used)]
fn test_maybe_compress_below_threshold() {
let data = b"tiny";
let (out, compressed) = maybe_compress(data, &CompressionKind::Lz4, 512).unwrap();
assert!(!compressed);
assert_eq!(out, data);
let roundtrip = maybe_decompress(&out, &CompressionKind::Lz4, compressed).unwrap();
assert_eq!(roundtrip, data);
}
#[test]
#[allow(clippy::unwrap_used)]
fn test_maybe_compress_above_threshold() {
let data = vec![1u8; 1024];
let (out, compressed) = maybe_compress(&data, &CompressionKind::Lz4, 512).unwrap();
assert!(compressed);
let roundtrip = maybe_decompress(&out, &CompressionKind::Lz4, compressed).unwrap();
assert_eq!(roundtrip, data);
}
#[test]
fn test_entropy_calculation() {
let zeros = vec![0u8; 100];
assert!(calculate_entropy(&zeros) < 0.1);
let random: Vec<u8> = (0..=255).cycle().take(1000).collect();
assert!(calculate_entropy(&random) > 7.0);
let pattern = vec![0, 1, 0, 1, 0, 1, 0, 1];
assert!(calculate_entropy(&pattern) < 2.0);
}
#[test]
#[allow(clippy::unwrap_used)]
fn test_adaptive_compression_low_entropy() {
let data = vec![0u8; 2048];
let (out, compressed) = maybe_compress_adaptive(&data, &CompressionKind::Lz4, 512).unwrap();
assert!(compressed);
assert!(out.len() < data.len());
}
#[test]
#[allow(clippy::unwrap_used)]
fn test_adaptive_compression_high_entropy() {
let data: Vec<u8> = (0..=255).cycle().take(2048).collect();
let (out, compressed) = maybe_compress_adaptive(&data, &CompressionKind::Lz4, 512).unwrap();
assert!(!compressed);
assert_eq!(out.len(), data.len());
}
#[test]
#[allow(clippy::unwrap_used)]
fn test_adaptive_compression_size_check() {
let data = vec![0u8; 100]; let (_out, _compressed) =
maybe_compress_adaptive(&data, &CompressionKind::Lz4, 50).unwrap();
}
}