use crate::{
protocol::CompressionMethod,
Error,
Result,
};
use bytes::{
Buf,
BufMut,
Bytes,
BytesMut,
};
use cityhash_rs::cityhash_102_128;
const HEADER_SIZE: usize = 9;
const CHECKSUM_SIZE: usize = 16;
#[repr(u8)]
enum CompressionMethodByte {
None = 0x02,
Lz4 = 0x82,
Zstd = 0x90,
}
const MAX_COMPRESSED_SIZE: usize = 0x40000000;
pub fn compress(method: CompressionMethod, data: &[u8]) -> Result<Bytes> {
match method {
CompressionMethod::None => {
compress_none(data)
}
CompressionMethod::Lz4 => compress_lz4(data),
CompressionMethod::Zstd => compress_zstd(data),
}
}
pub fn decompress(data: &[u8]) -> Result<Bytes> {
if data.len() < CHECKSUM_SIZE + HEADER_SIZE {
return Err(Error::Compression(
"Data too small for checksum and compression header".to_string(),
));
}
let data_without_checksum = &data[CHECKSUM_SIZE..];
let method = data_without_checksum[0];
let mut reader = &data_without_checksum[1..];
let compressed_size = reader.get_u32_le() as usize;
let uncompressed_size = reader.get_u32_le() as usize;
if compressed_size > MAX_COMPRESSED_SIZE {
return Err(Error::Compression(format!(
"Compressed size too large: {}",
compressed_size
)));
}
if uncompressed_size > MAX_COMPRESSED_SIZE {
return Err(Error::Compression(format!(
"Uncompressed size too large: {}",
uncompressed_size
)));
}
let compressed_data = &data_without_checksum[HEADER_SIZE..];
match method {
0x02 => {
if compressed_data.len() != uncompressed_size {
return Err(Error::Compression(format!(
"Uncompressed data size mismatch: expected {}, got {}",
uncompressed_size,
compressed_data.len()
)));
}
Ok(Bytes::copy_from_slice(compressed_data))
}
0x82 => {
decompress_lz4(compressed_data, uncompressed_size)
}
0x90 => {
decompress_zstd(compressed_data, uncompressed_size)
}
_ => Err(Error::Compression(format!(
"Unknown compression method: 0x{:02x}",
method
))),
}
}
fn compress_lz4(data: &[u8]) -> Result<Bytes> {
let max_compressed_size = lz4::block::compress_bound(data.len())?;
let mut compressed = vec![0u8; max_compressed_size];
let compressed_size =
lz4::block::compress_to_buffer(data, None, false, &mut compressed)?;
compressed.truncate(compressed_size);
let mut header_and_data =
BytesMut::with_capacity(HEADER_SIZE + compressed_size);
header_and_data.put_u8(CompressionMethodByte::Lz4 as u8);
header_and_data.put_u32_le((HEADER_SIZE + compressed_size) as u32); header_and_data.put_u32_le(data.len() as u32);
header_and_data.put_slice(&compressed);
let checksum = cityhash_102_128(&header_and_data);
let mut output =
BytesMut::with_capacity(CHECKSUM_SIZE + header_and_data.len());
output.put_u64_le((checksum >> 64) as u64); output.put_u64_le(checksum as u64); output.put_slice(&header_and_data);
Ok(output.freeze())
}
fn decompress_lz4(data: &[u8], uncompressed_size: usize) -> Result<Bytes> {
let decompressed =
lz4::block::decompress(data, Some(uncompressed_size as i32))?;
if decompressed.len() != uncompressed_size {
return Err(Error::Compression(format!(
"LZ4 decompression size mismatch: expected {}, got {}",
uncompressed_size,
decompressed.len()
)));
}
Ok(Bytes::from(decompressed))
}
fn compress_zstd(data: &[u8]) -> Result<Bytes> {
let compressed = zstd::bulk::compress(data, 3) .map_err(|e| {
Error::Compression(format!("ZSTD compression failed: {}", e))
})?;
let mut header_and_data =
BytesMut::with_capacity(HEADER_SIZE + compressed.len());
header_and_data.put_u8(CompressionMethodByte::Zstd as u8);
header_and_data.put_u32_le((HEADER_SIZE + compressed.len()) as u32); header_and_data.put_u32_le(data.len() as u32);
header_and_data.put_slice(&compressed);
let checksum = cityhash_102_128(&header_and_data);
let mut output =
BytesMut::with_capacity(CHECKSUM_SIZE + header_and_data.len());
output.put_u64_le((checksum >> 64) as u64); output.put_u64_le(checksum as u64); output.put_slice(&header_and_data);
Ok(output.freeze())
}
fn decompress_zstd(data: &[u8], uncompressed_size: usize) -> Result<Bytes> {
let decompressed = zstd::bulk::decompress(data, uncompressed_size)
.map_err(|e| {
Error::Compression(format!("ZSTD decompression failed: {}", e))
})?;
if decompressed.len() != uncompressed_size {
return Err(Error::Compression(format!(
"ZSTD decompression size mismatch: expected {}, got {}",
uncompressed_size,
decompressed.len()
)));
}
Ok(Bytes::from(decompressed))
}
fn compress_none(data: &[u8]) -> Result<Bytes> {
let mut header_and_data =
BytesMut::with_capacity(HEADER_SIZE + data.len());
header_and_data.put_u8(CompressionMethodByte::None as u8);
header_and_data.put_u32_le((HEADER_SIZE + data.len()) as u32); header_and_data.put_u32_le(data.len() as u32);
header_and_data.put_slice(data);
let checksum = cityhash_102_128(&header_and_data);
let mut output =
BytesMut::with_capacity(CHECKSUM_SIZE + header_and_data.len());
output.put_u128_le(checksum); output.put_slice(&header_and_data);
Ok(output.freeze())
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
#[test]
fn test_compress_decompress_none() {
let original = b"Hello, ClickHouse!";
let compressed = compress(CompressionMethod::None, original).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(&decompressed[..], original);
}
#[test]
fn test_compress_decompress_lz4() {
let original = b"Hello, ClickHouse! ".repeat(100);
let compressed = compress(CompressionMethod::Lz4, &original).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(&decompressed[..], &original[..]);
assert!(compressed.len() < original.len());
}
#[test]
fn test_compress_decompress_zstd() {
let original =
b"ClickHouse is a fast open-source column-oriented database"
.repeat(50);
let compressed = compress(CompressionMethod::Zstd, &original).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(&decompressed[..], &original[..]);
assert!(compressed.len() < original.len());
}
#[test]
fn test_empty_data() {
let original = b"";
let compressed = compress(CompressionMethod::Lz4, original).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(&decompressed[..], original);
}
#[test]
fn test_large_data_lz4() {
let original = vec![42u8; 100_000];
let compressed = compress(CompressionMethod::Lz4, &original).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(&decompressed[..], &original[..]);
assert!(compressed.len() < original.len() / 10);
}
#[test]
fn test_invalid_compression_method() {
let mut bad_data = vec![0xFFu8; 20]; bad_data[1..5].copy_from_slice(&20u32.to_le_bytes()); bad_data[5..9].copy_from_slice(&10u32.to_le_bytes());
let result = decompress(&bad_data);
assert!(result.is_err());
}
#[test]
fn test_header_too_small() {
let bad_data = vec![0x82, 1, 2, 3];
let result = decompress(&bad_data);
assert!(result.is_err());
}
}