use crate::fse::BitReader;
use crate::huffman::{build_table_from_weights, parse_huffman_weights, HuffmanDecoder};
use haagenti_core::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LiteralsBlockType {
Raw,
Rle,
Compressed,
Treeless,
}
impl LiteralsBlockType {
pub fn from_field(field: u8) -> Self {
match field {
0 => LiteralsBlockType::Raw,
1 => LiteralsBlockType::Rle,
2 => LiteralsBlockType::Compressed,
3 => LiteralsBlockType::Treeless,
_ => unreachable!(),
}
}
}
#[derive(Debug, Clone)]
pub struct LiteralsSection {
pub block_type: LiteralsBlockType,
pub regenerated_size: usize,
pub compressed_size: usize,
data: Vec<u8>,
}
impl LiteralsSection {
pub fn new_raw(data: Vec<u8>) -> Self {
let size = data.len();
Self {
block_type: LiteralsBlockType::Raw,
regenerated_size: size,
compressed_size: size,
data,
}
}
pub fn parse(input: &[u8]) -> Result<(Self, usize)> {
if input.is_empty() {
return Err(Error::corrupted("Empty literals section"));
}
let header_byte = input[0];
let block_type = LiteralsBlockType::from_field(header_byte & 0x03);
let size_format = (header_byte >> 2) & 0x03;
match block_type {
LiteralsBlockType::Raw | LiteralsBlockType::Rle => {
Self::parse_raw_rle(input, block_type, size_format)
}
LiteralsBlockType::Compressed | LiteralsBlockType::Treeless => {
Self::parse_compressed(input, block_type, size_format)
}
}
}
fn parse_raw_rle(
input: &[u8],
block_type: LiteralsBlockType,
size_format: u8,
) -> Result<(Self, usize)> {
let (regenerated_size, header_size) = match size_format {
0 | 2 => {
let size = (input[0] >> 3) as usize;
(size, 1)
}
1 => {
if input.len() < 2 {
return Err(Error::corrupted("Literals header truncated"));
}
let size = ((input[0] >> 4) as usize) | ((input[1] as usize) << 4);
(size, 2)
}
3 => {
if input.len() < 3 {
return Err(Error::corrupted("Literals header truncated"));
}
let size = ((input[0] >> 4) as usize)
| ((input[1] as usize) << 4)
| ((input[2] as usize) << 12);
(size, 3)
}
_ => unreachable!(),
};
let data_start = header_size;
let data = match block_type {
LiteralsBlockType::Raw => {
if input.len() < data_start + regenerated_size {
return Err(Error::corrupted("Raw literals truncated"));
}
input[data_start..data_start + regenerated_size].to_vec()
}
LiteralsBlockType::Rle => {
if input.len() < data_start + 1 {
return Err(Error::corrupted("RLE literals missing byte"));
}
vec![input[data_start]; regenerated_size]
}
_ => unreachable!(),
};
let total_size = match block_type {
LiteralsBlockType::Raw => header_size + regenerated_size,
LiteralsBlockType::Rle => header_size + 1,
_ => unreachable!(),
};
Ok((
Self {
block_type,
regenerated_size,
compressed_size: match block_type {
LiteralsBlockType::Raw => regenerated_size,
LiteralsBlockType::Rle => 1,
_ => unreachable!(),
},
data,
},
total_size,
))
}
fn parse_compressed(
input: &[u8],
block_type: LiteralsBlockType,
size_format: u8,
) -> Result<(Self, usize)> {
let is_single_stream = size_format == 3;
let (regenerated_size, compressed_size, header_size) = match size_format {
0 => {
if input.len() < 3 {
return Err(Error::corrupted("Compressed literals header truncated"));
}
let regen = ((input[0] >> 4) as usize) | (((input[1] & 0x3F) as usize) << 4);
let comp = ((input[1] >> 6) as usize) | ((input[2] as usize) << 2);
(regen, comp, 3)
}
1 => {
if input.len() < 4 {
return Err(Error::corrupted("Compressed literals header truncated"));
}
let regen = ((input[0] >> 4) as usize)
| ((input[1] as usize) << 4)
| (((input[2] & 0x03) as usize) << 12);
let comp = ((input[2] >> 6) as usize) | ((input[3] as usize) << 2);
(regen, comp, 4)
}
2 => {
if input.len() < 5 {
return Err(Error::corrupted("Compressed literals header truncated"));
}
let regen = (((input[0] >> 4) & 0x3F) as usize)
| ((input[1] as usize) << 4)
| (((input[2] & 0x0F) as usize) << 12);
let comp = ((input[2] >> 4) as usize)
| ((input[3] as usize) << 4)
| (((input[4] & 0x03) as usize) << 12);
(regen, comp, 5)
}
3 => {
if input.len() < 3 {
return Err(Error::corrupted("Compressed literals header truncated"));
}
let regen = ((input[0] >> 4) as usize) | (((input[1] & 0x3F) as usize) << 4);
let comp = ((input[1] >> 6) as usize) | ((input[2] as usize) << 2);
(regen, comp, 3)
}
_ => unreachable!(),
};
if input.len() < header_size + compressed_size {
return Err(Error::corrupted("Compressed literals data truncated"));
}
let compressed_data = &input[header_size..header_size + compressed_size];
if block_type == LiteralsBlockType::Treeless {
return Err(Error::Unsupported(
"Treeless Huffman literals require previous table state".into(),
));
}
let data =
Self::decode_huffman_literals(compressed_data, regenerated_size, is_single_stream)?;
let total_size = header_size + compressed_size;
Ok((
Self {
block_type,
regenerated_size,
compressed_size,
data,
},
total_size,
))
}
fn decode_huffman_literals(
data: &[u8],
regenerated_size: usize,
is_single_stream: bool,
) -> Result<Vec<u8>> {
if data.is_empty() {
return Err(Error::corrupted("Empty Huffman literals data"));
}
let (weights, weights_consumed) = parse_huffman_weights(data)?;
let table = build_table_from_weights(weights)?;
let decoder = HuffmanDecoder::new(&table);
let stream_data = &data[weights_consumed..];
if is_single_stream {
Self::decode_single_stream(&decoder, stream_data, regenerated_size)
} else {
Self::decode_four_streams(&decoder, stream_data, regenerated_size)
}
}
fn decode_single_stream(
decoder: &HuffmanDecoder,
data: &[u8],
regenerated_size: usize,
) -> Result<Vec<u8>> {
if data.is_empty() {
if regenerated_size == 0 {
return Ok(Vec::new());
}
return Err(Error::corrupted("Empty stream data for Huffman decoding"));
}
let mut output = Vec::with_capacity(regenerated_size);
let mut bits = BitReader::new_reversed(data)?;
for _ in 0..regenerated_size {
let symbol = decoder.decode_symbol(&mut bits)?;
output.push(symbol);
}
Ok(output)
}
fn decode_four_streams(
decoder: &HuffmanDecoder,
data: &[u8],
regenerated_size: usize,
) -> Result<Vec<u8>> {
if data.len() < 6 {
return Err(Error::corrupted("4-stream header too short"));
}
let jump1 = u16::from_le_bytes([data[0], data[1]]) as usize;
let jump2 = u16::from_le_bytes([data[2], data[3]]) as usize;
let jump3 = u16::from_le_bytes([data[4], data[5]]) as usize;
let stream1_start = 6;
let stream2_start = 6 + jump1;
let stream3_start = 6 + jump2;
let stream4_start = 6 + jump3;
let stream4_end = data.len();
if stream2_start > data.len() || stream3_start > data.len() || stream4_start > data.len() {
return Err(Error::corrupted(
"Invalid stream boundaries in 4-stream literals",
));
}
let base_size = regenerated_size / 4;
let remainder = regenerated_size % 4;
let sizes = [
base_size + if remainder > 0 { 1 } else { 0 },
base_size + if remainder > 1 { 1 } else { 0 },
base_size + if remainder > 2 { 1 } else { 0 },
base_size,
];
let stream_ranges = [
(stream1_start, stream2_start),
(stream2_start, stream3_start),
(stream3_start, stream4_start),
(stream4_start, stream4_end),
];
let mut output = Vec::with_capacity(regenerated_size);
for (i, &(start, end)) in stream_ranges.iter().enumerate() {
if start >= end {
if sizes[i] > 0 {
return Err(Error::corrupted(format!(
"Stream {} is empty but expects {} symbols",
i, sizes[i]
)));
}
continue;
}
let stream_data = &data[start..end];
let stream_output = Self::decode_single_stream(decoder, stream_data, sizes[i])?;
output.extend(stream_output);
}
Ok(output)
}
pub fn data(&self) -> &[u8] {
&self.data
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_literals_block_type_parsing() {
assert_eq!(LiteralsBlockType::from_field(0), LiteralsBlockType::Raw);
assert_eq!(LiteralsBlockType::from_field(1), LiteralsBlockType::Rle);
assert_eq!(
LiteralsBlockType::from_field(2),
LiteralsBlockType::Compressed
);
assert_eq!(
LiteralsBlockType::from_field(3),
LiteralsBlockType::Treeless
);
}
#[test]
fn test_raw_literals_5bit_size() {
let mut input = vec![0x28]; input.extend_from_slice(b"Hello");
let (section, consumed) = LiteralsSection::parse(&input).unwrap();
assert_eq!(section.block_type, LiteralsBlockType::Raw);
assert_eq!(section.regenerated_size, 5);
assert_eq!(section.data, b"Hello");
assert_eq!(consumed, 6); }
#[test]
fn test_rle_literals_5bit_size() {
let input = vec![0x51, b'X'];
let (section, consumed) = LiteralsSection::parse(&input).unwrap();
assert_eq!(section.block_type, LiteralsBlockType::Rle);
assert_eq!(section.regenerated_size, 10);
assert_eq!(section.data, vec![b'X'; 10]);
assert_eq!(consumed, 2); }
#[test]
fn test_raw_literals_12bit_size() {
let mut input = vec![0x04, 0x10]; input.resize(2 + 256, b'A');
let (section, consumed) = LiteralsSection::parse(&input).unwrap();
assert_eq!(section.block_type, LiteralsBlockType::Raw);
assert_eq!(section.regenerated_size, 256);
assert_eq!(consumed, 2 + 256);
}
#[test]
fn test_empty_input_error() {
let result = LiteralsSection::parse(&[]);
assert!(result.is_err());
}
#[test]
fn test_truncated_raw_error() {
let input = vec![0x50, b'H', b'e', b'l', b'l', b'o'];
let result = LiteralsSection::parse(&input);
assert!(result.is_err());
}
#[test]
fn test_new_raw_helper() {
let section = LiteralsSection::new_raw(b"test".to_vec());
assert_eq!(section.block_type, LiteralsBlockType::Raw);
assert_eq!(section.regenerated_size, 4);
assert_eq!(section.data(), b"test");
}
#[test]
fn test_compressed_header_type_detection() {
let header_byte = 0x0E; let block_type = LiteralsBlockType::from_field(header_byte & 0x03);
assert_eq!(block_type, LiteralsBlockType::Compressed);
let header_byte = 0x02; let block_type = LiteralsBlockType::from_field(header_byte & 0x03);
assert_eq!(block_type, LiteralsBlockType::Compressed);
}
#[test]
fn test_treeless_requires_previous_table() {
let mut input = vec![0x5F, 0x28, 0x00];
input.extend(vec![0x80; 10]);
let result = LiteralsSection::parse(&input);
assert!(result.is_err());
if let Err(e) = result {
let msg = format!("{:?}", e);
assert!(
msg.contains("previous table") || msg.contains("Treeless"),
"Expected 'previous table' or 'Treeless' error, got: {}",
msg
);
}
}
#[test]
fn test_compressed_literals_truncated_data_error() {
let input = vec![0xA2, 0x50, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05];
let result = LiteralsSection::parse(&input);
assert!(result.is_err());
}
#[test]
fn test_size_format_detection() {
for size_format in 0..4u8 {
let header_byte = 0x02 | (size_format << 2); let extracted = (header_byte >> 2) & 0x03;
assert_eq!(extracted, size_format);
}
}
}