use crate::tools::BytesGatherer;
use bytes::{BufMut, Bytes, BytesMut};
use std::io::{Read, Write};
const COMPRESSION_VERSION_NONE: u8 = 0;
const COMPRESSION_VERSION_LZ4: u8 = 1;
const COMPRESSION_VERSION_BROTLI: u8 = 2;
const MIN_BYTES_FOR_LZ4: usize = 64;
const MIN_BYTES_FOR_BROTLI: usize = 128;
const MAX_DECOMPRESSED_SIZE: usize = 32 * 1024 * 1024;
fn compress_passthrough(input: &[u8]) -> anyhow::Result<BytesGatherer> {
let mut bytes_gatherer = BytesGatherer::default();
bytes_gatherer.put_u8(COMPRESSION_VERSION_NONE);
bytes_gatherer.put_slice(input);
Ok(bytes_gatherer)
}
fn decompress_passthrough(input: &[u8]) -> anyhow::Result<BytesGatherer> {
Ok(BytesGatherer::from_bytes(Bytes::copy_from_slice(&input[1..])))
}
fn compress_lz4(input: &[u8]) -> anyhow::Result<BytesGatherer> {
if input.len() < MIN_BYTES_FOR_LZ4 {
return compress_passthrough(input);
}
let max_out = lz4_flex::block::get_maximum_output_size(input.len());
let mut result = BytesMut::with_capacity(5 + max_out);
result.put_u8(COMPRESSION_VERSION_LZ4);
result.put_slice(&(input.len() as u32).to_le_bytes());
let data_start = result.len(); result.resize(data_start + max_out, 0);
let n = lz4_flex::block::compress_into(input, &mut result[data_start..]).map_err(|e| anyhow::anyhow!("lz4 compression failed: {}", e))?;
result.truncate(data_start + n);
Ok(BytesGatherer::from_bytes(result.freeze()))
}
fn decompress_lz4(input: &[u8]) -> anyhow::Result<BytesGatherer> {
let data = &input[1..];
if data.len() < 4 {
anyhow::bail!("lz4 decompression failed: missing size prefix");
}
let uncompressed_size = u32::from_le_bytes(data[..4].try_into().unwrap()) as usize;
if uncompressed_size > MAX_DECOMPRESSED_SIZE {
anyhow::bail!("lz4 decompressed size {} exceeds limit {}", uncompressed_size, MAX_DECOMPRESSED_SIZE);
}
lz4_flex::decompress_size_prepended(data).map(|v| BytesGatherer::from_bytes(Bytes::from(v))).map_err(|e| anyhow::anyhow!("lz4 decompression failed: {}", e))
}
fn compress_brotli(input: &[u8]) -> anyhow::Result<BytesGatherer> {
if input.len() < MIN_BYTES_FOR_BROTLI {
return compress_passthrough(input);
}
let mut result = vec![COMPRESSION_VERSION_BROTLI];
{
let mut writer = brotli::CompressorWriter::new(&mut result, 4096, 11, 22);
writer.write_all(input)?;
}
Ok(BytesGatherer::from_bytes(Bytes::from(result)))
}
fn decompress_brotli(input: &[u8]) -> anyhow::Result<BytesGatherer> {
let mut output = Vec::new();
let bytes_read = brotli::Decompressor::new(&input[1..], 4096).take(MAX_DECOMPRESSED_SIZE as u64 + 1).read_to_end(&mut output)?;
if bytes_read > MAX_DECOMPRESSED_SIZE {
anyhow::bail!("brotli decompressed size {} exceeds limit {}", bytes_read, MAX_DECOMPRESSED_SIZE);
}
Ok(BytesGatherer::from_bytes(Bytes::from(output)))
}
pub fn compress_for_speed(input: &[u8]) -> anyhow::Result<BytesGatherer> {
let result = compress_lz4(input)?;
if result.len() < input.len() { Ok(result) } else { compress_passthrough(input) }
}
pub fn compress_for_size(input: &[u8]) -> anyhow::Result<BytesGatherer> {
let result = compress_brotli(input)?;
if result.len() < input.len() { Ok(result) } else { compress_passthrough(input) }
}
pub fn decompress(input: &[u8]) -> anyhow::Result<BytesGatherer> {
if input.is_empty() {
anyhow::bail!("missing compression version byte");
}
match input[0] {
COMPRESSION_VERSION_LZ4 => decompress_lz4(input),
COMPRESSION_VERSION_BROTLI => decompress_brotli(input),
COMPRESSION_VERSION_NONE => decompress_passthrough(input),
v => anyhow::bail!("unsupported compression version byte {}", v),
}
}
#[cfg(test)]
mod tests {
use crate::tools::compression::{compress_for_size, compress_for_speed, decompress};
use crate::tools::tools;
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
extern crate wasm_bindgen_test;
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
use wasm_bindgen_test::*;
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
fn roundtrip_speed(input: &[u8], msg: &str) -> anyhow::Result<()> {
let compressed = compress_for_speed(input)?.to_bytes();
let output = decompress(&compressed)?.to_bytes();
assert_eq!(input, output.as_ref(), "{}", msg);
Ok(())
}
fn roundtrip_size(input: &[u8], msg: &str) -> anyhow::Result<()> {
let compressed = compress_for_size(input)?.to_bytes();
let output = decompress(&compressed)?.to_bytes();
assert_eq!(input, output.as_ref(), "{}", msg);
Ok(())
}
#[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), tokio::test)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
async fn test_compression_is_reversible() -> anyhow::Result<()> {
let input = b"Some example string to test compression and decompression.";
roundtrip_speed(input, "lz4 roundtrip")?;
roundtrip_size(input, "brotli roundtrip")?;
Ok(())
}
#[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), tokio::test)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
async fn test_compression_is_reversible_short() -> anyhow::Result<()> {
let input = b"Some...";
roundtrip_speed(input, "lz4 short passthrough")?;
roundtrip_size(input, "brotli short passthrough")?;
Ok(())
}
#[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), tokio::test)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
async fn test_compression_is_reversible_empty() -> anyhow::Result<()> {
let input = b"";
roundtrip_speed(input, "lz4 empty passthrough")?;
roundtrip_size(input, "brotli empty passthrough")?;
Ok(())
}
#[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), tokio::test)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
async fn test_compression_is_reversible_random() -> anyhow::Result<()> {
const LENGTH: usize = 8192 * 8192;
let input: Vec<u8> = tools::random_bytes(LENGTH);
roundtrip_speed(&input, "lz4 random passthrough")?;
Ok(())
}
#[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), tokio::test)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
async fn test_brotli_actually_compresses_html() -> anyhow::Result<()> {
let input = "<!DOCTYPE html><html><head><title>Test</title></head><body>".repeat(50);
let compressed = compress_for_size(input.as_bytes())?.to_bytes();
assert!(
compressed.len() < input.len(),
"brotli should compress repetitive HTML: {} -> {}",
input.len(),
compressed.len()
);
let output = decompress(&compressed)?.to_bytes();
assert_eq!(input.as_bytes(), output.as_ref());
Ok(())
}
#[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), tokio::test)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
async fn test_lz4_rejects_oversized_decompressed_payload() {
let fake_size: u32 = (super::MAX_DECOMPRESSED_SIZE as u32) + 1;
let mut payload = vec![super::COMPRESSION_VERSION_LZ4];
payload.extend_from_slice(&fake_size.to_le_bytes());
payload.extend_from_slice(&[0u8; 16]); let result = decompress(&payload);
let error_message = result.err().expect("should have failed").to_string();
assert!(error_message.contains("exceeds limit"), "unexpected error: {}", error_message);
}
#[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), tokio::test)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
async fn test_lz4_accepts_valid_decompressed_payload() -> anyhow::Result<()> {
let input = "hello world! ".repeat(100);
let compressed = compress_for_speed(input.as_bytes())?.to_bytes();
let output = decompress(&compressed)?.to_bytes();
assert_eq!(input.as_bytes(), output.as_ref());
Ok(())
}
#[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), tokio::test)]
#[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
async fn test_lz4_actually_compresses_text() -> anyhow::Result<()> {
let input = "The quick brown fox jumps over the lazy dog. ".repeat(100);
let compressed = compress_for_speed(input.as_bytes())?.to_bytes();
assert!(
compressed.len() < input.len(),
"lz4 should compress repetitive text: {} -> {}",
input.len(),
compressed.len()
);
let output = decompress(&compressed)?.to_bytes();
assert_eq!(input.as_bytes(), output.as_ref());
Ok(())
}
}