use crate::compressed_block::encode_compressed_block;
use crate::lz77::{LevelConfig, MatchFinder};
use crate::xxhash::xxhash64_checksum;
use crate::{MAX_BLOCK_SIZE, ZSTD_MAGIC};
use oxiarc_core::error::Result;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CompressionStrategy {
Raw,
#[default]
RleOnly,
}
#[derive(Debug, Clone)]
pub struct ZstdEncoder {
include_checksum: bool,
include_content_size: bool,
strategy: CompressionStrategy,
level: i32,
dictionary: Option<Vec<u8>>,
dict_id: Option<u32>,
}
impl ZstdEncoder {
pub fn new() -> Self {
Self {
include_checksum: true,
include_content_size: true,
strategy: CompressionStrategy::default(),
level: 0,
dictionary: None,
dict_id: None,
}
}
pub fn set_checksum(&mut self, include: bool) -> &mut Self {
self.include_checksum = include;
self
}
pub fn set_content_size(&mut self, include: bool) -> &mut Self {
self.include_content_size = include;
self
}
pub fn set_strategy(&mut self, strategy: CompressionStrategy) -> &mut Self {
self.strategy = strategy;
self
}
pub fn set_level(&mut self, level: i32) -> &mut Self {
self.level = level.clamp(0, 22);
self
}
pub fn set_dictionary(&mut self, dict: &[u8]) -> &mut Self {
if dict.is_empty() {
self.dictionary = None;
self.dict_id = None;
} else {
let id = crate::xxhash::xxhash64(dict) as u32;
self.dictionary = Some(dict.to_vec());
self.dict_id = Some(id);
}
self
}
pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut output = Vec::with_capacity(data.len() + 32);
output.extend_from_slice(&ZSTD_MAGIC);
self.write_frame_header(&mut output, data.len());
if self.level > 0 {
self.write_compressed_blocks(&mut output, data)?;
} else {
self.write_blocks(&mut output, data);
}
if self.include_checksum {
let checksum = xxhash64_checksum(data);
output.extend_from_slice(&checksum.to_le_bytes());
}
Ok(output)
}
#[cfg(feature = "parallel")]
pub fn compress_parallel(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut output = Vec::with_capacity(data.len() + 32);
output.extend_from_slice(&ZSTD_MAGIC);
self.write_frame_header(&mut output, data.len());
if data.is_empty() {
write_empty_block(&mut output);
} else {
let chunks: Vec<&[u8]> = data.chunks(MAX_BLOCK_SIZE).collect();
let block_data: Vec<(bool, Vec<u8>)> = chunks
.par_iter()
.enumerate()
.map(|(idx, chunk)| {
let is_last = idx == chunks.len() - 1;
if self.strategy == CompressionStrategy::RleOnly {
if let Some(rle_byte) = detect_rle(chunk) {
let mut block_output = Vec::new();
write_rle_block_to(&mut block_output, rle_byte, chunk.len(), is_last);
return (is_last, block_output);
}
}
let mut block_output = Vec::new();
write_raw_block_to(&mut block_output, chunk, is_last);
(is_last, block_output)
})
.collect();
for (_is_last, block_bytes) in block_data {
output.extend_from_slice(&block_bytes);
}
}
if self.include_checksum {
let checksum = xxhash64_checksum(data);
output.extend_from_slice(&checksum.to_le_bytes());
}
Ok(output)
}
fn write_frame_header(&self, output: &mut Vec<u8>, content_size: usize) {
let mut descriptor: u8 = 0;
if self.include_checksum {
descriptor |= 0x04; }
descriptor |= 0x20;
let dict_id_flag = if self.dict_id.is_some() { 3u8 } else { 0u8 };
descriptor |= dict_id_flag;
let (fcs_flag, fcs_bytes) = if !self.include_content_size || content_size <= 255 {
(0u8, 1)
} else if content_size <= 65535 + 256 {
(1u8, 2)
} else if content_size <= u32::MAX as usize {
(2u8, 4)
} else {
(3u8, 8)
};
descriptor |= fcs_flag << 6;
output.push(descriptor);
if let Some(id) = self.dict_id {
output.extend_from_slice(&id.to_le_bytes());
}
match fcs_bytes {
1 => {
output.push(content_size as u8);
}
2 => {
let adjusted = (content_size - 256) as u16;
output.extend_from_slice(&adjusted.to_le_bytes());
}
4 => {
output.extend_from_slice(&(content_size as u32).to_le_bytes());
}
8 => {
output.extend_from_slice(&(content_size as u64).to_le_bytes());
}
_ => unreachable!(),
}
}
fn write_blocks(&self, output: &mut Vec<u8>, data: &[u8]) {
if data.is_empty() {
write_empty_block(output);
return;
}
let mut offset = 0;
while offset < data.len() {
let remaining = data.len() - offset;
let block_size = remaining.min(MAX_BLOCK_SIZE);
let is_last = offset + block_size >= data.len();
let block_data = &data[offset..offset + block_size];
if self.strategy == CompressionStrategy::RleOnly {
if let Some(rle_byte) = detect_rle(block_data) {
write_rle_block_to(output, rle_byte, block_size, is_last);
offset += block_size;
continue;
}
}
write_raw_block_to(output, block_data, is_last);
offset += block_size;
}
}
fn write_compressed_blocks(&self, output: &mut Vec<u8>, data: &[u8]) -> Result<()> {
if data.is_empty() {
write_empty_block(output);
return Ok(());
}
let config = LevelConfig::for_level(self.level);
let mut finder = MatchFinder::new(&config);
let dict = self.dictionary.as_deref().unwrap_or(&[]);
let mut offset = 0;
while offset < data.len() {
let remaining = data.len() - offset;
let block_size = remaining.min(config.target_block_size);
let is_last = offset + block_size >= data.len();
let block_data = &data[offset..offset + block_size];
if let Some(rle_byte) = detect_rle(block_data) {
write_rle_block_to(output, rle_byte, block_size, is_last);
offset += block_size;
continue;
}
let sequences = finder.find_sequences(block_data, dict)?;
match encode_compressed_block(&sequences) {
Ok(compressed_content) => {
if compressed_content.len() < block_data.len() {
write_compressed_block_to(output, &compressed_content, is_last);
} else {
write_raw_block_to(output, block_data, is_last);
}
}
Err(_) => {
write_raw_block_to(output, block_data, is_last);
}
}
finder.reset();
offset += block_size;
}
Ok(())
}
}
impl Default for ZstdEncoder {
fn default() -> Self {
Self::new()
}
}
fn write_empty_block(output: &mut Vec<u8>) {
let block_header: u32 = 1; output.push((block_header & 0xFF) as u8);
output.push(((block_header >> 8) & 0xFF) as u8);
output.push(((block_header >> 16) & 0xFF) as u8);
}
fn write_raw_block_to(output: &mut Vec<u8>, data: &[u8], is_last: bool) {
let last_flag = if is_last { 1u32 } else { 0u32 };
let block_header: u32 = last_flag | ((data.len() as u32) << 3);
output.push((block_header & 0xFF) as u8);
output.push(((block_header >> 8) & 0xFF) as u8);
output.push(((block_header >> 16) & 0xFF) as u8);
output.extend_from_slice(data);
}
fn write_rle_block_to(output: &mut Vec<u8>, byte: u8, size: usize, is_last: bool) {
let last_flag = if is_last { 1u32 } else { 0u32 };
let block_type = 1u32 << 1; let block_header: u32 = last_flag | block_type | ((size as u32) << 3);
output.push((block_header & 0xFF) as u8);
output.push(((block_header >> 8) & 0xFF) as u8);
output.push(((block_header >> 16) & 0xFF) as u8);
output.push(byte);
}
fn write_compressed_block_to(output: &mut Vec<u8>, content: &[u8], is_last: bool) {
let last_flag = if is_last { 1u32 } else { 0u32 };
let block_type = 2u32 << 1; let block_header: u32 = last_flag | block_type | ((content.len() as u32) << 3);
output.push((block_header & 0xFF) as u8);
output.push(((block_header >> 8) & 0xFF) as u8);
output.push(((block_header >> 16) & 0xFF) as u8);
output.extend_from_slice(content);
}
fn detect_rle(data: &[u8]) -> Option<u8> {
if data.is_empty() {
return None;
}
let first = data[0];
for chunk in data.chunks(16) {
if !chunk.iter().all(|&b| b == first) {
return None;
}
}
Some(first)
}
pub fn compress(data: &[u8]) -> Result<Vec<u8>> {
ZstdEncoder::new().compress(data)
}
pub fn compress_with_level(data: &[u8], level: i32) -> Result<Vec<u8>> {
let mut encoder = ZstdEncoder::new();
encoder.set_level(level);
encoder.compress(data)
}
pub fn compress_no_checksum(data: &[u8]) -> Result<Vec<u8>> {
let mut encoder = ZstdEncoder::new();
encoder.set_checksum(false);
encoder.compress(data)
}
#[cfg(feature = "parallel")]
pub fn compress_parallel(data: &[u8]) -> Result<Vec<u8>> {
ZstdEncoder::new().compress_parallel(data)
}
pub fn encode_all(data: &[u8], level: i32) -> Result<Vec<u8>> {
compress_with_level(data, level)
}
pub fn decode_all(data: &[u8]) -> Result<Vec<u8>> {
crate::decompress(data)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::decompress;
#[test]
fn test_compress_empty() {
let data: &[u8] = &[];
let compressed = compress(data).unwrap();
assert_eq!(&compressed[0..4], &ZSTD_MAGIC);
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_compress_small() {
let data = b"Hello, Zstandard!";
let compressed = compress(data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data.as_slice());
}
#[test]
fn test_compress_larger() {
let data = vec![0x42u8; 1000];
let compressed = compress(&data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_compress_multi_block() {
let data = vec![0xABu8; MAX_BLOCK_SIZE + 1000];
let compressed = compress(&data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_compress_no_checksum() {
let data = b"Test without checksum";
let compressed = compress_no_checksum(data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data.as_slice());
}
#[test]
fn test_encoder_builder() {
let data = b"Builder pattern test";
let mut encoder = ZstdEncoder::new();
encoder.set_checksum(true).set_content_size(true);
let compressed = encoder.compress(data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data.as_slice());
}
#[test]
fn test_various_sizes() {
for size in [0, 1, 10, 100, 255, 256, 257, 1000, 65535, 65536, 100000] {
let data = vec![0x55u8; size];
let compressed = compress(&data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data, "Failed for size {}", size);
}
}
#[test]
fn test_rle_compression() {
let data = vec![0xAAu8; 10000];
let compressed = compress(&data).unwrap();
assert!(
compressed.len() < data.len() / 10,
"RLE compression failed: {} vs {}",
compressed.len(),
data.len()
);
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_rle_multi_block() {
let data = vec![0xBBu8; MAX_BLOCK_SIZE * 3];
let compressed = compress(&data).unwrap();
assert!(
compressed.len() < 100,
"Expected small output, got {}",
compressed.len()
);
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_rle_mixed_data() {
let mut data = vec![0xCCu8; 1000];
data.extend_from_slice(b"Hello, World!");
data.extend_from_slice(&vec![0xDDu8; 1000]);
let compressed = compress(&data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_detect_rle() {
assert_eq!(detect_rle(&[0xAA; 100]), Some(0xAA));
assert_eq!(detect_rle(&[0x00; 50]), Some(0x00));
assert_eq!(detect_rle(&[0xFF]), Some(0xFF));
assert_eq!(detect_rle(&[0xAA, 0xAA, 0xBB]), None);
assert_eq!(detect_rle(&[0x00, 0x01]), None);
assert_eq!(detect_rle(&[]), None);
}
#[test]
fn test_raw_strategy() {
let data = vec![0xEEu8; 1000];
let mut encoder = ZstdEncoder::new();
encoder.set_strategy(CompressionStrategy::Raw);
let compressed = encoder.compress(&data).unwrap();
assert!(compressed.len() > data.len());
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_compress_with_level() {
let data = b"The quick brown fox jumps over the lazy dog. \
The quick brown fox jumps over the lazy dog. \
The quick brown fox jumps over the lazy dog.";
for level in [1, 3, 6, 9, 15, 22] {
let compressed = compress_with_level(data, level).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(
decompressed,
data.as_slice(),
"Roundtrip failed for level {}",
level
);
}
}
#[test]
fn test_encode_all_decode_all() {
let data = b"Testing encode_all and decode_all convenience functions";
let compressed = encode_all(data, 3).unwrap();
let decompressed = decode_all(&compressed).unwrap();
assert_eq!(decompressed, data.as_slice());
}
#[test]
fn test_level_compression_ratio() {
let mut data = Vec::new();
for _ in 0..100 {
data.extend_from_slice(b"ABCDEFGHIJKLMNOP");
}
let raw = compress(&data).unwrap();
let level3 = compress_with_level(&data, 3).unwrap();
assert!(
level3.len() <= raw.len(),
"Level 3 ({}) should be <= raw ({}) for repetitive data",
level3.len(),
raw.len()
);
assert_eq!(decompress(&raw).unwrap(), data);
assert_eq!(decompress(&level3).unwrap(), data);
}
#[test]
fn test_large_data_roundtrip() {
let mut data = Vec::with_capacity(16384);
let pattern = b"RDF triple: <http://example.org/subject> <http://example.org/predicate> \"value\" .\n";
while data.len() < 16384 {
data.extend_from_slice(pattern);
}
data.truncate(16384);
for level in [1, 3] {
let compressed = encode_all(&data, level).unwrap();
let decompressed = decode_all(&compressed).unwrap();
assert_eq!(
decompressed, data,
"Large roundtrip failed for level {}",
level
);
}
}
#[test]
#[cfg(feature = "parallel")]
fn test_parallel_roundtrip_basic() {
let data = b"Hello, World! Parallel Zstandard compression.";
let compressed = compress_parallel(data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data.as_slice());
}
#[test]
#[cfg(feature = "parallel")]
fn test_parallel_roundtrip_large() {
let data = vec![0xABu8; 5_000_000];
let compressed = compress_parallel(&data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
#[cfg(feature = "parallel")]
fn test_parallel_rle_compression() {
let data = vec![0xCCu8; 2_000_000];
let compressed = compress_parallel(&data).unwrap();
assert!(compressed.len() < data.len() / 100);
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
#[cfg(feature = "parallel")]
fn test_parallel_empty() {
let data: &[u8] = &[];
let compressed = compress_parallel(data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
#[cfg(feature = "parallel")]
fn test_parallel_vs_serial() {
let data = b"Testing parallel vs serial compression output.";
let serial = compress(data).unwrap();
let parallel = compress_parallel(data).unwrap();
let serial_decompressed = decompress(&serial).unwrap();
let parallel_decompressed = decompress(¶llel).unwrap();
assert_eq!(serial_decompressed, data.as_slice());
assert_eq!(parallel_decompressed, data.as_slice());
}
#[test]
#[cfg(feature = "parallel")]
fn test_parallel_encoder_options() {
let data = vec![0xFFu8; 1_000_000];
let mut encoder = ZstdEncoder::new();
encoder
.set_checksum(false)
.set_strategy(CompressionStrategy::RleOnly);
let compressed = encoder.compress_parallel(&data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
#[cfg(feature = "parallel")]
fn test_parallel_multi_block() {
let data = vec![0x55u8; MAX_BLOCK_SIZE * 3 + 5000];
let compressed = compress_parallel(&data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
}