use super::super::blocks::literals_section::{LiteralsSection, LiteralsSectionType};
use super::scratch::HuffmanScratch;
use crate::bit_io::BitReaderReversed;
#[cfg(target_arch = "x86_64")]
use crate::cpu_kernel::{Avx2Kernel, Bmi2Kernel, CpuKernelTag, Vbmi2Kernel};
use crate::cpu_kernel::{CpuKernel, ScalarKernel, detect_cpu_kernel};
use crate::decoding::errors::DecompressLiteralsError;
use crate::huff0::HuffmanDecoder;
use alloc::vec::Vec;
#[cfg(test)]
pub fn decode_literals(
section: &LiteralsSection,
scratch: &mut HuffmanScratch,
source: &[u8],
target: &mut Vec<u8>,
) -> Result<u32, DecompressLiteralsError> {
match section.ls_type {
LiteralsSectionType::Raw => {
target.extend(&source[0..section.regenerated_size as usize]);
Ok(section.regenerated_size)
}
LiteralsSectionType::RLE => {
target.resize(target.len() + section.regenerated_size as usize, source[0]);
Ok(1)
}
LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => {
let bytes_read = decompress_literals(section, scratch, source, target)?;
Ok(bytes_read)
}
}
}
pub struct LiteralsView<'a> {
pub data: &'a [u8],
pub bytes_used: u32,
}
pub fn decode_literals_zerocopy<'a>(
section: &LiteralsSection,
scratch: &mut HuffmanScratch,
source: &'a [u8],
target: &'a mut Vec<u8>,
) -> Result<LiteralsView<'a>, DecompressLiteralsError> {
let base = target.len();
match section.ls_type {
LiteralsSectionType::Raw => {
let n = section.regenerated_size as usize;
if source.len() < n {
return Err(DecompressLiteralsError::MissingBytesForLiterals {
got: source.len(),
needed: n,
});
}
Ok(LiteralsView {
data: &source[0..n],
bytes_used: section.regenerated_size,
})
}
LiteralsSectionType::RLE => {
if source.is_empty() {
return Err(DecompressLiteralsError::MissingBytesForLiterals { got: 0, needed: 1 });
}
target.resize(base + section.regenerated_size as usize, source[0]);
Ok(LiteralsView {
data: &target[base..],
bytes_used: 1,
})
}
LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => {
let bytes_used = decompress_literals(section, scratch, source, target)?;
Ok(LiteralsView {
data: &target[base..],
bytes_used,
})
}
}
}
fn decompress_literals(
section: &LiteralsSection,
scratch: &mut HuffmanScratch,
source: &[u8],
target: &mut Vec<u8>,
) -> Result<u32, DecompressLiteralsError> {
match detect_cpu_kernel() {
#[cfg(target_arch = "x86_64")]
CpuKernelTag::Vbmi2 => unsafe {
decompress_literals_vbmi2(section, scratch, source, target)
},
#[cfg(target_arch = "x86_64")]
CpuKernelTag::Avx2 => unsafe { decompress_literals_avx2(section, scratch, source, target) },
#[cfg(target_arch = "x86_64")]
CpuKernelTag::Bmi2 => unsafe { decompress_literals_bmi2(section, scratch, source, target) },
_ => decompress_literals_impl::<ScalarKernel>(section, scratch, source, target),
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi2,avx2")]
unsafe fn decompress_literals_avx2(
section: &LiteralsSection,
scratch: &mut HuffmanScratch,
source: &[u8],
target: &mut Vec<u8>,
) -> Result<u32, DecompressLiteralsError> {
decompress_literals_impl::<Avx2Kernel>(section, scratch, source, target)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "bmi2")]
unsafe fn decompress_literals_bmi2(
section: &LiteralsSection,
scratch: &mut HuffmanScratch,
source: &[u8],
target: &mut Vec<u8>,
) -> Result<u32, DecompressLiteralsError> {
decompress_literals_impl::<Bmi2Kernel>(section, scratch, source, target)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512vbmi2,avx512f,avx512vl,avx512bw,bmi2,avx2")]
unsafe fn decompress_literals_vbmi2(
section: &LiteralsSection,
scratch: &mut HuffmanScratch,
source: &[u8],
target: &mut Vec<u8>,
) -> Result<u32, DecompressLiteralsError> {
decompress_literals_impl::<Vbmi2Kernel>(section, scratch, source, target)
}
fn decompress_literals_impl<K: CpuKernel>(
section: &LiteralsSection,
scratch: &mut HuffmanScratch,
source: &[u8],
target: &mut Vec<u8>,
) -> Result<u32, DecompressLiteralsError> {
use DecompressLiteralsError as err;
let compressed_size = section.compressed_size.ok_or(err::MissingCompressedSize)? as usize;
let num_streams = section.num_streams.ok_or(err::MissingNumStreams)?;
let base = target.len();
let regen = section.regenerated_size as usize;
target.reserve(regen);
let source = &source[0..compressed_size];
let mut bytes_read = 0;
match section.ls_type {
LiteralsSectionType::Compressed => {
bytes_read += scratch.table.build_decoder(source)?;
vprintln!("Built huffman table using {} bytes", bytes_read);
}
LiteralsSectionType::Treeless if scratch.table.max_num_bits == 0 => {
return Err(err::UninitializedHuffmanTable);
}
_ => { }
}
let source = &source[bytes_read as usize..];
if num_streams == 4 {
if source.len() < 6 {
return Err(err::MissingBytesForJumpHeader { got: source.len() });
}
let jump1 = source[0] as usize + ((source[1] as usize) << 8);
let jump2 = jump1 + source[2] as usize + ((source[3] as usize) << 8);
let jump3 = jump2 + source[4] as usize + ((source[5] as usize) << 8);
bytes_read += 6;
let source = &source[6..];
if source.len() < jump3 {
return Err(err::MissingBytesForLiterals {
got: source.len(),
needed: jump3,
});
}
let streams: [&[u8]; 4] = [
&source[..jump1],
&source[jump1..jump2],
&source[jump2..jump3],
&source[jump3..],
];
let mut decoders: [HuffmanDecoder<'_>; 4] = [
HuffmanDecoder::new(&scratch.table),
HuffmanDecoder::new(&scratch.table),
HuffmanDecoder::new(&scratch.table),
HuffmanDecoder::new(&scratch.table),
];
let mut brs: [BitReaderReversed<'_, K>; 4] = [
BitReaderReversed::<K>::new(streams[0]),
BitReaderReversed::<K>::new(streams[1]),
BitReaderReversed::<K>::new(streams[2]),
BitReaderReversed::<K>::new(streams[3]),
];
for i in 0..4 {
let mut skipped_bits = 0;
loop {
let val = brs[i].get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
}
}
if skipped_bits > 8 {
return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
}
decoders[i].init_state(&mut brs[i]);
}
let max_bits = scratch.table.max_num_bits as isize;
let seg = regen.div_ceil(4);
let target_ptr: *mut u8 = target.as_mut_ptr();
let limit = base + regen;
let starts: [usize; 4] = [
base,
(base + seg).min(limit),
(base + 2 * seg).min(limit),
(base + 3 * seg).min(limit),
];
let ends: [usize; 4] = [starts[1], starts[2], starts[3], limit];
let mut cursors = starts;
let max_num_bits = scratch.table.max_num_bits;
let symbols_per_burst: usize = (63 - 8) / max_num_bits as usize;
let burst_bits = (symbols_per_burst * max_num_bits as usize) as u8;
let table_shift = (64 - max_num_bits) as u32;
let packed = scratch.table.packed_decode.as_slice();
let min_seg_len = (ends[0] - starts[0])
.min(ends[1] - starts[1])
.min(ends[2] - starts[2])
.min(ends[3] - starts[3]);
let burst_eligible = symbols_per_burst >= 1 && min_seg_len >= symbols_per_burst;
let cursor_burst_ceil = (starts[0] + min_seg_len).saturating_sub(symbols_per_burst);
let bounds = LoopBounds {
symbols_per_burst,
burst_bits,
table_shift,
cursor_burst_ceil,
burst_eligible,
alloc_upper_bound: base + regen,
};
unsafe {
run_4stream_burst_loop(
&mut decoders,
&mut brs,
target_ptr,
packed,
&mut cursors,
&bounds,
);
}
for i in 0..4 {
while brs[i].bits_remaining() > -max_bits && cursors[i] < ends[i] {
let byte = decoders[i].decode_symbol_and_advance(&mut brs[i]);
unsafe {
target_ptr.add(cursors[i]).write(byte);
}
cursors[i] += 1;
}
if brs[i].bits_remaining() != -max_bits {
return Err(DecompressLiteralsError::BitstreamReadMismatch {
read_til: brs[i].bits_remaining(),
expected: -max_bits,
});
}
}
let decoded: usize = cursors.iter().zip(starts.iter()).map(|(c, s)| c - s).sum();
if decoded != regen {
return Err(DecompressLiteralsError::DecodedLiteralCountMismatch {
decoded,
expected: regen,
});
}
unsafe {
target.set_len(base + regen);
}
bytes_read += source.len() as u32;
} else {
assert!(num_streams == 1);
let mut decoder = HuffmanDecoder::new(&scratch.table);
let mut br = BitReaderReversed::<K>::new(source);
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
}
}
if skipped_bits > 8 {
return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
}
decoder.init_state(&mut br);
while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
target.push(decoder.decode_symbol_and_advance(&mut br));
}
let expected = -(scratch.table.max_num_bits as isize);
if br.bits_remaining() != expected {
target.truncate(base);
return Err(DecompressLiteralsError::BitstreamReadMismatch {
read_til: br.bits_remaining(),
expected,
});
}
bytes_read += source.len() as u32;
}
if target.len() != base + regen {
let decoded = target.len() - base;
target.truncate(base);
return Err(DecompressLiteralsError::DecodedLiteralCountMismatch {
decoded,
expected: regen,
});
}
Ok(bytes_read)
}
#[derive(Copy, Clone)]
struct LoopBounds {
symbols_per_burst: usize,
burst_bits: u8,
table_shift: u32,
cursor_burst_ceil: usize,
burst_eligible: bool,
alloc_upper_bound: usize,
}
#[inline(always)]
unsafe fn burst_decode_symbols<const SPB: usize>(
bits: &mut [u64; 4],
cursors: &mut [usize; 4],
target_ptr: *mut u8,
packed: &[u16],
table_shift: u32,
) {
for _ in 0..SPB {
let idx0 = (bits[0] >> table_shift) as usize;
let entry0 = unsafe { *packed.get_unchecked(idx0) };
unsafe { target_ptr.add(cursors[0]).write((entry0 & 0xFF) as u8) };
cursors[0] += 1;
bits[0] <<= (entry0 >> 8) & 0xFF;
let idx1 = (bits[1] >> table_shift) as usize;
let entry1 = unsafe { *packed.get_unchecked(idx1) };
unsafe { target_ptr.add(cursors[1]).write((entry1 & 0xFF) as u8) };
cursors[1] += 1;
bits[1] <<= (entry1 >> 8) & 0xFF;
let idx2 = (bits[2] >> table_shift) as usize;
let entry2 = unsafe { *packed.get_unchecked(idx2) };
unsafe { target_ptr.add(cursors[2]).write((entry2 & 0xFF) as u8) };
cursors[2] += 1;
bits[2] <<= (entry2 >> 8) & 0xFF;
let idx3 = (bits[3] >> table_shift) as usize;
let entry3 = unsafe { *packed.get_unchecked(idx3) };
unsafe { target_ptr.add(cursors[3]).write((entry3 & 0xFF) as u8) };
cursors[3] += 1;
bits[3] <<= (entry3 >> 8) & 0xFF;
}
}
#[inline(always)]
unsafe fn run_4stream_burst_loop<K: CpuKernel>(
decoders: &mut [HuffmanDecoder<'_>; 4],
brs: &mut [BitReaderReversed<'_, K>; 4],
target_ptr: *mut u8,
packed: &[u16],
cursors: &mut [usize; 4],
bounds: &LoopBounds,
) {
let LoopBounds {
symbols_per_burst,
burst_bits,
table_shift,
cursor_burst_ceil,
burst_eligible,
alloc_upper_bound,
} = *bounds;
let max_num_bits = (64 - table_shift) as u8;
if !burst_eligible {
return;
}
debug_assert!(
cursor_burst_ceil + symbols_per_burst <= alloc_upper_bound,
"caller must size the target allocation so the lockstep-advanced \
cursors stay within bounds across a full burst",
);
let mut bits: [u64; 4] = [
(brs[0].bit_container | 1) << (brs[0].bits_consumed - max_num_bits),
(brs[1].bit_container | 1) << (brs[1].bits_consumed - max_num_bits),
(brs[2].bit_container | 1) << (brs[2].bits_consumed - max_num_bits),
(brs[3].bit_container | 1) << (brs[3].bits_consumed - max_num_bits),
];
let mut ip: [usize; 4] = [brs[0].index, brs[1].index, brs[2].index, brs[3].index];
let mut nb_bits_last: [u8; 4] = [
brs[0].bits_consumed - max_num_bits,
brs[1].bits_consumed - max_num_bits,
brs[2].bits_consumed - max_num_bits,
brs[3].bits_consumed - max_num_bits,
];
let bytes_per_iter_upper = (8 + burst_bits as usize) / 8;
let mut any_iter = false;
while cursors[0] <= cursor_burst_ceil {
let min_ip = ip[0].min(ip[1]).min(ip[2]).min(ip[3]);
if min_ip < bytes_per_iter_upper {
break;
}
any_iter = true;
match symbols_per_burst {
5 => unsafe {
burst_decode_symbols::<5>(
&mut bits,
&mut *cursors,
target_ptr,
packed,
table_shift,
);
},
6 => unsafe {
burst_decode_symbols::<6>(
&mut bits,
&mut *cursors,
target_ptr,
packed,
table_shift,
);
},
7 => unsafe {
burst_decode_symbols::<7>(
&mut bits,
&mut *cursors,
target_ptr,
packed,
table_shift,
);
},
_ => {
for _ in 0..symbols_per_burst {
let idx0 = (bits[0] >> table_shift) as usize;
let entry0 = unsafe { *packed.get_unchecked(idx0) };
unsafe { target_ptr.add(cursors[0]).write((entry0 & 0xFF) as u8) };
cursors[0] += 1;
bits[0] <<= (entry0 >> 8) & 0xFF;
let idx1 = (bits[1] >> table_shift) as usize;
let entry1 = unsafe { *packed.get_unchecked(idx1) };
unsafe { target_ptr.add(cursors[1]).write((entry1 & 0xFF) as u8) };
cursors[1] += 1;
bits[1] <<= (entry1 >> 8) & 0xFF;
let idx2 = (bits[2] >> table_shift) as usize;
let entry2 = unsafe { *packed.get_unchecked(idx2) };
unsafe { target_ptr.add(cursors[2]).write((entry2 & 0xFF) as u8) };
cursors[2] += 1;
bits[2] <<= (entry2 >> 8) & 0xFF;
let idx3 = (bits[3] >> table_shift) as usize;
let entry3 = unsafe { *packed.get_unchecked(idx3) };
unsafe { target_ptr.add(cursors[3]).write((entry3 & 0xFF) as u8) };
cursors[3] += 1;
bits[3] <<= (entry3 >> 8) & 0xFF;
}
}
}
for s in 0..4 {
let ctz = bits[s].trailing_zeros();
let nb_bytes = (ctz >> 3) as usize;
let nb_bits = (ctz & 7) as u8;
ip[s] -= nb_bytes;
let new_window = u64::from_le_bytes(unsafe {
brs[s]
.source
.get_unchecked(ip[s]..ip[s] + 8)
.try_into()
.unwrap_unchecked()
});
bits[s] = (new_window | 1) << nb_bits;
nb_bits_last[s] = nb_bits;
}
}
if !any_iter {
return;
}
for s in 0..4 {
brs[s].index = ip[s];
brs[s].bit_container = u64::from_le_bytes(unsafe {
brs[s]
.source
.get_unchecked(ip[s]..ip[s] + 8)
.try_into()
.unwrap_unchecked()
});
brs[s].bits_consumed = nb_bits_last[s] + max_num_bits;
decoders[s].state = bits[s] >> table_shift;
}
}
#[cfg(test)]
mod zerocopy_robustness_tests {
use super::{LiteralsView, decode_literals_zerocopy};
use crate::blocks::literals_section::{LiteralsSection, LiteralsSectionType};
use crate::decoding::scratch::HuffmanScratch;
use crate::huff0::HuffmanTable;
use alloc::vec::Vec;
fn raw_section(regen: u32) -> LiteralsSection {
LiteralsSection {
ls_type: LiteralsSectionType::Raw,
regenerated_size: regen,
compressed_size: None,
num_streams: None,
}
}
fn rle_section(regen: u32) -> LiteralsSection {
LiteralsSection {
ls_type: LiteralsSectionType::RLE,
regenerated_size: regen,
compressed_size: None,
num_streams: None,
}
}
fn fresh_scratch() -> HuffmanScratch {
HuffmanScratch {
table: HuffmanTable::new(),
}
}
#[test]
fn raw_truncated_source_returns_error_no_panic() {
let section = raw_section(10);
let source: [u8; 3] = [1, 2, 3];
let mut target: Vec<u8> = Vec::new();
let mut scratch = fresh_scratch();
let result = decode_literals_zerocopy(§ion, &mut scratch, &source, &mut target);
assert!(
result.is_err(),
"truncated raw source must error, not panic; got {:?}",
result.map(|_| ())
);
}
#[test]
fn rle_empty_source_returns_error_no_panic() {
let section = rle_section(10);
let source: [u8; 0] = [];
let mut target: Vec<u8> = Vec::new();
let mut scratch = fresh_scratch();
let result = decode_literals_zerocopy(§ion, &mut scratch, &source, &mut target);
assert!(
result.is_err(),
"empty RLE source must error, not panic; got {:?}",
result.map(|_| ())
);
}
#[test]
fn rle_view_excludes_pre_existing_target_bytes() {
let mut target: Vec<u8> = Vec::from([0xAA, 0xBB, 0xCC]);
let section = rle_section(4);
let source: [u8; 1] = [0x42];
let mut scratch = fresh_scratch();
let view = decode_literals_zerocopy(§ion, &mut scratch, &source, &mut target)
.expect("RLE with valid source must succeed");
assert_eq!(view.data.len(), 4, "view length must match regen_size");
assert!(
view.data.iter().all(|&b| b == 0x42),
"view must contain only the newly-RLE-expanded bytes, got {:?}",
view.data
);
let _ = LiteralsView {
data: view.data,
bytes_used: view.bytes_used,
};
}
}
#[cfg(test)]
mod burst_gate_tests {
use super::*;
use crate::bit_io::BitWriter;
use crate::blocks::literals_section::{LiteralsSection, LiteralsSectionType};
use crate::decoding::scratch::HuffmanScratch;
use crate::huff0::huff0_encoder::{HuffmanEncoder, HuffmanTable as EncTable};
use alloc::vec::Vec;
fn build_huf4x_block(data: &[u8]) -> (LiteralsSection, Vec<u8>) {
assert!(data.len() >= 4, "encode4x requires at least 4 bytes");
let table = EncTable::build_from_data(data);
let mut source: Vec<u8> = Vec::new();
{
let mut writer = BitWriter::from(&mut source);
let mut encoder = HuffmanEncoder::new(&table, &mut writer);
encoder.encode4x(data, true);
writer.flush();
}
let section = LiteralsSection {
ls_type: LiteralsSectionType::Compressed,
regenerated_size: data.len() as u32,
compressed_size: Some(source.len() as u32),
num_streams: Some(4),
};
(section, source)
}
fn roundtrip_assert(data: &[u8]) -> u8 {
let (section, source) = build_huf4x_block(data);
let mut scratch = HuffmanScratch::new();
let mut target = Vec::new();
let bytes_read = decode_literals(§ion, &mut scratch, &source, &mut target)
.expect("decode_literals must succeed on a well-formed roundtrip");
assert_eq!(
bytes_read as usize,
source.len(),
"decoder must consume every byte of the literals block"
);
assert_eq!(
target, data,
"decoded literals must match the encoder input"
);
scratch.table.max_num_bits
}
fn roundtrip_with_max_bits_range(data: &[u8], expected: core::ops::RangeInclusive<u8>) {
let m = roundtrip_assert(data);
assert!(
expected.contains(&m),
"max_num_bits {} outside expected range {:?} for this fixture — \
test no longer exercises the intended gate regime",
m,
expected
);
}
#[test]
fn burst_gate_lower_boundary_short_skewed_alphabet() {
let mut data: Vec<u8> = Vec::with_capacity(36);
data.extend_from_slice(&[
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22,
]);
roundtrip_with_max_bits_range(&data, 5..=11);
}
#[test]
fn burst_gate_upper_boundary_long_mid_alphabet() {
let mut data: Vec<u8> = Vec::with_capacity(4096);
for i in 0..4096u32 {
data.push((i.wrapping_mul(0x9E37_79B1) % 97) as u8);
}
roundtrip_with_max_bits_range(&data, 6..=11);
}
#[test]
fn burst_simd_fallback_refill_reentry_long_streams() {
let mut data: Vec<u8> = Vec::with_capacity(16 * 1024);
for i in 0..16 * 1024u32 {
data.push((i % 67) as u8);
}
roundtrip_with_max_bits_range(&data, 5..=8);
}
#[test]
fn burst_gate_sweep_sizes_and_alphabets() {
let sizes = [
16usize, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255, 256, 257, 511, 512, 513, 1023,
1024, 1025, 4096,
];
for &n in &sizes {
let mut bin: Vec<u8> = Vec::with_capacity(n);
for i in 0..n {
bin.push((i & 1) as u8);
}
roundtrip_assert(&bin);
let mut sm: Vec<u8> = Vec::with_capacity(n);
for i in 0..n {
sm.push((i % 16) as u8);
}
roundtrip_assert(&sm);
if n >= 128 {
let mut wide: Vec<u8> = Vec::with_capacity(n);
for i in 0..n {
wide.push((i.wrapping_mul(2_654_435_761) % 97) as u8);
}
roundtrip_assert(&wide);
}
}
}
#[test]
fn burst_gate_malformed_small_regen_returns_error() {
let mut data: Vec<u8> = Vec::with_capacity(256);
for i in 0..256u32 {
data.push((i % 67) as u8);
}
let (mut section, source) = build_huf4x_block(&data);
section.regenerated_size = 1;
let mut scratch = HuffmanScratch::new();
let mut target = Vec::new();
let result = decode_literals(§ion, &mut scratch, &source, &mut target);
assert!(
result.is_err(),
"decoder must reject the malformed header instead of panicking; \
got Ok({})",
result.unwrap_or(0)
);
}
}