#![deny(clippy::cast_possible_truncation)]
#[cfg(feature = "zstd")]
use std::io::Read;
#[cfg(test)]
use crate::delta::{DeltaDecoder, DeltaEncoder, MAX_DELTA_OUTPUT_SIZE};
const COMPRESSED_HEADER_LEN: usize = 9;
const MAX_DECOMPRESSED_SIZE: u64 = 256 * 1024 * 1024;
const ZSTD_MAGIC: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
enum CompressionType {
None = 0,
Zstd = 1,
Delta = 2,
}
impl CompressionType {
fn from_u8(value: u8) -> Option<Self> {
match value {
0 => Some(CompressionType::None),
1 => Some(CompressionType::Zstd),
2 => Some(CompressionType::Delta),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct CompressionConfig {
pub enabled: bool,
pub level: i32,
pub min_size: usize,
pub max_delta_size: usize,
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
enabled: cfg!(feature = "zstd"),
level: 3, min_size: 256, max_delta_size: 10_000_000, }
}
}
impl CompressionConfig {
pub fn from_env() -> Self {
let mut config = Self::default();
if let Ok(val) = std::env::var("HEDDLE_COMPRESSION") {
let requested = val != "0" && val.to_lowercase() != "false";
config.enabled = requested && cfg!(feature = "zstd");
}
if let Ok(val) = std::env::var("HEDDLE_COMPRESSION_LEVEL")
&& let Ok(level) = val.parse::<i32>()
{
config.level = level.clamp(1, 22);
}
if let Ok(val) = std::env::var("HEDDLE_COMPRESSION_MIN_SIZE")
&& let Ok(size) = val.parse::<usize>()
{
config.min_size = size;
}
config
}
pub fn disabled() -> Self {
Self {
enabled: false,
level: 0,
min_size: usize::MAX,
max_delta_size: 0,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum CompressionError {
#[error("decompression failed: {0}")]
DecompressionFailed(String),
#[error("compression failed: {0}")]
CompressionFailed(String),
#[error("invalid compression type: {0}")]
InvalidType(u8),
#[error("corrupted data: {0}")]
CorruptedData(String),
#[error("invalid operation: {0}")]
InvalidOperation(String),
#[error("object size {size} exceeds maximum {max}")]
SizeLimitExceeded { size: u64, max: u64 },
}
#[cfg(feature = "zstd")]
fn compress_zstd_impl(data: &[u8], level: i32) -> Result<Vec<u8>, CompressionError> {
zstd::encode_all(data, level).map_err(|e| CompressionError::CompressionFailed(e.to_string()))
}
#[cfg(not(feature = "zstd"))]
fn compress_zstd_impl(_data: &[u8], _level: i32) -> Result<Vec<u8>, CompressionError> {
Err(CompressionError::InvalidOperation(
"zstd compression support not compiled into this build".to_string(),
))
}
#[cfg(feature = "bench")]
pub fn compress_zstd(data: &[u8], level: i32) -> Result<Vec<u8>, CompressionError> {
compress_zstd_impl(data, level)
}
#[cfg(feature = "zstd")]
fn decompress_zstd_impl(data: &[u8], expected_size: u64) -> Result<Vec<u8>, CompressionError> {
validate_size(expected_size)?;
let expected_capacity = checked_size_to_usize("zstd expected size", expected_size)?;
let mut decoder = zstd::stream::read::Decoder::new(data)
.map_err(|e| CompressionError::DecompressionFailed(e.to_string()))?;
let mut decompressed = Vec::with_capacity(expected_capacity);
let mut buffer = [0u8; 8192];
loop {
let bytes_read = decoder
.read(&mut buffer)
.map_err(|e| CompressionError::DecompressionFailed(e.to_string()))?;
if bytes_read == 0 {
break;
}
let next_size = decompressed.len().checked_add(bytes_read).ok_or_else(|| {
CompressionError::CorruptedData("decompressed size overflows".to_string())
})?;
let next_size = u64::try_from(next_size).map_err(|_| {
CompressionError::CorruptedData("decompressed size exceeds platform limits".to_string())
})?;
if next_size > expected_size {
return Err(CompressionError::CorruptedData(format!(
"decompressed size exceeds recorded header size: expected {expected_size}, got at least {next_size}",
)));
}
decompressed.extend_from_slice(&buffer[..bytes_read]);
}
Ok(decompressed)
}
#[cfg(not(feature = "zstd"))]
fn decompress_zstd_impl(_data: &[u8], expected_size: u64) -> Result<Vec<u8>, CompressionError> {
validate_size(expected_size)?;
Err(CompressionError::InvalidOperation(
"zstd-compressed data is unsupported in this build".to_string(),
))
}
#[cfg(feature = "bench")]
pub fn decompress_zstd(data: &[u8], expected_size: u64) -> Result<Vec<u8>, CompressionError> {
decompress_zstd_impl(data, expected_size)
}
pub fn compress(
data: &[u8],
config: &CompressionConfig,
) -> Result<Option<Vec<u8>>, CompressionError> {
if !config.enabled || data.len() < config.min_size {
return Ok(None);
}
validate_size(data.len() as u64)?;
let compressed = compress_zstd_impl(data, config.level)?;
if compressed.len() >= data.len() {
return Ok(None);
}
let mut result = Vec::with_capacity(COMPRESSED_HEADER_LEN + compressed.len());
result.push(CompressionType::Zstd as u8);
result.extend_from_slice(&(data.len() as u64).to_be_bytes());
result.extend_from_slice(&compressed);
Ok(Some(result))
}
#[cfg(test)]
fn compress_delta(
data: &[u8],
base: &[u8],
config: &CompressionConfig,
) -> Result<Option<Vec<u8>>, CompressionError> {
if !config.enabled || data.len() < config.min_size || base.len() > config.max_delta_size {
return Ok(None);
}
validate_size(data.len() as u64)?;
let delta = DeltaEncoder::encode(base, data);
if delta.len() >= data.len() {
return Ok(None);
}
let mut result = Vec::with_capacity(COMPRESSED_HEADER_LEN + delta.len());
result.push(CompressionType::Delta as u8);
result.extend_from_slice(&(data.len() as u64).to_be_bytes());
result.extend_from_slice(&delta);
Ok(Some(result))
}
pub fn decompress(data: &[u8]) -> Result<Vec<u8>, CompressionError> {
if data.len() < COMPRESSED_HEADER_LEN {
return Ok(data.to_vec());
}
let compression_type =
CompressionType::from_u8(data[0]).ok_or_else(|| CompressionError::InvalidType(data[0]))?;
match compression_type {
CompressionType::None => {
let expected_size = read_u64_size(data)?;
let payload = data[COMPRESSED_HEADER_LEN..].to_vec();
validate_decompressed_len(expected_size, payload.len())?;
Ok(payload)
}
CompressionType::Zstd if zstd_header_len(data).is_some() => {
decompress_zstd_with_header(data)
}
CompressionType::Zstd => Ok(data.to_vec()),
CompressionType::Delta => {
Err(CompressionError::InvalidOperation(
"Delta compression requires base object".to_string(),
))
}
}
}
#[cfg(test)]
fn decompress_delta(delta_data: &[u8], base: &[u8]) -> Result<Vec<u8>, CompressionError> {
if delta_data.len() < COMPRESSED_HEADER_LEN {
return Err(CompressionError::CorruptedData(
"Delta data too short".to_string(),
));
}
let compression_type = CompressionType::from_u8(delta_data[0])
.ok_or_else(|| CompressionError::InvalidType(delta_data[0]))?;
if compression_type != CompressionType::Delta {
return Err(CompressionError::InvalidOperation(
"Expected delta compression".to_string(),
));
}
decompress_delta_with_header(delta_data, base)
}
pub fn is_compressed(data: &[u8]) -> bool {
if data.len() < COMPRESSED_HEADER_LEN {
return false;
}
matches!(
CompressionType::from_u8(data[0]),
Some(CompressionType::Zstd)
) && zstd_header_len(data).is_some()
}
pub fn header_uncompressed_size(data: &[u8]) -> Option<u64> {
if data.len() < COMPRESSED_HEADER_LEN {
return None;
}
match CompressionType::from_u8(data[0])? {
CompressionType::Zstd => {
zstd_header_len(data)?;
Some(u64::from_be_bytes(
data[1..COMPRESSED_HEADER_LEN].try_into().ok()?,
))
}
CompressionType::None | CompressionType::Delta => None,
}
}
#[cfg(test)]
fn compression_info(data: &[u8]) -> Option<(CompressionType, u64)> {
if data.len() < COMPRESSED_HEADER_LEN {
return None;
}
let compression_type = CompressionType::from_u8(data[0])?;
let uncompressed_size = u64::from_be_bytes(data[1..COMPRESSED_HEADER_LEN].try_into().ok()?);
Some((compression_type, uncompressed_size))
}
fn decompress_zstd_with_header(data: &[u8]) -> Result<Vec<u8>, CompressionError> {
try_decompress_zstd(data, COMPRESSED_HEADER_LEN, read_u64_size)
}
fn zstd_header_len(data: &[u8]) -> Option<usize> {
if has_magic_at(data, COMPRESSED_HEADER_LEN, ZSTD_MAGIC) {
Some(COMPRESSED_HEADER_LEN)
} else {
None
}
}
fn try_decompress_zstd<F>(
data: &[u8],
header_len: usize,
read_size: F,
) -> Result<Vec<u8>, CompressionError>
where
F: Fn(&[u8]) -> Result<u64, CompressionError>,
{
let uncompressed_size = read_size(data)?;
let decompressed = decompress_zstd_impl(&data[header_len..], uncompressed_size)?;
validate_decompressed_len(uncompressed_size, decompressed.len())?;
Ok(decompressed)
}
#[cfg(test)]
fn decompress_delta_with_header(
delta_data: &[u8],
base: &[u8],
) -> Result<Vec<u8>, CompressionError> {
try_decompress_delta(delta_data, base, COMPRESSED_HEADER_LEN, read_u64_size)
}
#[cfg(test)]
fn try_decompress_delta<F>(
delta_data: &[u8],
base: &[u8],
header_len: usize,
read_size: F,
) -> Result<Vec<u8>, CompressionError>
where
F: Fn(&[u8]) -> Result<u64, CompressionError>,
{
let uncompressed_size = read_size(delta_data)?;
let uncompressed_size_usize =
checked_size_to_usize("delta uncompressed size", uncompressed_size)?;
if uncompressed_size > MAX_DELTA_OUTPUT_SIZE as u64 {
return Err(CompressionError::DecompressionFailed(format!(
"delta output size {} exceeds max {}",
uncompressed_size, MAX_DELTA_OUTPUT_SIZE
)));
}
let delta = &delta_data[header_len..];
let decompressed = DeltaDecoder::decode(base, delta, uncompressed_size_usize)
.map_err(|error| CompressionError::DecompressionFailed(error.to_string()))?;
validate_decompressed_len(uncompressed_size, decompressed.len())?;
Ok(decompressed)
}
fn read_u64_size(data: &[u8]) -> Result<u64, CompressionError> {
if data.len() < COMPRESSED_HEADER_LEN {
return Err(CompressionError::CorruptedData(
"compression header truncated".to_string(),
));
}
let recorded_size =
u64::from_be_bytes(data[1..COMPRESSED_HEADER_LEN].try_into().map_err(|_| {
CompressionError::CorruptedData("compression header truncated".to_string())
})?);
validate_size(recorded_size)?;
Ok(recorded_size)
}
fn validate_size(size: u64) -> Result<(), CompressionError> {
if size > MAX_DECOMPRESSED_SIZE {
return Err(CompressionError::SizeLimitExceeded {
size,
max: MAX_DECOMPRESSED_SIZE,
});
}
Ok(())
}
#[cfg(any(feature = "zstd", test))]
fn checked_size_to_usize(field: &str, size: u64) -> Result<usize, CompressionError> {
usize::try_from(size)
.map_err(|_| CompressionError::CorruptedData(format!("{field} exceeds platform limits")))
}
fn validate_decompressed_len(expected: u64, actual: usize) -> Result<(), CompressionError> {
if actual as u64 != expected {
return Err(CompressionError::CorruptedData(format!(
"decompressed size mismatch: expected {expected}, got {actual}",
)));
}
Ok(())
}
fn has_magic_at(data: &[u8], offset: usize, magic: [u8; 4]) -> bool {
data.get(offset..offset + magic.len()) == Some(magic.as_slice())
}
#[cfg(test)]
mod compression_tests;