use crate::xxhash::xxhash64_checksum;
use crate::{MAX_BLOCK_SIZE, ZSTD_MAGIC};
use oxiarc_core::error::Result;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CompressionStrategy {
Raw,
#[default]
RleOnly,
}
#[derive(Debug, Clone)]
pub struct ZstdEncoder {
include_checksum: bool,
include_content_size: bool,
strategy: CompressionStrategy,
}
impl ZstdEncoder {
pub fn new() -> Self {
Self {
include_checksum: true,
include_content_size: true,
strategy: CompressionStrategy::default(),
}
}
pub fn set_checksum(&mut self, include: bool) -> &mut Self {
self.include_checksum = include;
self
}
pub fn set_content_size(&mut self, include: bool) -> &mut Self {
self.include_content_size = include;
self
}
pub fn set_strategy(&mut self, strategy: CompressionStrategy) -> &mut Self {
self.strategy = strategy;
self
}
pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut output = Vec::with_capacity(data.len() + 32);
output.extend_from_slice(&ZSTD_MAGIC);
self.write_frame_header(&mut output, data.len());
self.write_blocks(&mut output, data);
if self.include_checksum {
let checksum = xxhash64_checksum(data);
output.extend_from_slice(&checksum.to_le_bytes());
}
Ok(output)
}
fn write_frame_header(&self, output: &mut Vec<u8>, content_size: usize) {
let mut descriptor: u8 = 0;
if self.include_checksum {
descriptor |= 0x04; }
descriptor |= 0x20;
let (fcs_flag, fcs_bytes) = if !self.include_content_size || content_size == 0 {
(0u8, 1)
} else if content_size <= 255 {
(0u8, 1)
} else if content_size <= 65535 + 256 {
(1u8, 2)
} else if content_size <= u32::MAX as usize {
(2u8, 4)
} else {
(3u8, 8)
};
descriptor |= fcs_flag << 6;
output.push(descriptor);
match fcs_bytes {
1 => {
output.push(content_size as u8);
}
2 => {
let adjusted = (content_size - 256) as u16;
output.extend_from_slice(&adjusted.to_le_bytes());
}
4 => {
output.extend_from_slice(&(content_size as u32).to_le_bytes());
}
8 => {
output.extend_from_slice(&(content_size as u64).to_le_bytes());
}
_ => unreachable!(),
}
}
fn write_blocks(&self, output: &mut Vec<u8>, data: &[u8]) {
if data.is_empty() {
let block_header: u32 = 1; output.push((block_header & 0xFF) as u8);
output.push(((block_header >> 8) & 0xFF) as u8);
output.push(((block_header >> 16) & 0xFF) as u8);
return;
}
let mut offset = 0;
while offset < data.len() {
let remaining = data.len() - offset;
let block_size = remaining.min(MAX_BLOCK_SIZE);
let is_last = offset + block_size >= data.len();
let block_data = &data[offset..offset + block_size];
if self.strategy == CompressionStrategy::RleOnly {
if let Some(rle_byte) = Self::detect_rle(block_data) {
self.write_rle_block(output, rle_byte, block_size, is_last);
offset += block_size;
continue;
}
}
self.write_raw_block(output, block_data, is_last);
offset += block_size;
}
}
fn detect_rle(data: &[u8]) -> Option<u8> {
if data.is_empty() {
return None;
}
let first = data[0];
for chunk in data.chunks(16) {
if !chunk.iter().all(|&b| b == first) {
return None;
}
}
Some(first)
}
fn write_rle_block(&self, output: &mut Vec<u8>, byte: u8, size: usize, is_last: bool) {
let last_flag = if is_last { 1u32 } else { 0u32 };
let block_type = 1u32 << 1; let block_header: u32 = last_flag | block_type | ((size as u32) << 3);
output.push((block_header & 0xFF) as u8);
output.push(((block_header >> 8) & 0xFF) as u8);
output.push(((block_header >> 16) & 0xFF) as u8);
output.push(byte);
}
fn write_raw_block(&self, output: &mut Vec<u8>, data: &[u8], is_last: bool) {
let last_flag = if is_last { 1u32 } else { 0u32 };
let block_header: u32 = last_flag | ((data.len() as u32) << 3);
output.push((block_header & 0xFF) as u8);
output.push(((block_header >> 8) & 0xFF) as u8);
output.push(((block_header >> 16) & 0xFF) as u8);
output.extend_from_slice(data);
}
}
impl Default for ZstdEncoder {
fn default() -> Self {
Self::new()
}
}
pub fn compress(data: &[u8]) -> Result<Vec<u8>> {
ZstdEncoder::new().compress(data)
}
pub fn compress_no_checksum(data: &[u8]) -> Result<Vec<u8>> {
let mut encoder = ZstdEncoder::new();
encoder.set_checksum(false);
encoder.compress(data)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::decompress;
#[test]
fn test_compress_empty() {
let data: &[u8] = &[];
let compressed = compress(data).unwrap();
assert_eq!(&compressed[0..4], &ZSTD_MAGIC);
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_compress_small() {
let data = b"Hello, Zstandard!";
let compressed = compress(data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data.as_slice());
}
#[test]
fn test_compress_larger() {
let data = vec![0x42u8; 1000];
let compressed = compress(&data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_compress_multi_block() {
let data = vec![0xABu8; MAX_BLOCK_SIZE + 1000];
let compressed = compress(&data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_compress_no_checksum() {
let data = b"Test without checksum";
let compressed = compress_no_checksum(data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data.as_slice());
}
#[test]
fn test_encoder_builder() {
let data = b"Builder pattern test";
let mut encoder = ZstdEncoder::new();
encoder.set_checksum(true).set_content_size(true);
let compressed = encoder.compress(data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data.as_slice());
}
#[test]
fn test_various_sizes() {
for size in [0, 1, 10, 100, 255, 256, 257, 1000, 65535, 65536, 100000] {
let data = vec![0x55u8; size];
let compressed = compress(&data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data, "Failed for size {}", size);
}
}
#[test]
fn test_rle_compression() {
let data = vec![0xAAu8; 10000];
let compressed = compress(&data).unwrap();
assert!(
compressed.len() < data.len() / 10,
"RLE compression failed: {} vs {}",
compressed.len(),
data.len()
);
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_rle_multi_block() {
let data = vec![0xBBu8; MAX_BLOCK_SIZE * 3];
let compressed = compress(&data).unwrap();
assert!(
compressed.len() < 100,
"Expected small output, got {}",
compressed.len()
);
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_rle_mixed_data() {
let mut data = vec![0xCCu8; 1000]; data.extend_from_slice(b"Hello, World!"); data.extend_from_slice(&vec![0xDDu8; 1000]);
let compressed = compress(&data).unwrap();
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_detect_rle() {
assert_eq!(ZstdEncoder::detect_rle(&[0xAA; 100]), Some(0xAA));
assert_eq!(ZstdEncoder::detect_rle(&[0x00; 50]), Some(0x00));
assert_eq!(ZstdEncoder::detect_rle(&[0xFF]), Some(0xFF));
assert_eq!(ZstdEncoder::detect_rle(&[0xAA, 0xAA, 0xBB]), None);
assert_eq!(ZstdEncoder::detect_rle(&[0x00, 0x01]), None);
assert_eq!(ZstdEncoder::detect_rle(&[]), None);
}
#[test]
fn test_raw_strategy() {
let data = vec![0xEEu8; 1000];
let mut encoder = ZstdEncoder::new();
encoder.set_strategy(CompressionStrategy::Raw);
let compressed = encoder.compress(&data).unwrap();
assert!(compressed.len() > data.len());
let decompressed = decompress(&compressed).unwrap();
assert_eq!(decompressed, data);
}
}