#![allow(dead_code)]
use std::io;
const MULTI_SYM_BITS: usize = 12;
const MULTI_SYM_SIZE: usize = 1 << MULTI_SYM_BITS;
const MULTI_SYM_MASK: u64 = (MULTI_SYM_SIZE - 1) as u64;
#[derive(Clone, Copy, Default)]
#[repr(C, align(8))]
pub struct MultiSymEntry {
pub total_bits: u8,
pub sym_count: u8,
pub sym1: u8,
pub sym2: u8,
pub sym3: u8,
pub sym4: u8,
_pad: [u8; 2],
}
impl MultiSymEntry {
#[inline(always)]
pub fn single_literal(bits: u8, byte: u8) -> Self {
Self {
total_bits: bits,
sym_count: 1,
sym1: byte,
sym2: 0,
sym3: 0,
sym4: 0,
_pad: [0; 2],
}
}
#[inline(always)]
pub fn two_literals(bits1: u8, byte1: u8, bits2: u8, byte2: u8) -> Self {
Self {
total_bits: bits1 + bits2,
sym_count: 2,
sym1: byte1,
sym2: byte2,
sym3: 0,
sym4: 0,
_pad: [0; 2],
}
}
#[inline(always)]
pub fn three_literals(total_bits: u8, b1: u8, b2: u8, b3: u8) -> Self {
Self {
total_bits,
sym_count: 3,
sym1: b1,
sym2: b2,
sym3: b3,
sym4: 0,
_pad: [0; 2],
}
}
#[inline(always)]
pub fn four_literals(total_bits: u8, b1: u8, b2: u8, b3: u8, b4: u8) -> Self {
Self {
total_bits,
sym_count: 4,
sym1: b1,
sym2: b2,
sym3: b3,
sym4: b4,
_pad: [0; 2],
}
}
#[inline(always)]
pub fn non_literal(bits: u8, symbol: u16) -> Self {
Self {
total_bits: bits,
sym_count: 0, sym1: (symbol & 0xFF) as u8,
sym2: (symbol >> 8) as u8,
sym3: 0,
sym4: 0,
_pad: [0; 2],
}
}
#[inline(always)]
pub fn is_literal_run(&self) -> bool {
self.sym_count > 0
}
#[inline(always)]
pub fn symbol(&self) -> u16 {
(self.sym2 as u16) << 8 | self.sym1 as u16
}
}
pub struct MultiSymTable {
table: Vec<MultiSymEntry>,
max_code_len: u32,
}
impl MultiSymTable {
pub fn build(lens: &[u8]) -> io::Result<Self> {
let mut table = vec![MultiSymEntry::default(); MULTI_SYM_SIZE];
let mut bl_count = [0u32; 16];
let mut max_len = 0u32;
for &len in lens {
if len > 0 && len <= 15 {
bl_count[len as usize] += 1;
max_len = max_len.max(len as u32);
}
}
let mut next_code = [0u32; 16];
let mut code = 0u32;
for bits in 1..=15 {
code = (code + bl_count[bits - 1]) << 1;
next_code[bits] = code;
}
let mut symbol_info: Vec<(u16, u8)> = vec![(0xFFFF, 0); MULTI_SYM_SIZE];
for (symbol, &len) in lens.iter().enumerate() {
if len == 0 || len > 12 {
continue; }
let len = len as u32;
let code = next_code[len as usize];
next_code[len as usize] += 1;
let rev = reverse_bits(code, len);
let fill_count = 1usize << (MULTI_SYM_BITS as u32 - len);
for i in 0..fill_count {
let idx = (rev as usize) | (i << len as usize);
if idx < MULTI_SYM_SIZE {
symbol_info[idx] = (symbol as u16, len as u8);
}
}
}
for idx in 0..MULTI_SYM_SIZE {
let (sym1, len1) = symbol_info[idx];
if len1 == 0 {
continue;
}
if sym1 >= 256 {
table[idx] = MultiSymEntry::non_literal(len1, sym1);
continue;
}
let remaining1 = MULTI_SYM_BITS as u32 - len1 as u32;
if remaining1 >= 1 {
let next_bits = (idx >> len1 as usize) & ((1 << remaining1) - 1);
let (sym2, len2) = symbol_info[next_bits];
if len2 > 0 && (len2 as u32) <= remaining1 && sym2 < 256 {
let remaining2 = remaining1 - len2 as u32;
if remaining2 >= 1 {
let next_bits2 =
(idx >> (len1 as usize + len2 as usize)) & ((1 << remaining2) - 1);
let (sym3, len3) = symbol_info[next_bits2];
if len3 > 0 && (len3 as u32) <= remaining2 && sym3 < 256 {
let remaining3 = remaining2 - len3 as u32;
if remaining3 >= 1 {
let next_bits3 = (idx
>> (len1 as usize + len2 as usize + len3 as usize))
& ((1 << remaining3) - 1);
let (sym4, len4) = symbol_info[next_bits3];
if len4 > 0 && (len4 as u32) <= remaining3 && sym4 < 256 {
table[idx] = MultiSymEntry::four_literals(
len1 + len2 + len3 + len4,
sym1 as u8,
sym2 as u8,
sym3 as u8,
sym4 as u8,
);
continue;
}
}
table[idx] = MultiSymEntry::three_literals(
len1 + len2 + len3,
sym1 as u8,
sym2 as u8,
sym3 as u8,
);
continue;
}
}
table[idx] = MultiSymEntry::two_literals(len1, sym1 as u8, len2, sym2 as u8);
continue;
}
}
table[idx] = MultiSymEntry::single_literal(len1, sym1 as u8);
}
Ok(Self {
table,
max_code_len: max_len,
})
}
#[inline(always)]
pub fn lookup(&self, bits: u64) -> &MultiSymEntry {
let idx = (bits & MULTI_SYM_MASK) as usize;
unsafe { self.table.get_unchecked(idx) }
}
}
#[inline(always)]
fn reverse_bits(code: u32, len: u32) -> u32 {
if len == 0 {
return 0;
}
let mut rev = 0u32;
let mut c = code;
for _ in 0..len {
rev = (rev << 1) | (c & 1);
c >>= 1;
}
rev
}
#[cfg(target_arch = "x86_64")]
pub fn decode_simd_multi_sym(
table: &MultiSymTable,
bits: &mut crate::decompress::two_level_table::FastBits,
output: &mut [u8],
mut out_pos: usize,
) -> io::Result<usize> {
loop {
bits.ensure(32);
let entry = table.lookup(bits.buffer());
if entry.sym_count == 0 {
break;
}
if entry.total_bits == 0 {
break;
}
bits.consume(entry.total_bits as u32);
match entry.sym_count {
1 => {
if out_pos >= output.len() {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"Output buffer full",
));
}
output[out_pos] = entry.sym1;
out_pos += 1;
}
2 => {
if out_pos + 1 >= output.len() {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"Output buffer full",
));
}
output[out_pos] = entry.sym1;
output[out_pos + 1] = entry.sym2;
out_pos += 2;
}
3 => {
if out_pos + 2 >= output.len() {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"Output buffer full",
));
}
output[out_pos] = entry.sym1;
output[out_pos + 1] = entry.sym2;
output[out_pos + 2] = entry.sym3;
out_pos += 3;
}
4 => {
if out_pos + 3 >= output.len() {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"Output buffer full",
));
}
output[out_pos] = entry.sym1;
output[out_pos + 1] = entry.sym2;
output[out_pos + 2] = entry.sym3;
output[out_pos + 3] = entry.sym4;
out_pos += 4;
}
_ => break,
}
}
Ok(out_pos)
}
#[cfg(not(target_arch = "x86_64"))]
pub fn decode_simd_multi_sym(
table: &MultiSymTable,
bits: &mut crate::decompress::two_level_table::FastBits,
output: &mut [u8],
mut out_pos: usize,
) -> io::Result<usize> {
loop {
bits.ensure(32);
let entry = table.lookup(bits.buffer());
if entry.sym_count == 0 || entry.total_bits == 0 {
break;
}
bits.consume(entry.total_bits as u32);
match entry.sym_count {
1 => {
if out_pos >= output.len() {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"Output buffer full",
));
}
output[out_pos] = entry.sym1;
out_pos += 1;
}
2 => {
if out_pos + 1 >= output.len() {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"Output buffer full",
));
}
output[out_pos] = entry.sym1;
output[out_pos + 1] = entry.sym2;
out_pos += 2;
}
3 => {
if out_pos + 2 >= output.len() {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"Output buffer full",
));
}
output[out_pos] = entry.sym1;
output[out_pos + 1] = entry.sym2;
output[out_pos + 2] = entry.sym3;
out_pos += 3;
}
4 => {
if out_pos + 3 >= output.len() {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"Output buffer full",
));
}
output[out_pos] = entry.sym1;
output[out_pos + 1] = entry.sym2;
output[out_pos + 2] = entry.sym3;
output[out_pos + 3] = entry.sym4;
out_pos += 4;
}
_ => break,
}
}
Ok(out_pos)
}
#[derive(Clone, Copy, Default)]
#[repr(transparent)]
pub struct GatherEntry(pub u32);
impl GatherEntry {
pub const LITERAL_FLAG: u32 = 1 << 31;
pub const LENGTH_MASK: u32 = 0xFF;
pub const SYMBOL_SHIFT: u32 = 8;
#[inline(always)]
pub fn literal(bits: u8, byte: u8) -> Self {
GatherEntry(Self::LITERAL_FLAG | ((byte as u32) << Self::SYMBOL_SHIFT) | bits as u32)
}
#[inline(always)]
pub fn length_code(bits: u8, len_idx: u8) -> Self {
GatherEntry(((len_idx as u32) << Self::SYMBOL_SHIFT) | bits as u32)
}
#[inline(always)]
pub fn is_literal(self) -> bool {
self.0 & Self::LITERAL_FLAG != 0
}
#[inline(always)]
pub fn symbol(self) -> u8 {
((self.0 >> Self::SYMBOL_SHIFT) & 0xFF) as u8
}
#[inline(always)]
pub fn code_length(self) -> u8 {
(self.0 & Self::LENGTH_MASK) as u8
}
}
pub struct GatherTable {
pub entries: Box<[GatherEntry; 1024]>,
}
impl GatherTable {
pub fn build(lit_len_lens: &[u8]) -> Option<Self> {
let mut entries = Box::new([GatherEntry::default(); 1024]);
let mut bl_count = [0u32; 16];
for &len in lit_len_lens.iter().take(286) {
if len > 0 {
bl_count[len as usize] += 1;
}
}
let mut next_code = [0u32; 16];
let mut code = 0u32;
for bits in 1..16 {
code = (code + bl_count[bits - 1]) << 1;
next_code[bits] = code;
}
for (sym, &len) in lit_len_lens.iter().enumerate().take(286) {
if len == 0 || len > 10 {
continue;
}
let code = next_code[len as usize];
next_code[len as usize] += 1;
let mut rev = 0u32;
for i in 0..len {
if code & (1 << i) != 0 {
rev |= 1 << (len - 1 - i);
}
}
let fill_count = 1 << (10 - len);
for i in 0..fill_count {
let idx = rev as usize | (i << len);
if sym < 256 {
entries[idx] = GatherEntry::literal(len, sym as u8);
} else {
entries[idx] = GatherEntry::length_code(len, (sym - 256) as u8);
}
}
}
Some(GatherTable { entries })
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
#[allow(dead_code)]
pub unsafe fn decode_8_gather(
table: &GatherTable,
bit_positions: &[u64; 8],
bitbuf: u64,
) -> ([u8; 8], [u8; 8]) {
use std::arch::x86_64::*;
let mut indices = [0i32; 8];
for i in 0..8 {
indices[i] = ((bitbuf >> bit_positions[i]) & 0x3FF) as i32;
}
let idx_vec = _mm256_loadu_si256(indices.as_ptr() as *const __m256i);
let base_ptr = table.entries.as_ptr() as *const i32;
let entries = _mm256_i32gather_epi32(base_ptr, idx_vec, 4);
let mut entry_arr = [0u32; 8];
_mm256_storeu_si256(entry_arr.as_mut_ptr() as *mut __m256i, entries);
let mut literals = [0u8; 8];
let mut bits_consumed = [0u8; 8];
for i in 0..8 {
let e = GatherEntry(entry_arr[i]);
literals[i] = e.symbol();
bits_consumed[i] = e.code_length();
}
(literals, bits_consumed)
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
#[allow(dead_code, clippy::missing_safety_doc)]
pub unsafe fn decode_8_gather(
_table: &GatherTable,
_bit_positions: &[u64; 8],
_bitbuf: u64,
) -> ([u8; 8], [u8; 8]) {
([0u8; 8], [0u8; 8])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multi_sym_entry_size() {
assert_eq!(std::mem::size_of::<MultiSymEntry>(), 8);
}
#[test]
fn test_multi_sym_table_build() {
let mut lens = [0u8; 16];
for (i, len) in lens.iter_mut().enumerate() {
if i < 16 {
*len = 4; }
}
let table = MultiSymTable::build(&lens).unwrap();
let mut single_count = 0;
let mut multi_count = 0;
for entry in &table.table {
if entry.sym_count == 1 {
single_count += 1;
} else if entry.sym_count >= 2 {
multi_count += 1;
}
}
println!(
"Single: {}, Multi: {}/{}",
single_count, multi_count, MULTI_SYM_SIZE
);
assert!(multi_count > 0, "Should have some multi-symbol entries");
}
#[test]
fn test_fixed_huffman_table() {
let mut lens = [0u8; 288];
for len in lens.iter_mut().take(144) {
*len = 8;
}
for len in lens.iter_mut().take(256).skip(144) {
*len = 9;
}
for len in lens.iter_mut().take(280).skip(256) {
*len = 7;
}
for len in lens.iter_mut().take(288).skip(280) {
*len = 8;
}
let table = MultiSymTable::build(&lens).unwrap();
let mut single_count = 0;
let mut multi_count = 0;
let mut non_literal_count = 0;
for entry in &table.table {
if entry.sym_count == 0 && entry.total_bits > 0 {
non_literal_count += 1;
} else if entry.sym_count == 1 {
single_count += 1;
} else if entry.sym_count >= 2 {
multi_count += 1;
}
}
println!(
"Fixed Huffman: single={}, multi={}, non_literal={}",
single_count, multi_count, non_literal_count
);
}
#[test]
fn test_decode_multi_sym() {
let original = b"AAAAAABBBBBB";
use flate2::write::DeflateEncoder;
use flate2::Compression;
use std::io::Write;
let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
println!(
"Compressed {} bytes to {} bytes",
original.len(),
compressed.len()
);
}
#[test]
fn benchmark_multi_sym_vs_single() {
let data = match std::fs::read("benchmark_data/silesia-gzip.tar.gz") {
Ok(d) => d,
Err(_) => {
eprintln!("Skipping - no benchmark file");
return;
}
};
let mut lens = [0u8; 288];
for len in lens.iter_mut().take(144) {
*len = 8;
}
for len in lens.iter_mut().take(256).skip(144) {
*len = 9;
}
for len in lens.iter_mut().take(280).skip(256) {
*len = 7;
}
for len in lens.iter_mut().take(288).skip(280) {
*len = 8;
}
let table = MultiSymTable::build(&lens).unwrap();
let mut single_count = 0;
let mut double_count = 0;
let mut triple_count = 0;
let mut quad_count = 0;
for entry in &table.table {
match entry.sym_count {
1 => single_count += 1,
2 => double_count += 1,
3 => triple_count += 1,
4 => quad_count += 1,
_ => {}
}
}
println!("\n=== Multi-Symbol Table Analysis ===");
println!("1-symbol entries: {}", single_count);
println!("2-symbol entries: {}", double_count);
println!("3-symbol entries: {}", triple_count);
println!("4-symbol entries: {}", quad_count);
println!(
"Multi-symbol ratio: {:.1}%",
(double_count + triple_count + quad_count) as f64 / MULTI_SYM_SIZE as f64 * 100.0
);
println!("Benchmark file size: {} bytes", data.len());
}
}