use crate::block::{decode_raw_block, decode_rle_block, LiteralsSection, SequencesSection};
use crate::frame::{xxhash64, BlockHeader, BlockType, FrameHeader, ZSTD_MAGIC};
use haagenti_core::{Error, Result};
#[derive(Debug)]
pub struct DecompressContext {
output: Vec<u8>,
#[allow(dead_code)]
window_size: usize,
repeat_offsets: [u32; 3],
}
impl DecompressContext {
pub fn new(window_size: usize) -> Self {
Self {
output: Vec::with_capacity(window_size.min(1024 * 1024)),
window_size,
repeat_offsets: [1, 4, 8], }
}
pub fn output(&self) -> &[u8] {
&self.output
}
pub fn into_output(self) -> Vec<u8> {
self.output
}
pub fn update_offsets(&mut self, offset: u32) {
if offset != self.repeat_offsets[0] {
self.repeat_offsets[2] = self.repeat_offsets[1];
self.repeat_offsets[1] = self.repeat_offsets[0];
self.repeat_offsets[0] = offset;
}
}
pub fn get_repeat_offset(&self, code: u32) -> u32 {
match code {
1 => self.repeat_offsets[0],
2 => self.repeat_offsets[1],
3 => self.repeat_offsets[2],
_ => code, }
}
}
pub fn decompress_frame(input: &[u8]) -> Result<Vec<u8>> {
if input.len() < 4 {
return Err(Error::corrupted("Input too short for Zstd frame"));
}
let magic = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
if magic != ZSTD_MAGIC {
return Err(Error::corrupted(format!(
"Invalid Zstd magic: expected 0x{:08X}, got 0x{:08X}",
ZSTD_MAGIC, magic
)));
}
let header = FrameHeader::parse(&input[4..])?;
let mut ctx = DecompressContext::new(header.window_size);
let mut pos = header.header_size;
loop {
if pos + BlockHeader::SIZE > input.len() {
return Err(Error::corrupted("Frame truncated at block header"));
}
let block_header = BlockHeader::parse(&input[pos..])?;
pos += BlockHeader::SIZE;
let compressed_size = block_header.compressed_size();
if pos + compressed_size > input.len() {
return Err(Error::corrupted("Frame truncated at block data"));
}
let block_data = &input[pos..pos + compressed_size];
pos += compressed_size;
match block_header.block_type {
BlockType::Raw => {
decode_raw_block(block_data, &mut ctx.output)?;
}
BlockType::Rle => {
decode_rle_block(
block_data,
block_header.decompressed_size(),
&mut ctx.output,
)?;
}
BlockType::Compressed => {
decode_compressed_block(block_data, &mut ctx)?;
}
BlockType::Reserved => {
return Err(Error::corrupted("Reserved block type"));
}
}
if block_header.last_block {
break;
}
}
if header.has_checksum {
if pos + 4 > input.len() {
return Err(Error::corrupted("Frame truncated at checksum"));
}
let expected =
u32::from_le_bytes([input[pos], input[pos + 1], input[pos + 2], input[pos + 3]]);
let actual = (xxhash64(&ctx.output, 0) & 0xFFFFFFFF) as u32;
if expected != actual {
return Err(Error::corrupted(format!(
"Checksum mismatch: expected 0x{:08X}, got 0x{:08X}",
expected, actual
)));
}
}
if let Some(expected_size) = header.frame_content_size {
if ctx.output.len() as u64 != expected_size {
return Err(Error::corrupted(format!(
"Content size mismatch: expected {}, got {}",
expected_size,
ctx.output.len()
)));
}
}
Ok(ctx.into_output())
}
pub fn decompress_frame_with_dict(
input: &[u8],
dict: Option<&crate::dictionary::ZstdDictionary>,
) -> Result<Vec<u8>> {
if dict.is_none() {
return decompress_frame(input);
}
let dictionary = dict.unwrap();
if input.len() < 4 {
return Err(Error::corrupted("Input too short for Zstd frame"));
}
let magic = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
if magic != ZSTD_MAGIC {
return Err(Error::corrupted(format!(
"Invalid Zstd magic: expected 0x{:08X}, got 0x{:08X}",
ZSTD_MAGIC, magic
)));
}
let header = FrameHeader::parse(&input[4..])?;
let mut ctx = DecompressContext::new(header.window_size);
ctx.output.extend_from_slice(dictionary.content());
let dict_len = dictionary.content().len();
let mut pos = header.header_size;
loop {
if pos + BlockHeader::SIZE > input.len() {
return Err(Error::corrupted("Frame truncated at block header"));
}
let block_header = BlockHeader::parse(&input[pos..])?;
pos += BlockHeader::SIZE;
let compressed_size = block_header.compressed_size();
if pos + compressed_size > input.len() {
return Err(Error::corrupted("Frame truncated at block data"));
}
let block_data = &input[pos..pos + compressed_size];
pos += compressed_size;
match block_header.block_type {
BlockType::Raw => {
decode_raw_block(block_data, &mut ctx.output)?;
}
BlockType::Rle => {
decode_rle_block(
block_data,
block_header.decompressed_size(),
&mut ctx.output,
)?;
}
BlockType::Compressed => {
decode_compressed_block(block_data, &mut ctx)?;
}
BlockType::Reserved => {
return Err(Error::corrupted("Reserved block type"));
}
}
if block_header.last_block {
break;
}
}
if header.has_checksum {
if pos + 4 > input.len() {
return Err(Error::corrupted("Frame truncated at checksum"));
}
let expected =
u32::from_le_bytes([input[pos], input[pos + 1], input[pos + 2], input[pos + 3]]);
let content = &ctx.output[dict_len..];
let actual = (xxhash64(content, 0) & 0xFFFFFFFF) as u32;
if expected != actual {
return Err(Error::corrupted(format!(
"Checksum mismatch: expected 0x{:08X}, got 0x{:08X}",
expected, actual
)));
}
}
if let Some(expected_size) = header.frame_content_size {
let actual_size = (ctx.output.len() - dict_len) as u64;
if actual_size != expected_size {
return Err(Error::corrupted(format!(
"Content size mismatch: expected {}, got {}",
expected_size, actual_size
)));
}
}
Ok(ctx.output[dict_len..].to_vec())
}
fn decode_compressed_block(input: &[u8], ctx: &mut DecompressContext) -> Result<()> {
if input.is_empty() {
return Err(Error::corrupted("Empty compressed block"));
}
let (literals, literals_consumed) = LiteralsSection::parse(input)?;
let sequences_data = &input[literals_consumed..];
let sequences = SequencesSection::parse(sequences_data, &literals)?;
execute_sequences(&literals, &sequences, ctx)?;
Ok(())
}
fn execute_sequences(
literals: &LiteralsSection,
sequences: &SequencesSection,
ctx: &mut DecompressContext,
) -> Result<()> {
let literal_bytes = literals.data();
let mut literal_pos = 0;
let total_output: usize = sequences
.sequences
.iter()
.map(|s| s.literal_length as usize + s.match_length as usize)
.sum();
ctx.output
.reserve(total_output + literal_bytes.len() - literal_pos);
for seq in &sequences.sequences {
let literal_end = literal_pos + seq.literal_length as usize;
if literal_end > literal_bytes.len() {
return Err(Error::corrupted(
"Literal length exceeds available literals",
));
}
ctx.output
.extend_from_slice(&literal_bytes[literal_pos..literal_end]);
literal_pos = literal_end;
let offset = seq.offset as usize;
let match_length = seq.match_length as usize;
if match_length > 0 && offset > 0 {
let out_len = ctx.output.len();
if offset > out_len {
return Err(Error::corrupted(format!(
"Match offset {} exceeds output size {}",
offset, out_len
)));
}
let match_start = out_len - offset;
if offset >= match_length {
ctx.output
.extend_from_within(match_start..match_start + match_length);
} else {
copy_match_overlapping(&mut ctx.output, match_start, offset, match_length);
}
}
}
if literal_pos < literal_bytes.len() {
ctx.output.extend_from_slice(&literal_bytes[literal_pos..]);
}
Ok(())
}
#[inline(always)]
fn copy_match_overlapping(
output: &mut Vec<u8>,
match_start: usize,
offset: usize,
match_length: usize,
) {
output.reserve(match_length);
let out_len = output.len();
unsafe {
output.set_len(out_len + match_length);
let dst = output.as_mut_ptr().add(out_len);
let src_base = output.as_ptr().add(match_start);
match offset {
1 => {
let byte = *src_base;
core::ptr::write_bytes(dst, byte, match_length);
}
2 => {
let pattern = core::ptr::read_unaligned(src_base as *const u16);
let mut i = 0;
while i + 2 <= match_length {
core::ptr::write_unaligned(dst.add(i) as *mut u16, pattern);
i += 2;
}
if i < match_length {
*dst.add(i) = *src_base;
}
}
3 => {
for i in 0..match_length {
*dst.add(i) = *src_base.add(i % 3);
}
}
4 => {
let pattern = core::ptr::read_unaligned(src_base as *const u32);
let mut i = 0;
while i + 4 <= match_length {
core::ptr::write_unaligned(dst.add(i) as *mut u32, pattern);
i += 4;
}
while i < match_length {
*dst.add(i) = *src_base.add(i % 4);
i += 1;
}
}
5..=7 => {
for i in 0..match_length {
*dst.add(i) = *src_base.add(i % offset);
}
}
_ => {
let mut i = 0;
while i + offset <= match_length {
core::ptr::copy_nonoverlapping(src_base, dst.add(i), offset);
i += offset;
}
if i < match_length {
core::ptr::copy_nonoverlapping(src_base, dst.add(i), match_length - i);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decompress_context_creation() {
let ctx = DecompressContext::new(1024);
assert_eq!(ctx.window_size, 1024);
assert!(ctx.output.is_empty());
}
#[test]
fn test_repeat_offsets() {
let mut ctx = DecompressContext::new(1024);
assert_eq!(ctx.get_repeat_offset(1), 1);
assert_eq!(ctx.get_repeat_offset(2), 4);
assert_eq!(ctx.get_repeat_offset(3), 8);
ctx.update_offsets(100);
assert_eq!(ctx.get_repeat_offset(1), 100);
assert_eq!(ctx.get_repeat_offset(2), 1);
assert_eq!(ctx.get_repeat_offset(3), 4);
ctx.update_offsets(200);
assert_eq!(ctx.get_repeat_offset(1), 200);
assert_eq!(ctx.get_repeat_offset(2), 100);
assert_eq!(ctx.get_repeat_offset(3), 1);
}
#[test]
fn test_repeat_offset_same_value() {
let mut ctx = DecompressContext::new(1024);
ctx.update_offsets(100);
ctx.update_offsets(100);
assert_eq!(ctx.get_repeat_offset(1), 100);
assert_eq!(ctx.get_repeat_offset(2), 1);
}
#[test]
fn test_magic_validation() {
let result = decompress_frame(&[0x00, 0x00, 0x00, 0x00]);
assert!(result.is_err());
let result = decompress_frame(&[0x28, 0xB5]);
assert!(result.is_err());
}
#[test]
fn test_valid_magic() {
let data = [0x28, 0xB5, 0x2F, 0xFD, 0x00];
let result = decompress_frame(&data);
assert!(result.is_err());
}
#[test]
fn test_simple_raw_frame() {
let mut frame = vec![];
frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
frame.push(0x20);
frame.push(5);
frame.extend_from_slice(&[0x29, 0x00, 0x00]);
frame.extend_from_slice(b"Hello");
let result = decompress_frame(&frame).unwrap();
assert_eq!(result, b"Hello");
}
#[test]
fn test_rle_frame() {
let mut frame = vec![];
frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
frame.push(0x20);
frame.push(10);
frame.extend_from_slice(&[0x53, 0x00, 0x00]);
frame.push(b'X');
let result = decompress_frame(&frame).unwrap();
assert_eq!(result, vec![b'X'; 10]);
}
#[test]
fn test_multi_block_frame() {
let mut frame = vec![];
frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
frame.push(0x20);
frame.push(8);
frame.extend_from_slice(&[0x28, 0x00, 0x00]);
frame.extend_from_slice(b"Hello");
frame.extend_from_slice(&[0x19, 0x00, 0x00]);
frame.extend_from_slice(b"!!!");
let result = decompress_frame(&frame).unwrap();
assert_eq!(result, b"Hello!!!");
}
#[test]
fn test_content_size_mismatch() {
let mut frame = vec![];
frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
frame.push(0x20);
frame.push(10);
frame.extend_from_slice(&[0x29, 0x00, 0x00]);
frame.extend_from_slice(b"Hello");
let result = decompress_frame(&frame);
assert!(result.is_err());
}
#[test]
fn test_frame_with_checksum() {
let mut frame = vec![];
frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
frame.push(0x24);
frame.push(5);
frame.extend_from_slice(&[0x29, 0x00, 0x00]);
frame.extend_from_slice(b"Hello");
let hash = xxhash64(b"Hello", 0);
let checksum = (hash & 0xFFFFFFFF) as u32;
frame.extend_from_slice(&checksum.to_le_bytes());
let result = decompress_frame(&frame).unwrap();
assert_eq!(result, b"Hello");
}
#[test]
fn test_checksum_mismatch() {
let mut frame = vec![];
frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
frame.push(0x24);
frame.push(5);
frame.extend_from_slice(&[0x29, 0x00, 0x00]);
frame.extend_from_slice(b"Hello");
frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]);
let result = decompress_frame(&frame);
assert!(result.is_err());
}
#[test]
fn test_compressed_block_literals_only() {
let mut frame = vec![];
frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
frame.push(0x20);
frame.push(5);
let literals = b"Hello";
let compressed_block = build_compressed_block_literals_only(literals);
let block_size = compressed_block.len();
let header = (block_size << 3) | 5;
frame.push((header & 0xFF) as u8);
frame.push(((header >> 8) & 0xFF) as u8);
frame.push(((header >> 16) & 0xFF) as u8);
frame.extend_from_slice(&compressed_block);
let result = decompress_frame(&frame).unwrap();
assert_eq!(result, b"Hello");
}
fn build_compressed_block_literals_only(literals: &[u8]) -> Vec<u8> {
let mut block = vec![];
let size = literals.len();
if size <= 31 {
block.push(((size << 3) | 0) as u8);
} else if size <= 4095 {
let byte0 = ((size & 0xF) << 4) | (1 << 2);
let byte1 = (size >> 4) & 0xFF;
block.push(byte0 as u8);
block.push(byte1 as u8);
} else {
unreachable!("Size too large for test");
}
block.extend_from_slice(literals);
block.push(0);
block
}
#[test]
fn test_compressed_block_with_rle_literals() {
let mut frame = vec![];
frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
frame.push(0x20);
frame.push(10);
let compressed_block = build_compressed_block_rle_literals(b'A', 10);
let block_size = compressed_block.len();
let header = (block_size << 3) | 5; frame.push((header & 0xFF) as u8);
frame.push(((header >> 8) & 0xFF) as u8);
frame.push(((header >> 16) & 0xFF) as u8);
frame.extend_from_slice(&compressed_block);
let result = decompress_frame(&frame).unwrap();
assert_eq!(result, vec![b'A'; 10]);
}
fn build_compressed_block_rle_literals(byte: u8, repeat_count: usize) -> Vec<u8> {
let mut block = vec![];
if repeat_count <= 31 {
block.push(((repeat_count << 3) | 1) as u8);
} else if repeat_count <= 4095 {
let byte0 = ((repeat_count & 0xF) << 4) | (1 << 2) | 1;
let byte1 = (repeat_count >> 4) & 0xFF;
block.push(byte0 as u8);
block.push(byte1 as u8);
} else {
unreachable!("Size too large for test");
}
block.push(byte);
block.push(0);
block
}
#[test]
fn test_compressed_block_multi_literals() {
let mut frame = vec![];
frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
frame.push(0x40);
let literals: Vec<u8> = (0..100).map(|i| (i % 256) as u8).collect();
let mut frame = vec![];
frame.extend_from_slice(&[0x28, 0xB5, 0x2F, 0xFD]);
frame.push(0x20); frame.push(100);
let compressed_block = build_compressed_block_literals_only(&literals);
let block_size = compressed_block.len();
let header = (block_size << 3) | 5;
frame.push((header & 0xFF) as u8);
frame.push(((header >> 8) & 0xFF) as u8);
frame.push(((header >> 16) & 0xFF) as u8);
frame.extend_from_slice(&compressed_block);
let result = decompress_frame(&frame).unwrap();
assert_eq!(result, literals);
}
}