use super::{CompressResult, CompressionError, Compressor};
pub struct ZstdCompressor {
pub level: i32,
pub nb_workers: u32,
}
#[inline]
fn map_zstd_err<E: std::fmt::Display>(e: E) -> CompressionError {
CompressionError::Zstd(e.to_string())
}
impl Compressor for ZstdCompressor {
fn compress(&self, data: &[u8]) -> Result<CompressResult, CompressionError> {
if self.nb_workers == 0 {
let compressed = zstd::encode_all(data, self.level).map_err(map_zstd_err)?;
return Ok(CompressResult {
data: compressed,
block_offsets: None,
});
}
let mut cctx = zstd::bulk::Compressor::new(self.level).map_err(map_zstd_err)?;
cctx.set_parameter(zstd::zstd_safe::CParameter::NbWorkers(self.nb_workers))
.map_err(map_zstd_err)?;
let compressed = cctx.compress(data).map_err(map_zstd_err)?;
Ok(CompressResult {
data: compressed,
block_offsets: None,
})
}
fn decompress(&self, data: &[u8], _expected_size: usize) -> Result<Vec<u8>, CompressionError> {
zstd::decode_all(data).map_err(map_zstd_err)
}
fn decompress_range(
&self,
_data: &[u8],
_block_offsets: &[u64],
_byte_pos: usize,
_byte_size: usize,
) -> Result<Vec<u8>, CompressionError> {
Err(CompressionError::RangeNotSupported)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zstd_round_trip() {
let data: Vec<u8> = (0..4096).map(|i| (i % 256) as u8).collect();
let compressor = ZstdCompressor {
level: 3,
nb_workers: 0,
};
let result = compressor.compress(&data).unwrap();
assert!(result.block_offsets.is_none());
assert!(result.data.len() < data.len());
let decompressed = compressor.decompress(&result.data, data.len()).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn zstd_range_not_supported() {
let compressor = ZstdCompressor {
level: 3,
nb_workers: 0,
};
let result = compressor.decompress_range(&[0], &[], 0, 1);
assert!(matches!(result, Err(CompressionError::RangeNotSupported)));
}
#[test]
fn zstd_nb_workers_zero_matches_encode_all() {
let data: Vec<u8> = (0..32 * 1024).map(|i| ((i * 7) % 256) as u8).collect();
let via_struct = ZstdCompressor {
level: 3,
nb_workers: 0,
}
.compress(&data)
.unwrap()
.data;
let via_helper = zstd::encode_all(data.as_slice(), 3).unwrap();
assert_eq!(via_struct, via_helper);
}
#[test]
fn zstd_nb_workers_round_trip_lossless() {
let data: Vec<u8> = (0..256 * 1024).map(|i| ((i * 31) % 256) as u8).collect();
for n in [1u32, 2, 4, 8] {
let c = ZstdCompressor {
level: 3,
nb_workers: n,
};
let out = c.compress(&data).unwrap();
let rt = c.decompress(&out.data, data.len()).unwrap();
assert_eq!(rt, data, "zstd nb_workers={n} round-trip failure");
}
}
}