use alloc::vec::Vec;
use crate::error::Error;
use crate::zstd::bitreader::RevBitReader;
use crate::zstd::decoder::MAX_WINDOW_SIZE;
use crate::zstd::huffman::{HuffTable, decode_huffman_tree};
#[derive(Default)]
pub struct LiteralsState {
pub huff_tree: Option<HuffTable>,
}
pub struct LiteralsResult {
pub literals: Vec<u8>,
pub consumed: usize,
}
pub fn decode_literals(block: &[u8], state: &mut LiteralsState) -> Result<LiteralsResult, Error> {
if block.is_empty() {
return Err(Error::Corrupt);
}
let lhd = block[0];
let lit_block_type = lhd & 0b11;
let size_format = (lhd >> 2) & 0b11;
match lit_block_type {
0 | 1 => decode_raw_or_rle(block, lit_block_type == 1, size_format),
2 => decode_compressed_literals(block, size_format, state, false),
3 => decode_compressed_literals(block, size_format, state, true),
_ => unreachable!(),
}
}
fn decode_raw_or_rle(block: &[u8], is_rle: bool, sf: u8) -> Result<LiteralsResult, Error> {
let (regen_size, header_bytes) = match sf {
0b00 | 0b10 => {
let rs = (block[0] >> 3) as usize;
(rs, 1)
}
0b01 => {
if block.len() < 2 {
return Err(Error::Corrupt);
}
let rs = ((block[0] >> 4) as usize) | ((block[1] as usize) << 4);
(rs, 2)
}
0b11 => {
if block.len() < 3 {
return Err(Error::Corrupt);
}
let rs = ((block[0] >> 4) as usize)
| ((block[1] as usize) << 4)
| ((block[2] as usize) << 12);
(rs, 3)
}
_ => unreachable!(),
};
if regen_size as u64 > MAX_WINDOW_SIZE {
return Err(Error::Corrupt);
}
let mut literals = Vec::with_capacity(regen_size.min(128 * 1024));
if is_rle {
if block.len() < header_bytes + 1 {
return Err(Error::Corrupt);
}
let byte = block[header_bytes];
literals.resize(regen_size, byte);
Ok(LiteralsResult {
literals,
consumed: header_bytes + 1,
})
} else {
if block.len() < header_bytes + regen_size {
return Err(Error::Corrupt);
}
literals.extend_from_slice(&block[header_bytes..header_bytes + regen_size]);
Ok(LiteralsResult {
literals,
consumed: header_bytes + regen_size,
})
}
}
fn decode_compressed_literals(
block: &[u8],
sf: u8,
state: &mut LiteralsState,
reuse: bool,
) -> Result<LiteralsResult, Error> {
let (regen_size, comp_size, header_bytes, four_streams) = match sf {
0b00 | 0b01 => {
if block.len() < 3 {
return Err(Error::Corrupt);
}
let h0 = block[0] as u32;
let h1 = block[1] as u32;
let h2 = block[2] as u32;
let regen = ((h0 >> 4) | ((h1 & 0x3F) << 4)) as usize;
let comp = ((h1 >> 6) | (h2 << 2)) as usize;
(regen, comp, 3, sf != 0b00)
}
0b10 => {
if block.len() < 4 {
return Err(Error::Corrupt);
}
let h0 = block[0] as u32;
let h1 = block[1] as u32;
let h2 = block[2] as u32;
let h3 = block[3] as u32;
let regen = ((h0 >> 4) | (h1 << 4) | ((h2 & 0x03) << 12)) as usize;
let comp = ((h2 >> 2) | (h3 << 6)) as usize;
(regen, comp, 4, true)
}
0b11 => {
if block.len() < 5 {
return Err(Error::Corrupt);
}
let bits: u64 = (block[0] as u64)
| ((block[1] as u64) << 8)
| ((block[2] as u64) << 16)
| ((block[3] as u64) << 24)
| ((block[4] as u64) << 32);
let regen = ((bits >> 4) & 0x3_FFFF) as usize;
let comp = ((bits >> 22) & 0x3_FFFF) as usize;
(regen, comp, 5, true)
}
_ => unreachable!(),
};
if regen_size as u64 > MAX_WINDOW_SIZE {
return Err(Error::Corrupt);
}
if block.len() < header_bytes + comp_size {
return Err(Error::Corrupt);
}
let payload = &block[header_bytes..header_bytes + comp_size];
let (tree, tree_bytes) = if reuse {
match state.huff_tree.take() {
Some(t) => (t, 0),
None => return Err(Error::Corrupt),
}
} else {
let (t, used) = decode_huffman_tree(payload)?;
(t, used)
};
let bitstreams = &payload[tree_bytes..];
let mut literals = Vec::with_capacity(regen_size.min(128 * 1024));
if !four_streams {
decode_huff_stream(bitstreams, &tree, regen_size, &mut literals)?;
} else {
if bitstreams.len() < 6 {
return Err(Error::Corrupt);
}
let l1 = (bitstreams[0] as usize) | ((bitstreams[1] as usize) << 8);
let l2 = (bitstreams[2] as usize) | ((bitstreams[3] as usize) << 8);
let l3 = (bitstreams[4] as usize) | ((bitstreams[5] as usize) << 8);
if 6 + l1 + l2 + l3 > bitstreams.len() {
return Err(Error::Corrupt);
}
let l4 = bitstreams.len() - 6 - l1 - l2 - l3;
let s1 = &bitstreams[6..6 + l1];
let s2 = &bitstreams[6 + l1..6 + l1 + l2];
let s3 = &bitstreams[6 + l1 + l2..6 + l1 + l2 + l3];
let s4 = &bitstreams[6 + l1 + l2 + l3..];
let _ = l4;
if regen_size < 4 {
return Err(Error::Corrupt);
}
let per = regen_size.div_ceil(4);
let last = regen_size.checked_sub(3 * per).ok_or(Error::Corrupt)?;
decode_huff_stream(s1, &tree, per, &mut literals)?;
decode_huff_stream(s2, &tree, per, &mut literals)?;
decode_huff_stream(s3, &tree, per, &mut literals)?;
decode_huff_stream(s4, &tree, last, &mut literals)?;
}
state.huff_tree = Some(tree);
Ok(LiteralsResult {
literals,
consumed: header_bytes + comp_size,
})
}
fn decode_huff_stream(
data: &[u8],
tree: &HuffTable,
n: usize,
out: &mut Vec<u8>,
) -> Result<(), Error> {
if n == 0 {
return Ok(());
}
let mut br = RevBitReader::new(data)?;
for _ in 0..n {
let sym = tree.decode(&mut br)?;
out.push(sym);
}
Ok(())
}